from __future__ import annotations from dataclasses import dataclass from typing import Dict, Any import torch import torch.nn as nn from PIL import Image from torchvision import models, transforms @dataclass class Prediction: pred: int prob: float def _fix_state_dict_keys(state: dict) -> dict: # If checkpoint wraps the state_dict if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] # Strip common DataParallel prefix fixed = {} for k, v in state.items(): if k.startswith("module."): fixed[k[len("module."):]] = v else: fixed[k] = v return fixed class Predictor: def __init__(self, weights_path: str, threshold: float = 0.5, device: str | None = None): self.threshold = float(threshold) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device # Build a plain ResNet18 to match training keys self.net = models.resnet18(weights=None) self.net.fc = nn.Linear(self.net.fc.in_features, 1) self.net = self.net.to(self.device) self.net.eval() state = torch.load(weights_path, map_location=self.device) state = _fix_state_dict_keys(state) self.net.load_state_dict(state, strict=True) self.tfm = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ]) @torch.no_grad() def predict_pil(self, img: Image.Image) -> Prediction: x = self.tfm(img.convert("RGB")).unsqueeze(0).to(self.device) logits = self.net(x) prob = torch.sigmoid(logits).item() pred = 1 if prob >= self.threshold else 0 return Prediction(pred=int(pred), prob=float(prob)) def info(self) -> Dict[str, Any]: return { "device": self.device, "threshold": self.threshold, "cuda_available": bool(torch.cuda.is_available()), }