Machine_learning_CS-6140 / src /inference /resnet_pt_svm_model.py
Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/inference/resnet_pt_svm_model.py
import os
import json
from typing import Dict, Any, List, Optional
import numpy as np
from PIL import Image
import torch
from torchvision.models import resnet18, ResNet18_Weights
import joblib
class ResNetPTSVMModel:
"""
ResNet18 (pretrained, frozen) + Linear SVM head.
Pipeline:
- PIL image
- ImageNet transforms
- ResNet18 backbone (fc -> Identity) -> feature vector
- Linear SVM decision_function
- Softmax over scores to get probabilities
"""
def __init__(
self,
ckpt_path: str = "checkpoints/resnet_pt_svm_head.joblib",
labels_path: str = "configs/labels.json",
device: Optional[str] = None,
):
assert os.path.exists(ckpt_path), f"ResNet PT + SVM checkpoint not found: {ckpt_path}"
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
# Device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"[ResNetPTSVMModel] Using device: {self.device}")
# --- Load SVM head ---
print(f"[ResNetPTSVMModel] Loading SVM head from {ckpt_path} ...")
payload = joblib.load(ckpt_path)
if isinstance(payload, dict) and "model" in payload:
self.svm_head = payload["model"]
self.feature_dim = int(payload.get("feature_dim", 512))
self.backbone_name = payload.get("backbone", "resnet18_imagenet")
self.saved_labels_path = payload.get("labels_path", labels_path)
else:
self.svm_head = payload
self.feature_dim = None
self.backbone_name = "resnet18_imagenet"
self.saved_labels_path = labels_path
# --- Load labels mapping ---
labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
print(f"[ResNetPTSVMModel] Loading labels from {labels_file} ...")
with open(labels_file, "r") as f:
id_to_name = json.load(f)
# ensure keys are ints
self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
# --- Build ResNet18 backbone + preprocess ---
print("[ResNetPTSVMModel] Building ResNet18 backbone ...")
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
import torch.nn as nn
model.fc = nn.Identity()
model.to(self.device)
model.eval()
self.backbone = model
self.preprocess_tf = weights.transforms()
# Optional: sanity check feature_dim
if self.feature_dim is not None:
try:
test_input = torch.zeros(1, 3, 224, 224).to(self.device)
with torch.no_grad():
out = self.backbone(test_input)
actual_dim = out.shape[1]
if actual_dim != self.feature_dim:
print(
f"[ResNetPTSVMModel][WARN] feature_dim mismatch: "
f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
)
except Exception as e:
print(f"[ResNetPTSVMModel][WARN] could not verify feature_dim: {e}")
def preprocess(self, img: Image.Image) -> torch.Tensor:
"""
Apply the ImageNet-style transforms and return (1, 3, H, W) tensor on device.
"""
t = self.preprocess_tf(img) # (3, H, W)
if t.ndim == 3:
t = t.unsqueeze(0)
return t.to(self.device)
@staticmethod
def _softmax(scores: np.ndarray) -> np.ndarray:
scores = scores - np.max(scores)
exp = np.exp(scores)
return exp / np.sum(exp)
def _extract_features(self, img: Image.Image) -> np.ndarray:
"""
Run image through ResNet backbone to get (1, D) feature vector.
"""
x = self.preprocess(img)
with torch.no_grad():
feats = self.backbone(x) # (1, D)
return feats.cpu().numpy() # (1, D)
def predict(
self,
img: Image.Image,
top_k: int = 5,
) -> Dict[str, Any]:
"""
Predict class for a single image.
Returns:
{
"class_id": int,
"class_name": str,
"probabilities": {class_name: prob_float},
"top_k": [
{"class_id": int, "class_name": str, "probability": float},
...
]
}
"""
feats_np = self._extract_features(img) # (1, D)
# LinearSVC has no predict_proba -> use decision_function
scores = self.svm_head.decision_function(feats_np)
if scores.ndim == 1:
scores = scores[np.newaxis, :]
scores = scores[0] # (C,)
probs = self._softmax(scores) # (C,)
pred_id = int(np.argmax(probs))
pred_name = self.id_to_name[pred_id]
prob_dict: Dict[str, float] = {
self.id_to_name[i]: float(p)
for i, p in enumerate(probs)
}
sorted_indices = np.argsort(probs)[::-1]
top_k = min(top_k, len(sorted_indices))
top_k_list: List[Dict[str, Any]] = []
for i in range(top_k):
cid = int(sorted_indices[i])
top_k_list.append({
"class_id": cid,
"class_name": self.id_to_name[cid],
"probability": float(probs[cid]),
})
return {
"class_id": pred_id,
"class_name": pred_name,
"probabilities": prob_dict,
"top_k": top_k_list,
}