Spaces:
Sleeping
Sleeping
File size: 6,358 Bytes
52dd1ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# src/inference/resnet_pt_lr_model.py
import os
import json
from typing import Dict, Any, List, Optional
import numpy as np
from PIL import Image
import torch
from torchvision.models import resnet18, ResNet18_Weights
import joblib
class ResNetPTLRModel:
"""
End-to-end inference wrapper:
- ResNet18 (pretrained on ImageNet) as frozen backbone
- Logistic Regression head trained on extracted features
"""
def __init__(
self,
ckpt_path: str = "checkpoints/resnet_pt_lr_head.joblib",
labels_path: str = "configs/labels.json",
device: Optional[str] = None,
):
assert os.path.exists(ckpt_path), f"ResNet PT + LR checkpoint not found: {ckpt_path}"
assert os.path.exists(labels_path), f"Labels mapping not found: {labels_path}"
# Decide device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
print(f"[ResNetPTLRModel] Using device: {self.device}")
# --- Load LR head ---
print(f"[ResNetPTLRModel] Loading LR head from {ckpt_path} ...")
payload = joblib.load(ckpt_path)
# payload was saved as dict in train_resnet_pt_lr.py
if isinstance(payload, dict) and "model" in payload:
self.lr_head = payload["model"]
self.feature_dim = int(payload.get("feature_dim", 512))
self.backbone_name = payload.get("backbone", "resnet18_imagenet")
self.saved_labels_path = payload.get("labels_path", labels_path)
else:
# Fallback if someone saved the raw model
self.lr_head = payload
self.feature_dim = None
self.backbone_name = "resnet18_imagenet"
self.saved_labels_path = labels_path
# --- Load labels mapping ---
labels_file = self.saved_labels_path if os.path.exists(self.saved_labels_path) else labels_path
print(f"[ResNetPTLRModel] Loading labels from {labels_file} ...")
with open(labels_file, "r") as f:
id_to_name = json.load(f)
# ensure keys are ints
self.id_to_name: Dict[int, str] = {int(k): v for k, v in id_to_name.items()}
# --- Build ResNet18 backbone + preprocess (same as in feature extraction) ---
print("[ResNetPTLRModel] Building ResNet18 backbone ...")
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
import torch.nn as nn
model.fc = nn.Identity()
model.to(self.device)
model.eval()
self.backbone = model
self.preprocess_tf = weights.transforms()
# Optional: check feature_dim consistency if available
if self.feature_dim is not None:
try:
test_input = torch.zeros(1, 3, 224, 224).to(self.device)
with torch.no_grad():
out = self.backbone(test_input)
actual_dim = out.shape[1]
if actual_dim != self.feature_dim:
print(
f"[ResNetPTLRModel][WARN] feature_dim mismatch: "
f"head expects {self.feature_dim}, backbone outputs {actual_dim}"
)
except Exception as e:
print(f"[ResNetPTLRModel][WARN] could not verify feature_dim: {e}")
def preprocess(self, img: Image.Image) -> torch.Tensor:
"""
Apply ImageNet-style transforms and return a (1, 3, H, W) tensor on device.
"""
t = self.preprocess_tf(img) # (3, H, W)
if t.ndim == 3:
t = t.unsqueeze(0) # (1, 3, H, W)
return t.to(self.device)
@staticmethod
def _to_probabilities_from_logits(logits: np.ndarray) -> np.ndarray:
"""
Convert raw scores/logits to probabilities using softmax.
"""
logits = logits - np.max(logits)
exp = np.exp(logits)
return exp / np.sum(exp)
def _extract_features(self, img: Image.Image) -> np.ndarray:
"""
Run a PIL image through the backbone and get a (1, D) numpy feature vector.
"""
x = self.preprocess(img) # (1, 3, H, W)
with torch.no_grad():
feats = self.backbone(x) # (1, D)
feats_np = feats.cpu().numpy()
return feats_np # (1, D)
def predict(
self,
img: Image.Image,
top_k: int = 5,
) -> Dict[str, Any]:
"""
Predict class for a single image.
Returns:
{
"class_id": int,
"class_name": str,
"probabilities": {class_name: prob_float},
"top_k": [
{"class_id": int, "class_name": str, "probability": float},
...
]
}
"""
feats_np = self._extract_features(img) # (1, D)
# LR has predict_proba, use that directly
if hasattr(self.lr_head, "predict_proba"):
probs = self.lr_head.predict_proba(feats_np)[0] # (C,)
else:
# Fallback: use decision_function and softmax
scores = self.lr_head.decision_function(feats_np)
if scores.ndim == 1:
scores = scores[np.newaxis, :]
probs = self._to_probabilities_from_logits(scores[0])
pred_id = int(np.argmax(probs))
pred_name = self.id_to_name[pred_id]
# Full distribution
prob_dict: Dict[str, float] = {
self.id_to_name[i]: float(p)
for i, p in enumerate(probs)
}
# Top-k sorted
sorted_indices = np.argsort(probs)[::-1]
top_k = min(top_k, len(sorted_indices))
top_k_list: List[Dict[str, Any]] = []
for i in range(top_k):
cid = int(sorted_indices[i])
top_k_list.append({
"class_id": cid,
"class_name": self.id_to_name[cid],
"probability": float(probs[cid]),
})
return {
"class_id": pred_id,
"class_name": pred_name,
"probabilities": prob_dict,
"top_k": top_k_list,
}
|