hutiger's picture
Upload folder using huggingface_hub
bf5b4d8 verified
Raw
History Blame Contribute Delete
16.3 kB
"""
AI 垃圾分类助手 - 训练模块
支持两种数据目录结构:
1. 已划分: dataset/train/ + dataset/val/ [+ dataset/test/]
2. 未划分: dataset/trashnet/ (自动随机划分)
训练结束后保存 loss 曲线图并输出详细评估报告
"""
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from pathlib import Path
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg") # 不依赖 GUI 后端
import matplotlib.pyplot as plt
import numpy as np
from config import CLASS_NAMES, CLASS_NAMES_CN
# ── 设备 ──────────────────────────────────────────────
def get_device():
if torch.backends.mps.is_available():
device = torch.device("mps")
print("✓ 使用 MPS (Apple Silicon) 加速训练")
elif torch.cuda.is_available():
device = torch.device("cuda")
print("✓ 使用 CUDA 加速训练")
else:
device = torch.device("cpu")
print("⚠ 使用 CPU 训练 (建议使用 MPS/CUDA)")
return device
# ── 数据预处理 ────────────────────────────────────────
def get_transforms():
train_tf = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
eval_tf = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return train_tf, eval_tf
# ── 模型 ──────────────────────────────────────────────
def create_model(num_classes=6):
model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1)
in_features = model.classifier[3].in_features
model.classifier[3] = nn.Linear(in_features, num_classes)
return model
# ── 训练 / 评估 ───────────────────────────────────────
def train_epoch(model, loader, criterion, optimizer, device, desc="Training"):
model.train()
running_loss = correct = total = 0
pbar = tqdm(loader, desc=desc, leave=False)
for inputs, labels in pbar:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100.0 * correct / total if total > 0 else 0
pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{acc:.1f}%")
return running_loss / total, 100.0 * correct / total
@torch.no_grad()
def evaluate(model, loader, criterion, device, desc="Evaluating"):
model.eval()
running_loss = correct = total = 0
pbar = tqdm(loader, desc=desc, leave=False)
for inputs, labels in pbar:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100.0 * correct / total if total > 0 else 0
pbar.set_postfix(loss=f"{running_loss/total:.4f}", acc=f"{acc:.1f}%")
return running_loss / total, 100.0 * correct / total
@torch.no_grad()
def detailed_evaluate(model, loader, class_names, device):
"""返回: (loss, acc, per_class_acc, confusion_matrix)"""
model.eval()
n = len(class_names)
correct_per_class = np.zeros(n)
total_per_class = np.zeros(n)
conf_matrix = np.zeros((n, n), dtype=int)
criterion = nn.CrossEntropyLoss()
total_loss = total_samples = 0
for inputs, labels in tqdm(loader, desc="详细评估", leave=False):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item() * inputs.size(0)
total_samples += inputs.size(0)
_, predicted = outputs.max(1)
for t, p in zip(labels.cpu().numpy(), predicted.cpu().numpy()):
conf_matrix[t, p] += 1
total_per_class[t] += 1
if t == p:
correct_per_class[t] += 1
avg_loss = total_loss / total_samples
overall_acc = 100.0 * correct_per_class.sum() / total_per_class.sum()
per_class_acc = 100.0 * correct_per_class / np.maximum(total_per_class, 1)
return avg_loss, overall_acc, per_class_acc, conf_matrix
# ── 绘图 ──────────────────────────────────────────────
def plot_training_curves(history, save_path):
"""绘制并保存 Loss / Accuracy 曲线图"""
epochs = range(1, len(history["train_loss"]) + 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))
# Loss
ax1.plot(epochs, history["train_loss"], "o-", label="Train Loss", color="#2196F3")
ax1.plot(epochs, history["val_loss"], "s-", label="Val Loss", color="#FF5722")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Loss 曲线")
ax1.legend()
ax1.grid(True, alpha=0.3)
# Accuracy
ax2.plot(epochs, history["train_acc"], "o-", label="Train Acc", color="#2196F3")
ax2.plot(epochs, history["val_acc"], "s-", label="Val Acc", color="#FF5722")
ax2.axhline(y=history["best_acc"], color="green", linestyle="--", alpha=0.5,
label=f"Best Val {history['best_acc']:.1f}%")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.set_title("Accuracy 曲线")
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" 📊 训练曲线已保存: {save_path}")
# ── 评估报告 ──────────────────────────────────────────
def print_evaluation_report(class_names, per_class_acc, conf_matrix):
"""打印详细的评估报告"""
print(f"\n{'='*55}")
print(f" 详细评估报告")
print(f"{'='*55}")
print(f" 类别准确率:")
for i, name in enumerate(class_names):
cn = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else name
bar = "█" * int(per_class_acc[i] // 5) + "░" * (20 - int(per_class_acc[i] // 5))
print(f" {i}. {cn:8s} ({name:10s}): {per_class_acc[i]:5.1f}% [{bar}]")
print(f"{'─'*55}")
# 混淆矩阵
print(f" 混淆矩阵 (行=真实, 列=预测):")
header = "".join(f"{short:>6}" for short in [c[:5] for c in class_names])
print(f" {'':>6}{header}")
for i in range(len(class_names)):
row = "".join(f"{conf_matrix[i, j]:>6}" for j in range(len(class_names)))
cn = CLASS_NAMES_CN[i][:2] if i < len(CLASS_NAMES_CN) else class_names[i][:2]
print(f" {cn:>4}: {row} {per_class_acc[i]:.1f}%")
# 易混淆对
print(f"\n 易混淆组合 (非对角线最高):")
n = len(class_names)
pairs = []
for i in range(n):
for j in range(n):
if i != j and conf_matrix[i, j] > 0:
pairs.append((conf_matrix[i, j], i, j))
pairs.sort(reverse=True)
for count, i, j in pairs[:3]:
cn_i = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else class_names[i]
cn_j = CLASS_NAMES_CN[j] if j < len(CLASS_NAMES_CN) else class_names[j]
ratio = count / max(conf_matrix[i].sum(), 1) * 100
print(f" {cn_i}{cn_j}: {count} 次 ({ratio:.1f}%)")
# ── 数据加载 ──────────────────────────────────────────
def load_split_data(data_dir, train_tf, eval_tf, batch_size):
"""加载已划分的数据集 (train/val/test 子目录)"""
train_dir = data_dir / "train"
val_dir = data_dir / "val"
if not train_dir.exists() or not val_dir.exists():
return None
train_dataset = datasets.ImageFolder(root=str(train_dir), transform=train_tf)
val_dataset = datasets.ImageFolder(root=str(val_dir), transform=eval_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = None
test_dir = data_dir / "test"
if test_dir.exists():
test_dataset = datasets.ImageFolder(root=str(test_dir), transform=eval_tf)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
return train_loader, val_loader, test_loader, train_dataset.classes
def load_random_split_data(data_dir, train_tf, eval_tf, batch_size):
"""从单目录随机划分"""
full_dataset = datasets.ImageFolder(root=str(data_dir), transform=train_tf)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = eval_tf
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
return train_loader, val_loader, None, full_dataset.classes
# ── 主训练流程 ────────────────────────────────────────
def train(args):
device = get_device()
data_dir = Path(args.data_dir)
model_dir = Path(args.model_dir)
model_dir.mkdir(parents=True, exist_ok=True)
if not data_dir.exists():
print(f"✗ 数据集路径不存在: {data_dir}")
print("请将数据集放在以下结构之一:")
print(f" {data_dir}/ ├── cardboard/ └── ... (自动 80/20 划分)")
print(f" 或运行 split_dataset.py 划分后使用:")
print(f" {data_dir}/train/ ├── cardboard/ └── ...")
print(f" {data_dir}/val/ ├── cardboard/ └── ...")
return
train_tf, eval_tf = get_transforms()
if (data_dir / "train").exists():
result = load_split_data(data_dir, train_tf, eval_tf, args.batch_size)
if result:
train_loader, val_loader, test_loader, classes = result
print(f"\n检测到已划分的数据集")
else:
result = load_random_split_data(data_dir, train_tf, eval_tf, args.batch_size)
if result:
train_loader, val_loader, test_loader, classes = result
print(f"\n检测到未划分的数据集 (自动 80/20 随机划分)")
print(f" 类别 ({len(classes)}): {classes}")
print(f" 训练集: {len(train_loader.dataset)} 张")
print(f" 验证集: {len(val_loader.dataset)} 张")
if test_loader:
print(f" 测试集: {len(test_loader.dataset)} 张")
model = create_model(num_classes=len(classes)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
# ── 训练循环 ──
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "best_acc": 0.0}
best_acc = 0.0
print(f"\n开始训练 (共 {args.epochs} 轮)...")
print(f"{'─'*65}")
for epoch in range(1, args.epochs + 1):
train_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, device,
desc=f"Epoch {epoch}/{args.epochs}",
)
val_loss, val_acc = evaluate(model, val_loader, criterion, device, desc="Validating")
scheduler.step()
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
print(
f"Epoch {epoch:2d}/{args.epochs} | "
f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | "
f"LR: {scheduler.get_last_lr()[0]:.2e}"
)
if val_acc > best_acc:
best_acc = val_acc
history["best_acc"] = best_acc
model_path = model_dir / "garbage_model.pth"
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_acc": best_acc,
"class_names": classes,
}, str(model_path))
print(f" ✓ 保存最佳模型 (验证准确率: {best_acc:.2f}%)")
# ── 训练结束 ──
print(f"{'─'*65}")
print(f"训练完成!最佳验证准确率: {best_acc:.2f}%")
# 绘制训练曲线
plot_path = model_dir / "training_curves.png"
plot_training_curves(history, plot_path)
# 测试集详细评估
if test_loader:
print(f"\n{'='*55}")
print(f" 测试集最终评估")
print(f"{'='*55}")
test_loss, test_acc, per_class_acc, conf_matrix = detailed_evaluate(
model, test_loader, classes, device
)
print(f" 测试集 Loss: {test_loss:.4f} | 准确率: {test_acc:.2f}%")
print_evaluation_report(classes, per_class_acc, conf_matrix)
# 追加测试结果到报告文件
report_path = model_dir / "evaluation_report.txt"
with open(report_path, "w", encoding="utf-8") as f:
f.write(f"AI 垃圾分类助手 - 模型评估报告\n")
f.write(f"{'='*55}\n")
f.write(f"训练设备: {device}\n")
f.write(f"训练轮数: {args.epochs}\n")
f.write(f"批次大小: {args.batch_size}\n")
f.write(f"学习率: {args.lr}\n\n")
f.write(f"最佳验证准确率: {best_acc:.2f}%\n")
f.write(f"测试集准确率: {test_acc:.2f}%\n\n")
f.write(f"各类别准确率:\n")
for i, name in enumerate(classes):
cn = CLASS_NAMES_CN[i] if i < len(CLASS_NAMES_CN) else name
f.write(f" {cn} ({name}): {per_class_acc[i]:.2f}%\n")
f.write(f"\n混淆矩阵:\n")
f.write(f"{'':>6}" + "".join(f"{c[:5]:>6}" for c in classes) + "\n")
for i in range(len(classes)):
f.write(f"{classes[i][:4]:>4}: " + "".join(f"{conf_matrix[i,j]:>6}" for j in range(len(classes))) + "\n")
print(f" 📄 评估报告已保存: {report_path}")
print(f"\n✓ 模型: {model_dir / 'garbage_model.pth'}")
print(f"✓ 曲线图: {plot_path}")
print(f" 如需启动 Web 界面: python main.py webui")
# ── CLI ───────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="训练垃圾分类模型")
parser.add_argument("--data-dir", default="dataset")
parser.add_argument("--model-dir", default="models")
parser.add_argument("--epochs", type=int, default=30)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--lr", type=float, default=0.001)
args = parser.parse_args()
train(args)