logo_classifier / src /models /predictor.py
yoavraytz's picture
Deploy FastAPI logo classifier demo
ab794cc
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Any
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms
@dataclass
class Prediction:
pred: int
prob: float
def _fix_state_dict_keys(state: dict) -> dict:
# If checkpoint wraps the state_dict
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
# Strip common DataParallel prefix
fixed = {}
for k, v in state.items():
if k.startswith("module."):
fixed[k[len("module."):]] = v
else:
fixed[k] = v
return fixed
class Predictor:
def __init__(self, weights_path: str, threshold: float = 0.5, device: str | None = None):
self.threshold = float(threshold)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
# Build a plain ResNet18 to match training keys
self.net = models.resnet18(weights=None)
self.net.fc = nn.Linear(self.net.fc.in_features, 1)
self.net = self.net.to(self.device)
self.net.eval()
state = torch.load(weights_path, map_location=self.device)
state = _fix_state_dict_keys(state)
self.net.load_state_dict(state, strict=True)
self.tfm = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
@torch.no_grad()
def predict_pil(self, img: Image.Image) -> Prediction:
x = self.tfm(img.convert("RGB")).unsqueeze(0).to(self.device)
logits = self.net(x)
prob = torch.sigmoid(logits).item()
pred = 1 if prob >= self.threshold else 0
return Prediction(pred=int(pred), prob=float(prob))
def info(self) -> Dict[str, Any]:
return {
"device": self.device,
"threshold": self.threshold,
"cuda_available": bool(torch.cuda.is_available()),
}