File size: 2,149 Bytes
ab794cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Any

import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms


@dataclass
class Prediction:
    pred: int
    prob: float


def _fix_state_dict_keys(state: dict) -> dict:
    # If checkpoint wraps the state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]

    # Strip common DataParallel prefix
    fixed = {}
    for k, v in state.items():
        if k.startswith("module."):
            fixed[k[len("module."):]] = v
        else:
            fixed[k] = v
    return fixed


class Predictor:
    def __init__(self, weights_path: str, threshold: float = 0.5, device: str | None = None):
        self.threshold = float(threshold)

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device

        # Build a plain ResNet18 to match training keys
        self.net = models.resnet18(weights=None)
        self.net.fc = nn.Linear(self.net.fc.in_features, 1)
        self.net = self.net.to(self.device)
        self.net.eval()

        state = torch.load(weights_path, map_location=self.device)
        state = _fix_state_dict_keys(state)
        self.net.load_state_dict(state, strict=True)

        self.tfm = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

    @torch.no_grad()
    def predict_pil(self, img: Image.Image) -> Prediction:
        x = self.tfm(img.convert("RGB")).unsqueeze(0).to(self.device)
        logits = self.net(x)
        prob = torch.sigmoid(logits).item()
        pred = 1 if prob >= self.threshold else 0
        return Prediction(pred=int(pred), prob=float(prob))

    def info(self) -> Dict[str, Any]:
        return {
            "device": self.device,
            "threshold": self.threshold,
            "cuda_available": bool(torch.cuda.is_available()),
        }