Spaces:
Sleeping
Sleeping
| """ | |
| AI 垃圾分类助手 - 预测模块 | |
| 使用训练好的模型进行垃圾图像分类 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import mobilenet_v3_small | |
| from PIL import Image | |
| from pathlib import Path | |
| from config import CLASS_NAMES, CLASS_NAMES_CN, MODEL_PATH as DEFAULT_MODEL_PATH | |
| class GarbageClassifier: | |
| def __init__(self, model_path=None, device=None): | |
| model_path = model_path or str(DEFAULT_MODEL_PATH) | |
| self.device = device or self._get_device() | |
| self.class_names = CLASS_NAMES | |
| self.class_names_cn = CLASS_NAMES_CN | |
| self.model = self._load_model(model_path) | |
| self.model.eval() | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def _get_device(self): | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def _load_model(self, model_path): | |
| model = mobilenet_v3_small(weights=None) | |
| in_features = model.classifier[3].in_features | |
| model.classifier[3] = nn.Linear(in_features, len(CLASS_NAMES)) | |
| path = Path(model_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"模型文件不存在: {model_path}\n请先运行 train.py 训练模型") | |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model = model.to(self.device) | |
| print(f"✓ 模型加载成功 ({model_path})") | |
| print(f" 验证准确率: {checkpoint.get('best_acc', 'N/A'):.2f}%") | |
| return model | |
| def predict(self, image_path, top_k=3): | |
| image = Image.open(image_path).convert("RGB") | |
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| top_probs, top_indices = torch.topk(probabilities, top_k) | |
| top_probs = top_probs.squeeze().cpu().numpy() | |
| top_indices = top_indices.squeeze().cpu().numpy() | |
| if top_k == 1: | |
| top_probs = [top_probs] | |
| top_indices = [top_indices] | |
| return [ | |
| { | |
| "class_name": self.class_names[idx], | |
| "class_name_cn": self.class_names_cn[idx], | |
| "confidence": float(prob), | |
| } | |
| for prob, idx in zip(top_probs, top_indices) | |
| ] | |
| def predict_batch(self, image_paths): | |
| return {path: self.predict(path) for path in image_paths} | |