""" AI 垃圾分类助手 - 训练模块 支持两种数据目录结构: 1. 已划分: dataset/train/ + dataset/val/ [+ dataset/test/] 2. 未划分: dataset/trashnet/ (自动随机划分) 训练结束后保存 loss 曲线图并输出详细评估报告 """ import argparse import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights from pathlib import Path from tqdm import tqdm import matplotlib matplotlib.use("Agg") # 不依赖 GUI 后端 import matplotlib.pyplot as plt import numpy as np from config import CLASS_NAMES, CLASS_NAMES_CN # ── 设备 ────────────────────────────────────────────── def get_device(): if torch.backends.mps.is_available(): device = torch.device("mps") print("✓ 使用 MPS (Apple Silicon) 加速训练") elif torch.cuda.is_available(): device = torch.device("cuda") print("✓ 使用 CUDA 加速训练") else: device = torch.device("cpu") print("⚠ 使用 CPU 训练 (建议使用 MPS/CUDA)") return device # ── 数据预处理 ──────────────────────────────────────── def get_transforms(): train_tf = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) eval_tf = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return train_tf, eval_tf # ── 模型 ────────────────────────────────────────────── def create_model(num_classes=6): model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1) in_features = model.classifier[3].in_features model.classifier[3] = nn.Linear(in_features, num_classes) return model # ── 训练 / 评估 ─────────────────────────────────────── def train_epoch(model, loader, criterion, optimizer, device, desc="Training"): model.train() running_loss = correct = total = 0 pbar = tqdm(loader, desc=desc, leave=False) for inputs, labels in pbar: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() acc = 100.0 * correct / total if total > 0 else 0 pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{acc:.1f}%") return running_loss / total, 100.0 * correct / total @torch.no_grad() def evaluate(model, loader, criterion, device, desc="Evaluating"): model.eval() running_loss = correct = total = 0 pbar = tqdm(loader, desc=desc, leave=False) for inputs, labels in pbar: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) running_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() acc = 100.0 * correct / total if total > 0 else 0 pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{acc:.1f}%") return running_loss / total, 100.0 * correct / total @torch.no_grad() def detailed_evaluate(model, loader, class_names, device): """返回: (loss, acc, per_class_acc, confusion_matrix)""" model.eval() n = len(class_names) correct_per_class = np.zeros(n) total_per_class = np.zeros(n) conf_matrix = np.zeros((n, n), dtype=int) criterion = nn.CrossEntropyLoss() total_loss = total_samples = 0 for inputs, labels in tqdm(loader, desc="详细评估", leave=False): inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) total_loss += loss.item() * inputs.size(0) total_samples += inputs.size(0) _, predicted = outputs.max(1) for t, p in zip(labels.cpu().numpy(), predicted.cpu().numpy()): conf_matrix[t, p] += 1 total_per_class[t] += 1 if t == p: correct_per_class[t] += 1 avg_loss = total_loss / total_samples overall_acc = 100.0 * correct_per_class.sum() / total_per_class.sum() per_class_acc = 100.0 * correct_per_class / np.maximum(total_per_class, 1) return avg_loss, overall_acc, per_class_acc, conf_matrix # ── 绘图 ────────────────────────────────────────────── def plot_training_curves(history, save_path): """绘制并保存 Loss / Accuracy 曲线图""" epochs = range(1, len(history["train_loss"]) + 1) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5)) # Loss ax1.plot(epochs, history["train_loss"], "o-", label="Train Loss", color="#2196F3") ax1.plot(epochs, history["val_loss"], "s-", label="Val Loss", color="#FF5722") ax1.set_xlabel("Epoch") ax1.set_ylabel("Loss") ax1.set_title("Loss 曲线") ax1.legend() ax1.grid(True, alpha=0.3) # Accuracy ax2.plot(epochs, history["train_acc"], "o-", label="Train Acc", color="#2196F3") ax2.plot(epochs, history["val_acc"], "s-", label="Val Acc", color="#FF5722") ax2.axhline(y=history["best_acc"], color="green", linestyle="--", alpha=0.5, label=f"Best Val {history['best_acc']:.1f}%") ax2.set_xlabel("Epoch") ax2.set_ylabel("Accuracy (%)") ax2.set_title("Accuracy 曲线") ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f" 📊 训练曲线已保存: {save_path}") # ── 评估报告 ────────────────────────────────────────── def print_evaluation_report(class_names, per_class_acc, conf_matrix): """打印详细的评估报告""" print(f"\n{'='*55}") print(f" 详细评估报告") print(f"{'='*55}") print(f" 类别准确率:") for i, name in enumerate(class_names): cn = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else name bar = "█" * int(per_class_acc[i] // 5) + "░" * (20 - int(per_class_acc[i] // 5)) print(f" {i}. {cn:8s} ({name:10s}): {per_class_acc[i]:5.1f}% [{bar}]") print(f"{'─'*55}") # 混淆矩阵 print(f" 混淆矩阵 (行=真实, 列=预测):") header = "".join(f"{short:>6}" for short in [c[:5] for c in class_names]) print(f" {'':>6}{header}") for i in range(len(class_names)): row = "".join(f"{conf_matrix[i, j]:>6}" for j in range(len(class_names))) cn = CLASS_NAMES_CN[i][:2] if i < len(CLASS_NAMES_CN) else class_names[i][:2] print(f" {cn:>4}: {row} {per_class_acc[i]:.1f}%") # 易混淆对 print(f"\n 易混淆组合 (非对角线最高):") n = len(class_names) pairs = [] for i in range(n): for j in range(n): if i != j and conf_matrix[i, j] > 0: pairs.append((conf_matrix[i, j], i, j)) pairs.sort(reverse=True) for count, i, j in pairs[:3]: cn_i = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else class_names[i] cn_j = CLASS_NAMES_CN[j] if j < len(CLASS_NAMES_CN) else class_names[j] ratio = count / max(conf_matrix[i].sum(), 1) * 100 print(f" {cn_i} → {cn_j}: {count} 次 ({ratio:.1f}%)") # ── 数据加载 ────────────────────────────────────────── def load_split_data(data_dir, train_tf, eval_tf, batch_size): """加载已划分的数据集 (train/val/test 子目录)""" train_dir = data_dir / "train" val_dir = data_dir / "val" if not train_dir.exists() or not val_dir.exists(): return None train_dataset = datasets.ImageFolder(root=str(train_dir), transform=train_tf) val_dataset = datasets.ImageFolder(root=str(val_dir), transform=eval_tf) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) test_loader = None test_dir = data_dir / "test" if test_dir.exists(): test_dataset = datasets.ImageFolder(root=str(test_dir), transform=eval_tf) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) return train_loader, val_loader, test_loader, train_dataset.classes def load_random_split_data(data_dir, train_tf, eval_tf, batch_size): """从单目录随机划分""" full_dataset = datasets.ImageFolder(root=str(data_dir), transform=train_tf) train_size = int(0.8 * len(full_dataset)) val_size = len(full_dataset) - train_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) val_dataset.dataset.transform = eval_tf train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0) return train_loader, val_loader, None, full_dataset.classes # ── 主训练流程 ──────────────────────────────────────── def train(args): device = get_device() data_dir = Path(args.data_dir) model_dir = Path(args.model_dir) model_dir.mkdir(parents=True, exist_ok=True) if not data_dir.exists(): print(f"✗ 数据集路径不存在: {data_dir}") print("请将数据集放在以下结构之一:") print(f" {data_dir}/ ├── cardboard/ └── ... (自动 80/20 划分)") print(f" 或运行 split_dataset.py 划分后使用:") print(f" {data_dir}/train/ ├── cardboard/ └── ...") print(f" {data_dir}/val/ ├── cardboard/ └── ...") return train_tf, eval_tf = get_transforms() if (data_dir / "train").exists(): result = load_split_data(data_dir, train_tf, eval_tf, args.batch_size) if result: train_loader, val_loader, test_loader, classes = result print(f"\n检测到已划分的数据集") else: result = load_random_split_data(data_dir, train_tf, eval_tf, args.batch_size) if result: train_loader, val_loader, test_loader, classes = result print(f"\n检测到未划分的数据集 (自动 80/20 随机划分)") print(f" 类别 ({len(classes)}): {classes}") print(f" 训练集: {len(train_loader.dataset)} 张") print(f" 验证集: {len(val_loader.dataset)} 张") if test_loader: print(f" 测试集: {len(test_loader.dataset)} 张") model = create_model(num_classes=len(classes)).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) # ── 训练循环 ── history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "best_acc": 0.0} best_acc = 0.0 print(f"\n开始训练 (共 {args.epochs} 轮)...") print(f"{'─'*65}") for epoch in range(1, args.epochs + 1): train_loss, train_acc = train_epoch( model, train_loader, criterion, optimizer, device, desc=f"Epoch {epoch}/{args.epochs}", ) val_loss, val_acc = evaluate(model, val_loader, criterion, device, desc="Validating") scheduler.step() history["train_loss"].append(train_loss) history["train_acc"].append(train_acc) history["val_loss"].append(val_loss) history["val_acc"].append(val_acc) print( f"Epoch {epoch:2d}/{args.epochs} | " f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | " f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | " f"LR: {scheduler.get_last_lr()[0]:.2e}" ) if val_acc > best_acc: best_acc = val_acc history["best_acc"] = best_acc model_path = model_dir / "garbage_model.pth" torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_acc": best_acc, "class_names": classes, }, str(model_path)) print(f" ✓ 保存最佳模型 (验证准确率: {best_acc:.2f}%)") # ── 训练结束 ── print(f"{'─'*65}") print(f"训练完成!最佳验证准确率: {best_acc:.2f}%") # 绘制训练曲线 plot_path = model_dir / "training_curves.png" plot_training_curves(history, plot_path) # 测试集详细评估 if test_loader: print(f"\n{'='*55}") print(f" 测试集最终评估") print(f"{'='*55}") test_loss, test_acc, per_class_acc, conf_matrix = detailed_evaluate( model, test_loader, classes, device ) print(f" 测试集 Loss: {test_loss:.4f} | 准确率: {test_acc:.2f}%") print_evaluation_report(classes, per_class_acc, conf_matrix) # 追加测试结果到报告文件 report_path = model_dir / "evaluation_report.txt" with open(report_path, "w", encoding="utf-8") as f: f.write(f"AI 垃圾分类助手 - 模型评估报告\n") f.write(f"{'='*55}\n") f.write(f"训练设备: {device}\n") f.write(f"训练轮数: {args.epochs}\n") f.write(f"批次大小: {args.batch_size}\n") f.write(f"学习率: {args.lr}\n\n") f.write(f"最佳验证准确率: {best_acc:.2f}%\n") f.write(f"测试集准确率: {test_acc:.2f}%\n\n") f.write(f"各类别准确率:\n") for i, name in enumerate(classes): cn = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else name f.write(f" {cn} ({name}): {per_class_acc[i]:.2f}%\n") f.write(f"\n混淆矩阵:\n") f.write(f"{'':>6}" + "".join(f"{c[:5]:>6}" for c in classes) + "\n") for i in range(len(classes)): f.write(f"{classes[i][:4]:>4}: " + "".join(f"{conf_matrix[i,j]:>6}" for j in range(len(classes))) + "\n") print(f" 📄 评估报告已保存: {report_path}") print(f"\n✓ 模型: {model_dir / 'garbage_model.pth'}") print(f"✓ 曲线图: {plot_path}") print(f" 如需启动 Web 界面: python main.py webui") # ── CLI ─────────────────────────────────────────────── if __name__ == "__main__": parser = argparse.ArgumentParser(description="训练垃圾分类模型") parser.add_argument("--data-dir", default="dataset") parser.add_argument("--model-dir", default="models") parser.add_argument("--epochs", type=int, default=30) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--lr", type=float, default=0.001) args = parser.parse_args() train(args)