# src/app/inference.py import json from pathlib import Path from typing import List, Tuple, Dict, Any import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image # --- Paths (relative to project root) --- ARTIFACTS_DIR = Path("artifacts") CKPT_PATH = ARTIFACTS_DIR / "model.pt" LABELS_PATH = ARTIFACTS_DIR / "label_names.json" IMG_SIZE = 224 def get_device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_label_names() -> List[str]: if not LABELS_PATH.exists(): raise FileNotFoundError(f"Missing {LABELS_PATH}. Run training first to create artifacts.") return json.loads(LABELS_PATH.read_text(encoding="utf-8")) def build_model(num_classes: int) -> nn.Module: # OFFLINE SAFE: no pretrained downloads model = models.resnet18(weights=None) model.fc = nn.Linear(model.fc.in_features, num_classes) return model def load_model() -> Tuple[nn.Module, List[str], torch.device]: """ Loads the trained model artifact and label names once. Returns (model, label_names, device). """ if not CKPT_PATH.exists(): raise FileNotFoundError(f"Missing {CKPT_PATH}. Train and save model first.") label_names = load_label_names() num_classes = len(label_names) device = get_device() model = build_model(num_classes) ckpt = torch.load(CKPT_PATH, map_location="cpu") model.load_state_dict(ckpt["model_state_dict"]) model.to(device) model.eval() return model, label_names, device def get_preprocess() -> transforms.Compose: # Must match training/evaluation preprocessing return transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @torch.no_grad() def predict_image( model: nn.Module, label_names: List[str], device: torch.device, image: Image.Image, top_k: int = 3 ) -> Dict[str, Any]: """ Predicts class probabilities for a single PIL image. Returns predicted class, confidence, and top-k list. """ tf = get_preprocess() x = tf(image.convert("RGB")).unsqueeze(0).to(device) # (1,3,H,W) logits = model(x) probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu() pred_id = int(torch.argmax(probs).item()) pred_label = label_names[pred_id] pred_conf = float(probs[pred_id].item()) k = min(top_k, len(label_names)) top = torch.topk(probs, k=k) topk: List[Dict[str, float]] = [] for score, idx in zip(top.values.tolist(), top.indices.tolist()): topk.append({"label": label_names[int(idx)], "confidence": float(score)}) # all_probs is sometimes useful for debugging/UI charts all_probs = {label_names[i]: float(probs[i].item()) for i in range(len(label_names))} return { "predicted_class": pred_label, "confidence": pred_conf, "top_k": topk, "all_probs": all_probs, } # --- IGNORE --- # This module provides functions to load a trained ResNet18 model, # preprocess images, and perform inference to obtain class predictions # and confidence scores for the "comprehensive-car-damage" dataset.