hutiger's picture
Fix model loading with weights_only=False
7e18bc7 verified
Raw
History Blame Contribute Delete
2.9 kB
"""
AI 垃圾分类助手 - 预测模块
使用训练好的模型进行垃圾图像分类
"""
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import mobilenet_v3_small
from PIL import Image
from pathlib import Path
from config import CLASS_NAMES, CLASS_NAMES_CN, MODEL_PATH as DEFAULT_MODEL_PATH
class GarbageClassifier:
def __init__(self, model_path=None, device=None):
model_path = model_path or str(DEFAULT_MODEL_PATH)
self.device = device or self._get_device()
self.class_names = CLASS_NAMES
self.class_names_cn = CLASS_NAMES_CN
self.model = self._load_model(model_path)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def _get_device(self):
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def _load_model(self, model_path):
model = mobilenet_v3_small(weights=None)
in_features = model.classifier[3].in_features
model.classifier[3] = nn.Linear(in_features, len(CLASS_NAMES))
path = Path(model_path)
if not path.exists():
raise FileNotFoundError(f"模型文件不存在: {model_path}\n请先运行 train.py 训练模型")
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(self.device)
print(f"✓ 模型加载成功 ({model_path})")
print(f" 验证准确率: {checkpoint.get('best_acc', 'N/A'):.2f}%")
return model
def predict(self, image_path, top_k=3):
image = Image.open(image_path).convert("RGB")
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
top_probs, top_indices = torch.topk(probabilities, top_k)
top_probs = top_probs.squeeze().cpu().numpy()
top_indices = top_indices.squeeze().cpu().numpy()
if top_k == 1:
top_probs = [top_probs]
top_indices = [top_indices]
return [
{
"class_name": self.class_names[idx],
"class_name_cn": self.class_names_cn[idx],
"confidence": float(prob),
}
for prob, idx in zip(top_probs, top_indices)
]
def predict_batch(self, image_paths):
return {path: self.predict(path) for path in image_paths}