|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
ARTIFACTS_DIR = Path("artifacts") |
|
|
CKPT_PATH = ARTIFACTS_DIR / "model.pt" |
|
|
LABELS_PATH = ARTIFACTS_DIR / "label_names.json" |
|
|
|
|
|
IMG_SIZE = 224 |
|
|
|
|
|
|
|
|
def get_device() -> torch.device: |
|
|
return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def load_label_names() -> List[str]: |
|
|
if not LABELS_PATH.exists(): |
|
|
raise FileNotFoundError(f"Missing {LABELS_PATH}. Run training first to create artifacts.") |
|
|
return json.loads(LABELS_PATH.read_text(encoding="utf-8")) |
|
|
|
|
|
|
|
|
def build_model(num_classes: int) -> nn.Module: |
|
|
|
|
|
model = models.resnet18(weights=None) |
|
|
model.fc = nn.Linear(model.fc.in_features, num_classes) |
|
|
return model |
|
|
|
|
|
|
|
|
def load_model() -> Tuple[nn.Module, List[str], torch.device]: |
|
|
""" |
|
|
Loads the trained model artifact and label names once. |
|
|
Returns (model, label_names, device). |
|
|
""" |
|
|
if not CKPT_PATH.exists(): |
|
|
raise FileNotFoundError(f"Missing {CKPT_PATH}. Train and save model first.") |
|
|
|
|
|
label_names = load_label_names() |
|
|
num_classes = len(label_names) |
|
|
|
|
|
device = get_device() |
|
|
model = build_model(num_classes) |
|
|
|
|
|
ckpt = torch.load(CKPT_PATH, map_location="cpu") |
|
|
model.load_state_dict(ckpt["model_state_dict"]) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model, label_names, device |
|
|
|
|
|
|
|
|
def get_preprocess() -> transforms.Compose: |
|
|
|
|
|
return transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_image( |
|
|
model: nn.Module, |
|
|
label_names: List[str], |
|
|
device: torch.device, |
|
|
image: Image.Image, |
|
|
top_k: int = 3 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Predicts class probabilities for a single PIL image. |
|
|
Returns predicted class, confidence, and top-k list. |
|
|
""" |
|
|
tf = get_preprocess() |
|
|
x = tf(image.convert("RGB")).unsqueeze(0).to(device) |
|
|
|
|
|
logits = model(x) |
|
|
probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu() |
|
|
|
|
|
pred_id = int(torch.argmax(probs).item()) |
|
|
pred_label = label_names[pred_id] |
|
|
pred_conf = float(probs[pred_id].item()) |
|
|
|
|
|
k = min(top_k, len(label_names)) |
|
|
top = torch.topk(probs, k=k) |
|
|
topk: List[Dict[str, float]] = [] |
|
|
for score, idx in zip(top.values.tolist(), top.indices.tolist()): |
|
|
topk.append({"label": label_names[int(idx)], "confidence": float(score)}) |
|
|
|
|
|
|
|
|
all_probs = {label_names[i]: float(probs[i].item()) for i in range(len(label_names))} |
|
|
|
|
|
return { |
|
|
"predicted_class": pred_label, |
|
|
"confidence": pred_conf, |
|
|
"top_k": topk, |
|
|
"all_probs": all_probs, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|