Machine_learning_CS-6140 / src /inference /resnet_pt_lr_model.py
Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/inference/resnet_pt_lr_model.py
import os
import json
from typing import Dict, Any, List, Optional
import numpy as np
from PIL import Image
import torch
from torchvision.models import resnet18, ResNet18_Weights
import joblib
class ResNetPTLRModel:
"""
End-to-end inference wrapper:
- ResNet18 (pretrained on ImageNet) as frozen backbone
- Logistic Regression head trained on extracted features
"""
def __init__(
self,
ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib",
labels_path: str = "configs/labels.json",
device: Optional[str] = None,
):
assert os.path.exists(ckpt_path), f"ResNet PT + LR checkpoint not found: {ckpt_path}"
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
# Decide device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"[ResNetPTLRModel] Using device: {self.device}")
# --- Load LR head ---
print(f"[ResNetPTLRModel] Loading LR head from {ckpt_path} ...")
payload = joblib.load(ckpt_path)
# payload was saved as dict in train_resnet_pt_lr.py
if isinstance(payload, dict) and "model" in payload:
self.lr_head = payload["model"]
self.feature_dim = int(payload.get("feature_dim", 512))
self.backbone_name = payload.get("backbone", "resnet18_imagenet")
self.saved_labels_path = payload.get("labels_path", labels_path)
else:
# Fallback if someone saved the raw model
self.lr_head = payload
self.feature_dim = None
self.backbone_name = "resnet18_imagenet"
self.saved_labels_path = labels_path
# --- Load labels mapping ---
labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
print(f"[ResNetPTLRModel] Loading labels from {labels_file} ...")
with open(labels_file, "r") as f:
id_to_name = json.load(f)
# ensure keys are ints
self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
# --- Build ResNet18 backbone + preprocess (same as in feature extraction) ---
print("[ResNetPTLRModel] Building ResNet18 backbone ...")
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
import torch.nn as nn
model.fc = nn.Identity()
model.to(self.device)
model.eval()
self.backbone = model
self.preprocess_tf = weights.transforms()
# Optional: check feature_dim consistency if available
if self.feature_dim is not None:
try:
test_input = torch.zeros(1, 3, 224, 224).to(self.device)
with torch.no_grad():
out = self.backbone(test_input)
actual_dim = out.shape[1]
if actual_dim != self.feature_dim:
print(
f"[ResNetPTLRModel][WARN] feature_dim mismatch: "
f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
)
except Exception as e:
print(f"[ResNetPTLRModel][WARN] could not verify feature_dim: {e}")
def preprocess(self, img: Image.Image) -> torch.Tensor:
"""
Apply ImageNet-style transforms and return a (1, 3, H, W) tensor on device.
"""
t = self.preprocess_tf(img) # (3, H, W)
if t.ndim == 3:
t = t.unsqueeze(0) # (1, 3, H, W)
return t.to(self.device)
@staticmethod
def _to_probabilities_from_logits(logits: np.ndarray) -> np.ndarray:
"""
Convert raw scores/logits to probabilities using softmax.
"""
logits = logits - np.max(logits)
exp = np.exp(logits)
return exp / np.sum(exp)
def _extract_features(self, img: Image.Image) -> np.ndarray:
"""
Run a PIL image through the backbone and get a (1, D) numpy feature vector.
"""
x = self.preprocess(img) # (1, 3, H, W)
with torch.no_grad():
feats = self.backbone(x) # (1, D)
feats_np = feats.cpu().numpy()
return feats_np # (1, D)
def predict(
self,
img: Image.Image,
top_k: int = 5,
) -> Dict[str, Any]:
"""
Predict class for a single image.
Returns:
{
"class_id": int,
"class_name": str,
"probabilities": {class_name: prob_float},
"top_k": [
{"class_id": int, "class_name": str, "probability": float},
...
]
}
"""
feats_np = self._extract_features(img) # (1, D)
# LR has predict_proba, use that directly
if hasattr(self.lr_head, "predict_proba"):
probs = self.lr_head.predict_proba(feats_np)[0] # (C,)
else:
# Fallback: use decision_function and softmax
scores = self.lr_head.decision_function(feats_np)
if scores.ndim == 1:
scores = scores[np.newaxis, :]
probs = self._to_probabilities_from_logits(scores[0])
pred_id = int(np.argmax(probs))
pred_name = self.id_to_name[pred_id]
# Full distribution
prob_dict: Dict[str, float] = {
self.id_to_name[i]: float(p)
for i, p in enumerate(probs)
}
# Top-k sorted
sorted_indices = np.argsort(probs)[::-1]
top_k = min(top_k, len(sorted_indices))
top_k_list: List[Dict[str, Any]] = []
for i in range(top_k):
cid = int(sorted_indices[i])
top_k_list.append({
"class_id": cid,
"class_name": self.id_to_name[cid],
"probability": float(probs[cid]),
})
return {
"class_id": pred_id,
"class_name": pred_name,
"probabilities": prob_dict,
"top_k": top_k_list,
}