Spaces:
Sleeping
Sleeping
File size: 5,884 Bytes
52dd1ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# src/inference/resnet_pt_svm_model.py
import os
import json
from typing import Dict, Any, List, Optional
import numpy as np
from PIL import Image
import torch
from torchvision.models import resnet18, ResNet18_Weights
import joblib
class ResNetPTSVMModel:
"""
ResNet18 (pretrained, frozen) + Linear SVM head.
Pipeline:
- PIL image
- ImageNet transforms
- ResNet18 backbone (fc -> Identity) -> feature vector
- Linear SVM decision_function
- Softmax over scores to get probabilities
"""
def __init__(
self,
ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib",
labels_path: str = "configs/labels.json",
device: Optional[str] = None,
):
assert os.path.exists(ckpt_path), f"ResNet PT + SVM checkpoint not found: {ckpt_path}"
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
# Device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"[ResNetPTSVMModel] Using device: {self.device}")
# --- Load SVM head ---
print(f"[ResNetPTSVMModel] Loading SVM head from {ckpt_path} ...")
payload = joblib.load(ckpt_path)
if isinstance(payload, dict) and "model" in payload:
self.svm_head = payload["model"]
self.feature_dim = int(payload.get("feature_dim", 512))
self.backbone_name = payload.get("backbone", "resnet18_imagenet")
self.saved_labels_path = payload.get("labels_path", labels_path)
else:
self.svm_head = payload
self.feature_dim = None
self.backbone_name = "resnet18_imagenet"
self.saved_labels_path = labels_path
# --- Load labels mapping ---
labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
print(f"[ResNetPTSVMModel] Loading labels from {labels_file} ...")
with open(labels_file, "r") as f:
id_to_name = json.load(f)
# ensure keys are ints
self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
# --- Build ResNet18 backbone + preprocess ---
print("[ResNetPTSVMModel] Building ResNet18 backbone ...")
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
import torch.nn as nn
model.fc = nn.Identity()
model.to(self.device)
model.eval()
self.backbone = model
self.preprocess_tf = weights.transforms()
# Optional: sanity check feature_dim
if self.feature_dim is not None:
try:
test_input = torch.zeros(1, 3, 224, 224).to(self.device)
with torch.no_grad():
out = self.backbone(test_input)
actual_dim = out.shape[1]
if actual_dim != self.feature_dim:
print(
f"[ResNetPTSVMModel][WARN] feature_dim mismatch: "
f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
)
except Exception as e:
print(f"[ResNetPTSVMModel][WARN] could not verify feature_dim: {e}")
def preprocess(self, img: Image.Image) -> torch.Tensor:
"""
Apply the ImageNet-style transforms and return (1, 3, H, W) tensor on device.
"""
t = self.preprocess_tf(img) # (3, H, W)
if t.ndim == 3:
t = t.unsqueeze(0)
return t.to(self.device)
@staticmethod
def _softmax(scores: np.ndarray) -> np.ndarray:
scores = scores - np.max(scores)
exp = np.exp(scores)
return exp / np.sum(exp)
def _extract_features(self, img: Image.Image) -> np.ndarray:
"""
Run image through ResNet backbone to get (1, D) feature vector.
"""
x = self.preprocess(img)
with torch.no_grad():
feats = self.backbone(x) # (1, D)
return feats.cpu().numpy() # (1, D)
def predict(
self,
img: Image.Image,
top_k: int = 5,
) -> Dict[str, Any]:
"""
Predict class for a single image.
Returns:
{
"class_id": int,
"class_name": str,
"probabilities": {class_name: prob_float},
"top_k": [
{"class_id": int, "class_name": str, "probability": float},
...
]
}
"""
feats_np = self._extract_features(img) # (1, D)
# LinearSVC has no predict_proba -> use decision_function
scores = self.svm_head.decision_function(feats_np)
if scores.ndim == 1:
scores = scores[np.newaxis, :]
scores = scores[0] # (C,)
probs = self._softmax(scores) # (C,)
pred_id = int(np.argmax(probs))
pred_name = self.id_to_name[pred_id]
prob_dict: Dict[str, float] = {
self.id_to_name[i]: float(p)
for i, p in enumerate(probs)
}
sorted_indices = np.argsort(probs)[::-1]
top_k = min(top_k, len(sorted_indices))
top_k_list: List[Dict[str, Any]] = []
for i in range(top_k):
cid = int(sorted_indices[i])
top_k_list.append({
"class_id": cid,
"class_name": self.id_to_name[cid],
"probability": float(probs[cid]),
})
return {
"class_id": pred_id,
"class_name": pred_name,
"probabilities": prob_dict,
"top_k": top_k_list,
}
|