Spaces:
Sleeping
Sleeping
| """ | |
| AI 垃圾分类助手 - 训练模块 | |
| 支持两种数据目录结构: | |
| 1. 已划分: dataset/train/ + dataset/val/ [+ dataset/test/] | |
| 2. 未划分: dataset/trashnet/ (自动随机划分) | |
| 训练结束后保存 loss 曲线图并输出详细评估报告 | |
| """ | |
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, random_split | |
| from torchvision import datasets, transforms | |
| from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import matplotlib | |
| matplotlib.use("Agg") # 不依赖 GUI 后端 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from config import CLASS_NAMES, CLASS_NAMES_CN | |
| # ── 设备 ────────────────────────────────────────────── | |
| def get_device(): | |
| if torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| print("✓ 使用 MPS (Apple Silicon) 加速训练") | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| print("✓ 使用 CUDA 加速训练") | |
| else: | |
| device = torch.device("cpu") | |
| print("⚠ 使用 CPU 训练 (建议使用 MPS/CUDA)") | |
| return device | |
| # ── 数据预处理 ──────────────────────────────────────── | |
| def get_transforms(): | |
| train_tf = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.RandomResizedCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| eval_tf = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return train_tf, eval_tf | |
| # ── 模型 ────────────────────────────────────────────── | |
| def create_model(num_classes=6): | |
| model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1) | |
| in_features = model.classifier[3].in_features | |
| model.classifier[3] = nn.Linear(in_features, num_classes) | |
| return model | |
| # ── 训练 / 评估 ─────────────────────────────────────── | |
| def train_epoch(model, loader, criterion, optimizer, device, desc="Training"): | |
| model.train() | |
| running_loss = correct = total = 0 | |
| pbar = tqdm(loader, desc=desc, leave=False) | |
| for inputs, labels in pbar: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * inputs.size(0) | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| acc = 100.0 * correct / total if total > 0 else 0 | |
| pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{acc:.1f}%") | |
| return running_loss / total, 100.0 * correct / total | |
| def evaluate(model, loader, criterion, device, desc="Evaluating"): | |
| model.eval() | |
| running_loss = correct = total = 0 | |
| pbar = tqdm(loader, desc=desc, leave=False) | |
| for inputs, labels in pbar: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| running_loss += loss.item() * inputs.size(0) | |
| _, predicted = outputs.max(1) | |
| total += labels.size(0) | |
| correct += predicted.eq(labels).sum().item() | |
| acc = 100.0 * correct / total if total > 0 else 0 | |
| pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{acc:.1f}%") | |
| return running_loss / total, 100.0 * correct / total | |
| def detailed_evaluate(model, loader, class_names, device): | |
| """返回: (loss, acc, per_class_acc, confusion_matrix)""" | |
| model.eval() | |
| n = len(class_names) | |
| correct_per_class = np.zeros(n) | |
| total_per_class = np.zeros(n) | |
| conf_matrix = np.zeros((n, n), dtype=int) | |
| criterion = nn.CrossEntropyLoss() | |
| total_loss = total_samples = 0 | |
| for inputs, labels in tqdm(loader, desc="详细评估", leave=False): | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * inputs.size(0) | |
| total_samples += inputs.size(0) | |
| _, predicted = outputs.max(1) | |
| for t, p in zip(labels.cpu().numpy(), predicted.cpu().numpy()): | |
| conf_matrix[t, p] += 1 | |
| total_per_class[t] += 1 | |
| if t == p: | |
| correct_per_class[t] += 1 | |
| avg_loss = total_loss / total_samples | |
| overall_acc = 100.0 * correct_per_class.sum() / total_per_class.sum() | |
| per_class_acc = 100.0 * correct_per_class / np.maximum(total_per_class, 1) | |
| return avg_loss, overall_acc, per_class_acc, conf_matrix | |
| # ── 绘图 ────────────────────────────────────────────── | |
| def plot_training_curves(history, save_path): | |
| """绘制并保存 Loss / Accuracy 曲线图""" | |
| epochs = range(1, len(history["train_loss"]) + 1) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5)) | |
| # Loss | |
| ax1.plot(epochs, history["train_loss"], "o-", label="Train Loss", color="#2196F3") | |
| ax1.plot(epochs, history["val_loss"], "s-", label="Val Loss", color="#FF5722") | |
| ax1.set_xlabel("Epoch") | |
| ax1.set_ylabel("Loss") | |
| ax1.set_title("Loss 曲线") | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| # Accuracy | |
| ax2.plot(epochs, history["train_acc"], "o-", label="Train Acc", color="#2196F3") | |
| ax2.plot(epochs, history["val_acc"], "s-", label="Val Acc", color="#FF5722") | |
| ax2.axhline(y=history["best_acc"], color="green", linestyle="--", alpha=0.5, | |
| label=f"Best Val {history['best_acc']:.1f}%") | |
| ax2.set_xlabel("Epoch") | |
| ax2.set_ylabel("Accuracy (%)") | |
| ax2.set_title("Accuracy 曲线") | |
| ax2.legend() | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f" 📊 训练曲线已保存: {save_path}") | |
| # ── 评估报告 ────────────────────────────────────────── | |
| def print_evaluation_report(class_names, per_class_acc, conf_matrix): | |
| """打印详细的评估报告""" | |
| print(f"\n{'='*55}") | |
| print(f" 详细评估报告") | |
| print(f"{'='*55}") | |
| print(f" 类别准确率:") | |
| for i, name in enumerate(class_names): | |
| cn = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else name | |
| bar = "█" * int(per_class_acc[i] // 5) + "░" * (20 - int(per_class_acc[i] // 5)) | |
| print(f" {i}. {cn:8s} ({name:10s}): {per_class_acc[i]:5.1f}% [{bar}]") | |
| print(f"{'─'*55}") | |
| # 混淆矩阵 | |
| print(f" 混淆矩阵 (行=真实, 列=预测):") | |
| header = "".join(f"{short:>6}" for short in [c[:5] for c in class_names]) | |
| print(f" {'':>6}{header}") | |
| for i in range(len(class_names)): | |
| row = "".join(f"{conf_matrix[i, j]:>6}" for j in range(len(class_names))) | |
| cn = CLASS_NAMES_CN[i][:2] if i < len(CLASS_NAMES_CN) else class_names[i][:2] | |
| print(f" {cn:>4}: {row} {per_class_acc[i]:.1f}%") | |
| # 易混淆对 | |
| print(f"\n 易混淆组合 (非对角线最高):") | |
| n = len(class_names) | |
| pairs = [] | |
| for i in range(n): | |
| for j in range(n): | |
| if i != j and conf_matrix[i, j] > 0: | |
| pairs.append((conf_matrix[i, j], i, j)) | |
| pairs.sort(reverse=True) | |
| for count, i, j in pairs[:3]: | |
| cn_i = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else class_names[i] | |
| cn_j = CLASS_NAMES_CN[j] if j < len(CLASS_NAMES_CN) else class_names[j] | |
| ratio = count / max(conf_matrix[i].sum(), 1) * 100 | |
| print(f" {cn_i} → {cn_j}: {count} 次 ({ratio:.1f}%)") | |
| # ── 数据加载 ────────────────────────────────────────── | |
| def load_split_data(data_dir, train_tf, eval_tf, batch_size): | |
| """加载已划分的数据集 (train/val/test 子目录)""" | |
| train_dir = data_dir / "train" | |
| val_dir = data_dir / "val" | |
| if not train_dir.exists() or not val_dir.exists(): | |
| return None | |
| train_dataset = datasets.ImageFolder(root=str(train_dir), transform=train_tf) | |
| val_dataset = datasets.ImageFolder(root=str(val_dir), transform=eval_tf) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
| test_loader = None | |
| test_dir = data_dir / "test" | |
| if test_dir.exists(): | |
| test_dataset = datasets.ImageFolder(root=str(test_dir), transform=eval_tf) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
| return train_loader, val_loader, test_loader, train_dataset.classes | |
| def load_random_split_data(data_dir, train_tf, eval_tf, batch_size): | |
| """从单目录随机划分""" | |
| full_dataset = datasets.ImageFolder(root=str(data_dir), transform=train_tf) | |
| train_size = int(0.8 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) | |
| val_dataset.dataset.transform = eval_tf | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
| return train_loader, val_loader, None, full_dataset.classes | |
| # ── 主训练流程 ──────────────────────────────────────── | |
| def train(args): | |
| device = get_device() | |
| data_dir = Path(args.data_dir) | |
| model_dir = Path(args.model_dir) | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| if not data_dir.exists(): | |
| print(f"✗ 数据集路径不存在: {data_dir}") | |
| print("请将数据集放在以下结构之一:") | |
| print(f" {data_dir}/ ├── cardboard/ └── ... (自动 80/20 划分)") | |
| print(f" 或运行 split_dataset.py 划分后使用:") | |
| print(f" {data_dir}/train/ ├── cardboard/ └── ...") | |
| print(f" {data_dir}/val/ ├── cardboard/ └── ...") | |
| return | |
| train_tf, eval_tf = get_transforms() | |
| if (data_dir / "train").exists(): | |
| result = load_split_data(data_dir, train_tf, eval_tf, args.batch_size) | |
| if result: | |
| train_loader, val_loader, test_loader, classes = result | |
| print(f"\n检测到已划分的数据集") | |
| else: | |
| result = load_random_split_data(data_dir, train_tf, eval_tf, args.batch_size) | |
| if result: | |
| train_loader, val_loader, test_loader, classes = result | |
| print(f"\n检测到未划分的数据集 (自动 80/20 随机划分)") | |
| print(f" 类别 ({len(classes)}): {classes}") | |
| print(f" 训练集: {len(train_loader.dataset)} 张") | |
| print(f" 验证集: {len(val_loader.dataset)} 张") | |
| if test_loader: | |
| print(f" 测试集: {len(test_loader.dataset)} 张") | |
| model = create_model(num_classes=len(classes)).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| # ── 训练循环 ── | |
| history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "best_acc": 0.0} | |
| best_acc = 0.0 | |
| print(f"\n开始训练 (共 {args.epochs} 轮)...") | |
| print(f"{'─'*65}") | |
| for epoch in range(1, args.epochs + 1): | |
| train_loss, train_acc = train_epoch( | |
| model, train_loader, criterion, optimizer, device, | |
| desc=f"Epoch {epoch}/{args.epochs}", | |
| ) | |
| val_loss, val_acc = evaluate(model, val_loader, criterion, device, desc="Validating") | |
| scheduler.step() | |
| history["train_loss"].append(train_loss) | |
| history["train_acc"].append(train_acc) | |
| history["val_loss"].append(val_loss) | |
| history["val_acc"].append(val_acc) | |
| print( | |
| f"Epoch {epoch:2d}/{args.epochs} | " | |
| f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | " | |
| f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | " | |
| f"LR: {scheduler.get_last_lr()[0]:.2e}" | |
| ) | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| history["best_acc"] = best_acc | |
| model_path = model_dir / "garbage_model.pth" | |
| torch.save({ | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "best_acc": best_acc, | |
| "class_names": classes, | |
| }, str(model_path)) | |
| print(f" ✓ 保存最佳模型 (验证准确率: {best_acc:.2f}%)") | |
| # ── 训练结束 ── | |
| print(f"{'─'*65}") | |
| print(f"训练完成!最佳验证准确率: {best_acc:.2f}%") | |
| # 绘制训练曲线 | |
| plot_path = model_dir / "training_curves.png" | |
| plot_training_curves(history, plot_path) | |
| # 测试集详细评估 | |
| if test_loader: | |
| print(f"\n{'='*55}") | |
| print(f" 测试集最终评估") | |
| print(f"{'='*55}") | |
| test_loss, test_acc, per_class_acc, conf_matrix = detailed_evaluate( | |
| model, test_loader, classes, device | |
| ) | |
| print(f" 测试集 Loss: {test_loss:.4f} | 准确率: {test_acc:.2f}%") | |
| print_evaluation_report(classes, per_class_acc, conf_matrix) | |
| # 追加测试结果到报告文件 | |
| report_path = model_dir / "evaluation_report.txt" | |
| with open(report_path, "w", encoding="utf-8") as f: | |
| f.write(f"AI 垃圾分类助手 - 模型评估报告\n") | |
| f.write(f"{'='*55}\n") | |
| f.write(f"训练设备: {device}\n") | |
| f.write(f"训练轮数: {args.epochs}\n") | |
| f.write(f"批次大小: {args.batch_size}\n") | |
| f.write(f"学习率: {args.lr}\n\n") | |
| f.write(f"最佳验证准确率: {best_acc:.2f}%\n") | |
| f.write(f"测试集准确率: {test_acc:.2f}%\n\n") | |
| f.write(f"各类别准确率:\n") | |
| for i, name in enumerate(classes): | |
| cn = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else name | |
| f.write(f" {cn} ({name}): {per_class_acc[i]:.2f}%\n") | |
| f.write(f"\n混淆矩阵:\n") | |
| f.write(f"{'':>6}" + "".join(f"{c[:5]:>6}" for c in classes) + "\n") | |
| for i in range(len(classes)): | |
| f.write(f"{classes[i][:4]:>4}: " + "".join(f"{conf_matrix[i,j]:>6}" for j in range(len(classes))) + "\n") | |
| print(f" 📄 评估报告已保存: {report_path}") | |
| print(f"\n✓ 模型: {model_dir / 'garbage_model.pth'}") | |
| print(f"✓ 曲线图: {plot_path}") | |
| print(f" 如需启动 Web 界面: python main.py webui") | |
| # ── CLI ─────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="训练垃圾分类模型") | |
| parser.add_argument("--data-dir", default="dataset") | |
| parser.add_argument("--model-dir", default="models") | |
| parser.add_argument("--epochs", type=int, default=30) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--lr", type=float, default=0.001) | |
| args = parser.parse_args() | |
| train(args) | |