Spaces:
Sleeping
Sleeping
| # src/inference/resnet_pt_svm_model.py | |
| import os | |
| import json | |
| from typing import Dict, Any, List, Optional | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| import joblib | |
| class ResNetPTSVMModel: | |
| """ | |
| ResNet18 (pretrained, frozen) + Linear SVM head. | |
| Pipeline: | |
| - PIL image | |
| - ImageNet transforms | |
| - ResNet18 backbone (fc -> Identity) -> feature vector | |
| - Linear SVM decision_function | |
| - Softmax over scores to get probabilities | |
| """ | |
| def __init__( | |
| self, | |
| ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib", | |
| labels_path: str = "configs/labels.json", | |
| device: Optional[str] = None, | |
| ): | |
| assert os.path.exists(ckpt_path), f"ResNet PT + SVM checkpoint not found: {ckpt_path}" | |
| assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}" | |
| # Device | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| print(f"[ResNetPTSVMModel] Using device: {self.device}") | |
| # --- Load SVM head --- | |
| print(f"[ResNetPTSVMModel] Loading SVM head from {ckpt_path} ...") | |
| payload = joblib.load(ckpt_path) | |
| if isinstance(payload, dict) and "model" in payload: | |
| self.svm_head = payload["model"] | |
| self.feature_dim = int(payload.get("feature_dim", 512)) | |
| self.backbone_name = payload.get("backbone", "resnet18_imagenet") | |
| self.saved_labels_path = payload.get("labels_path", labels_path) | |
| else: | |
| self.svm_head = payload | |
| self.feature_dim = None | |
| self.backbone_name = "resnet18_imagenet" | |
| self.saved_labels_path = labels_path | |
| # --- Load labels mapping --- | |
| labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path | |
| print(f"[ResNetPTSVMModel] Loading labels from {labels_file} ...") | |
| with open(labels_file, "r") as f: | |
| id_to_name = json.load(f) | |
| # ensure keys are ints | |
| self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()} | |
| # --- Build ResNet18 backbone + preprocess --- | |
| print("[ResNetPTSVMModel] Building ResNet18 backbone ...") | |
| weights = ResNet18_Weights.DEFAULT | |
| model = resnet18(weights=weights) | |
| import torch.nn as nn | |
| model.fc = nn.Identity() | |
| model.to(self.device) | |
| model.eval() | |
| self.backbone = model | |
| self.preprocess_tf = weights.transforms() | |
| # Optional: sanity check feature_dim | |
| if self.feature_dim is not None: | |
| try: | |
| test_input = torch.zeros(1, 3, 224, 224).to(self.device) | |
| with torch.no_grad(): | |
| out = self.backbone(test_input) | |
| actual_dim = out.shape[1] | |
| if actual_dim != self.feature_dim: | |
| print( | |
| f"[ResNetPTSVMModel][WARN] feature_dim mismatch: " | |
| f"head expects {self.feature_dim}, backbone outputs {actual_dim}" | |
| ) | |
| except Exception as e: | |
| print(f"[ResNetPTSVMModel][WARN] could not verify feature_dim: {e}") | |
| def preprocess(self, img: Image.Image) -> torch.Tensor: | |
| """ | |
| Apply the ImageNet-style transforms and return (1, 3, H, W) tensor on device. | |
| """ | |
| t = self.preprocess_tf(img) # (3, H, W) | |
| if t.ndim == 3: | |
| t = t.unsqueeze(0) | |
| return t.to(self.device) | |
| def _softmax(scores: np.ndarray) -> np.ndarray: | |
| scores = scores - np.max(scores) | |
| exp = np.exp(scores) | |
| return exp / np.sum(exp) | |
| def _extract_features(self, img: Image.Image) -> np.ndarray: | |
| """ | |
| Run image through ResNet backbone to get (1, D) feature vector. | |
| """ | |
| x = self.preprocess(img) | |
| with torch.no_grad(): | |
| feats = self.backbone(x) # (1, D) | |
| return feats.cpu().numpy() # (1, D) | |
| def predict( | |
| self, | |
| img: Image.Image, | |
| top_k: int = 5, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Predict class for a single image. | |
| Returns: | |
| { | |
| "class_id": int, | |
| "class_name": str, | |
| "probabilities": {class_name: prob_float}, | |
| "top_k": [ | |
| {"class_id": int, "class_name": str, "probability": float}, | |
| ... | |
| ] | |
| } | |
| """ | |
| feats_np = self._extract_features(img) # (1, D) | |
| # LinearSVC has no predict_proba -> use decision_function | |
| scores = self.svm_head.decision_function(feats_np) | |
| if scores.ndim == 1: | |
| scores = scores[np.newaxis, :] | |
| scores = scores[0] # (C,) | |
| probs = self._softmax(scores) # (C,) | |
| pred_id = int(np.argmax(probs)) | |
| pred_name = self.id_to_name[pred_id] | |
| prob_dict: Dict[str, float] = { | |
| self.id_to_name[i]: float(p) | |
| for i, p in enumerate(probs) | |
| } | |
| sorted_indices = np.argsort(probs)[::-1] | |
| top_k = min(top_k, len(sorted_indices)) | |
| top_k_list: List[Dict[str, Any]] = [] | |
| for i in range(top_k): | |
| cid = int(sorted_indices[i]) | |
| top_k_list.append({ | |
| "class_id": cid, | |
| "class_name": self.id_to_name[cid], | |
| "probability": float(probs[cid]), | |
| }) | |
| return { | |
| "class_id": pred_id, | |
| "class_name": pred_name, | |
| "probabilities": prob_dict, | |
| "top_k": top_k_list, | |
| } | |