File size: 3,723 Bytes
52dd1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# src/inference/svm_model.py

import os
import json
from typing import Dict, Any, List

import numpy as np
from PIL import Image
from torchvision import transforms
import joblib


class SVMModel:
    """

    Inference wrapper for the Linear SVM trained on raw 64x64 grayscale pixels.

    """

    def __init__(

        self,

        ckpt_path: str = "checkpoints/svm_model.joblib",

        labels_path: str = "configs/labels.json",

    ):
        assert os.path.exists(ckpt_path), f"SVM checkpoint not found: {ckpt_path}"
        assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"

        print(f"[SVMModel] Loading checkpoint from {ckpt_path} ...")
        payload = joblib.load(ckpt_path)

        # You might have saved a dict with more keys, so handle both cases.
        if isinstance(payload, dict) and "model" in payload:
            self.model = payload["model"]
        else:
            self.model = payload

        print(f"[SVMModel] Loading labels from {labels_path} ...")
        with open(labels_path, "r") as f:
            self.id_to_name = json.load(f)

        # Ensure keys are integers
        self.id_to_name = {int(k): v for k, v in self.id_to_name.items()}

        self.preprocess_tf = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),  # (1, 64, 64) in [0, 1]
        ])

    def preprocess(self, img: Image.Image) -> np.ndarray:
        """

        Convert PIL image to flattened grayscale vector (1, 4096).

        """
        t = self.preprocess_tf(img)        # (1, 64, 64) tensor
        arr = t.view(-1).numpy()           # (4096,)
        return arr[np.newaxis, :]          # (1, 4096)

    @staticmethod
    def _softmax(scores: np.ndarray) -> np.ndarray:
        # scores: (C,)
        scores = scores - np.max(scores)   # for numerical stability
        exp = np.exp(scores)
        return exp / np.sum(exp)

    def predict(

        self,

        img: Image.Image,

        top_k: int = 5,

    ) -> Dict[str, Any]:
        """

        Predict the class of a single image.



        Returns:

        {

          "class_id": int,

          "class_name": str,

          "probabilities": {class_name: prob_float}   # full distribution

          "top_k": List[{"class_id": int, "class_name": str, "probability": float}]

        }

        """
        x = self.preprocess(img)  # (1, 4096)

        # LinearSVC doesn't have predict_proba, but decision_function gives scores
        scores = self.model.decision_function(x)  # (1, C) or (C,) if binary
        if scores.ndim == 1:
            scores = scores[np.newaxis, :]
        scores = scores[0]  # (C,)

        probs = self._softmax(scores)  # (C,)

        pred_id = int(np.argmax(probs))
        pred_name = self.id_to_name[pred_id]

        # Build dict of {class_name: prob}
        prob_dict = {
            self.id_to_name[i]: float(p)
            for i, p in enumerate(probs)
        }

        # Build sorted top-k
        sorted_indices = np.argsort(probs)[::-1]
        top_k = min(top_k, len(sorted_indices))
        top_k_list: List[Dict[str, Any]] = []
        for i in range(top_k):
            cid = int(sorted_indices[i])
            top_k_list.append({
                "class_id": cid,
                "class_name": self.id_to_name[cid],
                "probability": float(probs[cid]),
            })

        return {
            "class_id": pred_id,
            "class_name": pred_name,
            "probabilities": prob_dict,
            "top_k": top_k_list,
        }