""" AI 垃圾分类助手 - 预测模块 使用训练好的模型进行垃圾图像分类 """ import torch import torch.nn as nn from torchvision import transforms from torchvision.models import mobilenet_v3_small from PIL import Image from pathlib import Path from config import CLASS_NAMES, CLASS_NAMES_CN, MODEL_PATH as DEFAULT_MODEL_PATH class GarbageClassifier: def __init__(self, model_path=None, device=None): model_path = model_path or str(DEFAULT_MODEL_PATH) self.device = device or self._get_device() self.class_names = CLASS_NAMES self.class_names_cn = CLASS_NAMES_CN self.model = self._load_model(model_path) self.model.eval() self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def _get_device(self): if torch.backends.mps.is_available(): return torch.device("mps") elif torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def _load_model(self, model_path): model = mobilenet_v3_small(weights=None) in_features = model.classifier[3].in_features model.classifier[3] = nn.Linear(in_features, len(CLASS_NAMES)) path = Path(model_path) if not path.exists(): raise FileNotFoundError(f"模型文件不存在: {model_path}\n请先运行 train.py 训练模型") checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(self.device) print(f"✓ 模型加载成功 ({model_path})") print(f" 验证准确率: {checkpoint.get('best_acc', 'N/A'):.2f}%") return model def predict(self, image_path, top_k=3): image = Image.open(image_path).convert("RGB") input_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) top_probs, top_indices = torch.topk(probabilities, top_k) top_probs = top_probs.squeeze().cpu().numpy() top_indices = top_indices.squeeze().cpu().numpy() if top_k == 1: top_probs = [top_probs] top_indices = [top_indices] return [ { "class_name": self.class_names[idx], "class_name_cn": self.class_names_cn[idx], "confidence": float(prob), } for prob, idx in zip(top_probs, top_indices) ] def predict_batch(self, image_paths): return {path: self.predict(path) for path in image_paths}