Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, Any | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| class Prediction: | |
| pred: int | |
| prob: float | |
| def _fix_state_dict_keys(state: dict) -> dict: | |
| # If checkpoint wraps the state_dict | |
| if isinstance(state, dict) and "state_dict" in state: | |
| state = state["state_dict"] | |
| # Strip common DataParallel prefix | |
| fixed = {} | |
| for k, v in state.items(): | |
| if k.startswith("module."): | |
| fixed[k[len("module."):]] = v | |
| else: | |
| fixed[k] = v | |
| return fixed | |
| class Predictor: | |
| def __init__(self, weights_path: str, threshold: float = 0.5, device: str | None = None): | |
| self.threshold = float(threshold) | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = device | |
| # Build a plain ResNet18 to match training keys | |
| self.net = models.resnet18(weights=None) | |
| self.net.fc = nn.Linear(self.net.fc.in_features, 1) | |
| self.net = self.net.to(self.device) | |
| self.net.eval() | |
| state = torch.load(weights_path, map_location=self.device) | |
| state = _fix_state_dict_keys(state) | |
| self.net.load_state_dict(state, strict=True) | |
| self.tfm = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| def predict_pil(self, img: Image.Image) -> Prediction: | |
| x = self.tfm(img.convert("RGB")).unsqueeze(0).to(self.device) | |
| logits = self.net(x) | |
| prob = torch.sigmoid(logits).item() | |
| pred = 1 if prob >= self.threshold else 0 | |
| return Prediction(pred=int(pred), prob=float(prob)) | |
| def info(self) -> Dict[str, Any]: | |
| return { | |
| "device": self.device, | |
| "threshold": self.threshold, | |
| "cuda_available": bool(torch.cuda.is_available()), | |
| } | |