File size: 2,901 Bytes
bf5b4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e18bc7
bf5b4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
AI 垃圾分类助手 - 预测模块
使用训练好的模型进行垃圾图像分类
"""

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import mobilenet_v3_small
from PIL import Image
from pathlib import Path
from config import CLASS_NAMES, CLASS_NAMES_CN, MODEL_PATH as DEFAULT_MODEL_PATH


class GarbageClassifier:
    def __init__(self, model_path=None, device=None):
        model_path = model_path or str(DEFAULT_MODEL_PATH)
        self.device = device or self._get_device()
        self.class_names = CLASS_NAMES
        self.class_names_cn = CLASS_NAMES_CN
        self.model = self._load_model(model_path)
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def _get_device(self):
        if torch.backends.mps.is_available():
            return torch.device("mps")
        elif torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")

    def _load_model(self, model_path):
        model = mobilenet_v3_small(weights=None)
        in_features = model.classifier[3].in_features
        model.classifier[3] = nn.Linear(in_features, len(CLASS_NAMES))

        path = Path(model_path)
        if not path.exists():
            raise FileNotFoundError(f"模型文件不存在: {model_path}\n请先运行 train.py 训练模型")

        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        model = model.to(self.device)
        print(f"✓ 模型加载成功 ({model_path})")
        print(f"  验证准确率: {checkpoint.get('best_acc', 'N/A'):.2f}%")
        return model

    def predict(self, image_path, top_k=3):
        image = Image.open(image_path).convert("RGB")
        input_tensor = self.transform(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)

        top_probs, top_indices = torch.topk(probabilities, top_k)
        top_probs = top_probs.squeeze().cpu().numpy()
        top_indices = top_indices.squeeze().cpu().numpy()

        if top_k == 1:
            top_probs = [top_probs]
            top_indices = [top_indices]

        return [
            {
                "class_name": self.class_names[idx],
                "class_name_cn": self.class_names_cn[idx],
                "confidence": float(prob),
            }
            for prob, idx in zip(top_probs, top_indices)
        ]

    def predict_batch(self, image_paths):
        return {path: self.predict(path) for path in image_paths}