efnanaladagg's picture
Clean push
6f6eb85
# src/app/inference.py
import json
from pathlib import Path
from typing import List, Tuple, Dict, Any
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
# --- Paths (relative to project root) ---
ARTIFACTS_DIR = Path("artifacts")
CKPT_PATH = ARTIFACTS_DIR / "model.pt"
LABELS_PATH = ARTIFACTS_DIR / "label_names.json"
IMG_SIZE = 224
def get_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_label_names() -> List[str]:
if not LABELS_PATH.exists():
raise FileNotFoundError(f"Missing {LABELS_PATH}. Run training first to create artifacts.")
return json.loads(LABELS_PATH.read_text(encoding="utf-8"))
def build_model(num_classes: int) -> nn.Module:
# OFFLINE SAFE: no pretrained downloads
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def load_model() -> Tuple[nn.Module, List[str], torch.device]:
"""
Loads the trained model artifact and label names once.
Returns (model, label_names, device).
"""
if not CKPT_PATH.exists():
raise FileNotFoundError(f"Missing {CKPT_PATH}. Train and save model first.")
label_names = load_label_names()
num_classes = len(label_names)
device = get_device()
model = build_model(num_classes)
ckpt = torch.load(CKPT_PATH, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)
model.eval()
return model, label_names, device
def get_preprocess() -> transforms.Compose:
# Must match training/evaluation preprocessing
return transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
@torch.no_grad()
def predict_image(
model: nn.Module,
label_names: List[str],
device: torch.device,
image: Image.Image,
top_k: int = 3
) -> Dict[str, Any]:
"""
Predicts class probabilities for a single PIL image.
Returns predicted class, confidence, and top-k list.
"""
tf = get_preprocess()
x = tf(image.convert("RGB")).unsqueeze(0).to(device) # (1,3,H,W)
logits = model(x)
probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu()
pred_id = int(torch.argmax(probs).item())
pred_label = label_names[pred_id]
pred_conf = float(probs[pred_id].item())
k = min(top_k, len(label_names))
top = torch.topk(probs, k=k)
topk: List[Dict[str, float]] = []
for score, idx in zip(top.values.tolist(), top.indices.tolist()):
topk.append({"label": label_names[int(idx)], "confidence": float(score)})
# all_probs is sometimes useful for debugging/UI charts
all_probs = {label_names[i]: float(probs[i].item()) for i in range(len(label_names))}
return {
"predicted_class": pred_label,
"confidence": pred_conf,
"top_k": topk,
"all_probs": all_probs,
}
# --- IGNORE ---
# This module provides functions to load a trained ResNet18 model,
# preprocess images, and perform inference to obtain class predictions
# and confidence scores for the "comprehensive-car-damage" dataset.