# src/inference/resnet_pt_svm_model.py import os import json from typing import Dict, Any, List, Optional import numpy as np from PIL import Image import torch from torchvision.models import resnet18, ResNet18_Weights import joblib class ResNetPTSVMModel: """ ResNet18 (pretrained, frozen) + Linear SVM head. Pipeline: - PIL image - ImageNet transforms - ResNet18 backbone (fc -> Identity) -> feature vector - Linear SVM decision_function - Softmax over scores to get probabilities """ def __init__( self, ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib", labels_path: str = "configs/labels.json", device: Optional[str] = None, ): assert os.path.exists(ckpt_path), f"ResNet PT + SVM checkpoint not found: {ckpt_path}" assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}" # Device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) print(f"[ResNetPTSVMModel] Using device: {self.device}") # --- Load SVM head --- print(f"[ResNetPTSVMModel] Loading SVM head from {ckpt_path} ...") payload = joblib.load(ckpt_path) if isinstance(payload, dict) and "model" in payload: self.svm_head = payload["model"] self.feature_dim = int(payload.get("feature_dim", 512)) self.backbone_name = payload.get("backbone", "resnet18_imagenet") self.saved_labels_path = payload.get("labels_path", labels_path) else: self.svm_head = payload self.feature_dim = None self.backbone_name = "resnet18_imagenet" self.saved_labels_path = labels_path # --- Load labels mapping --- labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path print(f"[ResNetPTSVMModel] Loading labels from {labels_file} ...") with open(labels_file, "r") as f: id_to_name = json.load(f) # ensure keys are ints self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()} # --- Build ResNet18 backbone + preprocess --- print("[ResNetPTSVMModel] Building ResNet18 backbone ...") weights = ResNet18_Weights.DEFAULT model = resnet18(weights=weights) import torch.nn as nn model.fc = nn.Identity() model.to(self.device) model.eval() self.backbone = model self.preprocess_tf = weights.transforms() # Optional: sanity check feature_dim if self.feature_dim is not None: try: test_input = torch.zeros(1, 3, 224, 224).to(self.device) with torch.no_grad(): out = self.backbone(test_input) actual_dim = out.shape[1] if actual_dim != self.feature_dim: print( f"[ResNetPTSVMModel][WARN] feature_dim mismatch: " f"head expects {self.feature_dim}, backbone outputs {actual_dim}" ) except Exception as e: print(f"[ResNetPTSVMModel][WARN] could not verify feature_dim: {e}") def preprocess(self, img: Image.Image) -> torch.Tensor: """ Apply the ImageNet-style transforms and return (1, 3, H, W) tensor on device. """ t = self.preprocess_tf(img) # (3, H, W) if t.ndim == 3: t = t.unsqueeze(0) return t.to(self.device) @staticmethod def _softmax(scores: np.ndarray) -> np.ndarray: scores = scores - np.max(scores) exp = np.exp(scores) return exp / np.sum(exp) def _extract_features(self, img: Image.Image) -> np.ndarray: """ Run image through ResNet backbone to get (1, D) feature vector. """ x = self.preprocess(img) with torch.no_grad(): feats = self.backbone(x) # (1, D) return feats.cpu().numpy() # (1, D) def predict( self, img: Image.Image, top_k: int = 5, ) -> Dict[str, Any]: """ Predict class for a single image. Returns: { "class_id": int, "class_name": str, "probabilities": {class_name: prob_float}, "top_k": [ {"class_id": int, "class_name": str, "probability": float}, ... ] } """ feats_np = self._extract_features(img) # (1, D) # LinearSVC has no predict_proba -> use decision_function scores = self.svm_head.decision_function(feats_np) if scores.ndim == 1: scores = scores[np.newaxis, :] scores = scores[0] # (C,) probs = self._softmax(scores) # (C,) pred_id = int(np.argmax(probs)) pred_name = self.id_to_name[pred_id] prob_dict: Dict[str, float] = { self.id_to_name[i]: float(p) for i, p in enumerate(probs) } sorted_indices = np.argsort(probs)[::-1] top_k = min(top_k, len(sorted_indices)) top_k_list: List[Dict[str, Any]] = [] for i in range(top_k): cid = int(sorted_indices[i]) top_k_list.append({ "class_id": cid, "class_name": self.id_to_name[cid], "probability": float(probs[cid]), }) return { "class_id": pred_id, "class_name": pred_name, "probabilities": prob_dict, "top_k": top_k_list, }