Spaces:
Sleeping
Sleeping
| # src/inference/resnet_pt_lr_model.py | |
| import os | |
| import json | |
| from typing import Dict, Any, List, Optional | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| import joblib | |
| class ResNetPTLRModel: | |
| """ | |
| End-to-end inference wrapper: | |
| - ResNet18 (pretrained on ImageNet) as frozen backbone | |
| - Logistic Regression head trained on extracted features | |
| """ | |
| def __init__( | |
| self, | |
| ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib", | |
| labels_path: str = "configs/labels.json", | |
| device: Optional[str] = None, | |
| ): | |
| assert os.path.exists(ckpt_path), f"ResNet PT + LR checkpoint not found: {ckpt_path}" | |
| assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}" | |
| # Decide device | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| print(f"[ResNetPTLRModel] Using device: {self.device}") | |
| # --- Load LR head --- | |
| print(f"[ResNetPTLRModel] Loading LR head from {ckpt_path} ...") | |
| payload = joblib.load(ckpt_path) | |
| # payload was saved as dict in train_resnet_pt_lr.py | |
| if isinstance(payload, dict) and "model" in payload: | |
| self.lr_head = payload["model"] | |
| self.feature_dim = int(payload.get("feature_dim", 512)) | |
| self.backbone_name = payload.get("backbone", "resnet18_imagenet") | |
| self.saved_labels_path = payload.get("labels_path", labels_path) | |
| else: | |
| # Fallback if someone saved the raw model | |
| self.lr_head = payload | |
| self.feature_dim = None | |
| self.backbone_name = "resnet18_imagenet" | |
| self.saved_labels_path = labels_path | |
| # --- Load labels mapping --- | |
| labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path | |
| print(f"[ResNetPTLRModel] Loading labels from {labels_file} ...") | |
| with open(labels_file, "r") as f: | |
| id_to_name = json.load(f) | |
| # ensure keys are ints | |
| self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()} | |
| # --- Build ResNet18 backbone + preprocess (same as in feature extraction) --- | |
| print("[ResNetPTLRModel] Building ResNet18 backbone ...") | |
| weights = ResNet18_Weights.DEFAULT | |
| model = resnet18(weights=weights) | |
| import torch.nn as nn | |
| model.fc = nn.Identity() | |
| model.to(self.device) | |
| model.eval() | |
| self.backbone = model | |
| self.preprocess_tf = weights.transforms() | |
| # Optional: check feature_dim consistency if available | |
| if self.feature_dim is not None: | |
| try: | |
| test_input = torch.zeros(1, 3, 224, 224).to(self.device) | |
| with torch.no_grad(): | |
| out = self.backbone(test_input) | |
| actual_dim = out.shape[1] | |
| if actual_dim != self.feature_dim: | |
| print( | |
| f"[ResNetPTLRModel][WARN] feature_dim mismatch: " | |
| f"head expects {self.feature_dim}, backbone outputs {actual_dim}" | |
| ) | |
| except Exception as e: | |
| print(f"[ResNetPTLRModel][WARN] could not verify feature_dim: {e}") | |
| def preprocess(self, img: Image.Image) -> torch.Tensor: | |
| """ | |
| Apply ImageNet-style transforms and return a (1, 3, H, W) tensor on device. | |
| """ | |
| t = self.preprocess_tf(img) # (3, H, W) | |
| if t.ndim == 3: | |
| t = t.unsqueeze(0) # (1, 3, H, W) | |
| return t.to(self.device) | |
| def _to_probabilities_from_logits(logits: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert raw scores/logits to probabilities using softmax. | |
| """ | |
| logits = logits - np.max(logits) | |
| exp = np.exp(logits) | |
| return exp / np.sum(exp) | |
| def _extract_features(self, img: Image.Image) -> np.ndarray: | |
| """ | |
| Run a PIL image through the backbone and get a (1, D) numpy feature vector. | |
| """ | |
| x = self.preprocess(img) # (1, 3, H, W) | |
| with torch.no_grad(): | |
| feats = self.backbone(x) # (1, D) | |
| feats_np = feats.cpu().numpy() | |
| return feats_np # (1, D) | |
| def predict( | |
| self, | |
| img: Image.Image, | |
| top_k: int = 5, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Predict class for a single image. | |
| Returns: | |
| { | |
| "class_id": int, | |
| "class_name": str, | |
| "probabilities": {class_name: prob_float}, | |
| "top_k": [ | |
| {"class_id": int, "class_name": str, "probability": float}, | |
| ... | |
| ] | |
| } | |
| """ | |
| feats_np = self._extract_features(img) # (1, D) | |
| # LR has predict_proba, use that directly | |
| if hasattr(self.lr_head, "predict_proba"): | |
| probs = self.lr_head.predict_proba(feats_np)[0] # (C,) | |
| else: | |
| # Fallback: use decision_function and softmax | |
| scores = self.lr_head.decision_function(feats_np) | |
| if scores.ndim == 1: | |
| scores = scores[np.newaxis, :] | |
| probs = self._to_probabilities_from_logits(scores[0]) | |
| pred_id = int(np.argmax(probs)) | |
| pred_name = self.id_to_name[pred_id] | |
| # Full distribution | |
| prob_dict: Dict[str, float] = { | |
| self.id_to_name[i]: float(p) | |
| for i, p in enumerate(probs) | |
| } | |
| # Top-k sorted | |
| sorted_indices = np.argsort(probs)[::-1] | |
| top_k = min(top_k, len(sorted_indices)) | |
| top_k_list: List[Dict[str, Any]] = [] | |
| for i in range(top_k): | |
| cid = int(sorted_indices[i]) | |
| top_k_list.append({ | |
| "class_id": cid, | |
| "class_name": self.id_to_name[cid], | |
| "probability": float(probs[cid]), | |
| }) | |
| return { | |
| "class_id": pred_id, | |
| "class_name": pred_name, | |
| "probabilities": prob_dict, | |
| "top_k": top_k_list, | |
| } | |