# src/inference/resnet_pt_lr_model.py import os import json from typing import Dict, Any, List, Optional import numpy as np from PIL import Image import torch from torchvision.models import resnet18, ResNet18_Weights import joblib class ResNetPTLRModel: """ End-to-end inference wrapper: - ResNet18 (pretrained on ImageNet) as frozen backbone - Logistic Regression head trained on extracted features """ def __init__( self, ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib", labels_path: str = "configs/labels.json", device: Optional[str] = None, ): assert os.path.exists(ckpt_path), f"ResNet PT + LR checkpoint not found: {ckpt_path}" assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}" # Decide device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) print(f"[ResNetPTLRModel] Using device: {self.device}") # --- Load LR head --- print(f"[ResNetPTLRModel] Loading LR head from {ckpt_path} ...") payload = joblib.load(ckpt_path) # payload was saved as dict in train_resnet_pt_lr.py if isinstance(payload, dict) and "model" in payload: self.lr_head = payload["model"] self.feature_dim = int(payload.get("feature_dim", 512)) self.backbone_name = payload.get("backbone", "resnet18_imagenet") self.saved_labels_path = payload.get("labels_path", labels_path) else: # Fallback if someone saved the raw model self.lr_head = payload self.feature_dim = None self.backbone_name = "resnet18_imagenet" self.saved_labels_path = labels_path # --- Load labels mapping --- labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path print(f"[ResNetPTLRModel] Loading labels from {labels_file} ...") with open(labels_file, "r") as f: id_to_name = json.load(f) # ensure keys are ints self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()} # --- Build ResNet18 backbone + preprocess (same as in feature extraction) --- print("[ResNetPTLRModel] Building ResNet18 backbone ...") weights = ResNet18_Weights.DEFAULT model = resnet18(weights=weights) import torch.nn as nn model.fc = nn.Identity() model.to(self.device) model.eval() self.backbone = model self.preprocess_tf = weights.transforms() # Optional: check feature_dim consistency if available if self.feature_dim is not None: try: test_input = torch.zeros(1, 3, 224, 224).to(self.device) with torch.no_grad(): out = self.backbone(test_input) actual_dim = out.shape[1] if actual_dim != self.feature_dim: print( f"[ResNetPTLRModel][WARN] feature_dim mismatch: " f"head expects {self.feature_dim}, backbone outputs {actual_dim}" ) except Exception as e: print(f"[ResNetPTLRModel][WARN] could not verify feature_dim: {e}") def preprocess(self, img: Image.Image) -> torch.Tensor: """ Apply ImageNet-style transforms and return a (1, 3, H, W) tensor on device. """ t = self.preprocess_tf(img) # (3, H, W) if t.ndim == 3: t = t.unsqueeze(0) # (1, 3, H, W) return t.to(self.device) @staticmethod def _to_probabilities_from_logits(logits: np.ndarray) -> np.ndarray: """ Convert raw scores/logits to probabilities using softmax. """ logits = logits - np.max(logits) exp = np.exp(logits) return exp / np.sum(exp) def _extract_features(self, img: Image.Image) -> np.ndarray: """ Run a PIL image through the backbone and get a (1, D) numpy feature vector. """ x = self.preprocess(img) # (1, 3, H, W) with torch.no_grad(): feats = self.backbone(x) # (1, D) feats_np = feats.cpu().numpy() return feats_np # (1, D) def predict( self, img: Image.Image, top_k: int = 5, ) -> Dict[str, Any]: """ Predict class for a single image. Returns: { "class_id": int, "class_name": str, "probabilities": {class_name: prob_float}, "top_k": [ {"class_id": int, "class_name": str, "probability": float}, ... ] } """ feats_np = self._extract_features(img) # (1, D) # LR has predict_proba, use that directly if hasattr(self.lr_head, "predict_proba"): probs = self.lr_head.predict_proba(feats_np)[0] # (C,) else: # Fallback: use decision_function and softmax scores = self.lr_head.decision_function(feats_np) if scores.ndim == 1: scores = scores[np.newaxis, :] probs = self._to_probabilities_from_logits(scores[0]) pred_id = int(np.argmax(probs)) pred_name = self.id_to_name[pred_id] # Full distribution prob_dict: Dict[str, float] = { self.id_to_name[i]: float(p) for i, p in enumerate(probs) } # Top-k sorted sorted_indices = np.argsort(probs)[::-1] top_k = min(top_k, len(sorted_indices)) top_k_list: List[Dict[str, Any]] = [] for i in range(top_k): cid = int(sorted_indices[i]) top_k_list.append({ "class_id": cid, "class_name": self.id_to_name[cid], "probability": float(probs[cid]), }) return { "class_id": pred_id, "class_name": pred_name, "probabilities": prob_dict, "top_k": top_k_list, }