File size: 3,210 Bytes
85225b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""

Standalone inference module for the Hugging Face model repo.



This file is uploaded as ``load_model.py`` next to ``model.pth``. Copy-paste friendly:

only needs ``torch``, ``torchvision``, ``pillow``, ``huggingface_hub``.

"""
from __future__ import annotations

import os

import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torchvision import models, transforms
from PIL import Image

IMAGE_SIZE = 299
CLASS_NAMES = ("No-Stroke", "Stroke")


class StrokeHead(nn.Sequential):
    def __init__(self, in_features: int, num_classes: int = 2):
        super().__init__(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )


def build_model(num_classes: int = 2) -> nn.Module:
    m = models.efficientnet_b0(weights=None)
    in_features = m.classifier[1].in_features
    m.classifier = StrokeHead(in_features, num_classes)
    return m


def _normalize_state_dict(state: dict) -> dict:
    if any(k.startswith("model.") for k in state):
        state = {k.replace("model.", ""): v for k, v in state.items()}
    if any(k.startswith("module.") for k in state):
        state = {k.replace("module.", ""): v for k, v in state.items()}
    return state


def _unwrap(raw: object) -> dict:
    if isinstance(raw, dict) and "state_dict" in raw:
        return dict(raw["state_dict"])
    if isinstance(raw, dict) and "model_state_dict" in raw:
        return dict(raw["model_state_dict"])
    if not isinstance(raw, dict):
        raise TypeError(f"Expected state dict, got {type(raw)}")
    return raw


def eval_transform(image_size: int = IMAGE_SIZE) -> transforms.Compose:
    return transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )


def load_model(

    repo_id: str,

    *,

    weights_name: str = "model.pth",

    image_size: int = IMAGE_SIZE,

    token: str | None = None,

):
    """

    Download weights from the Hub, build EfficientNet-B0 + head, return ``(model, transform)``.

    """
    tok = token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
    path = hf_hub_download(
        repo_id=repo_id,
        filename=weights_name,
        repo_type="model",
        token=tok,
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    raw = torch.load(path, map_location=device, weights_only=False)
    state = _normalize_state_dict(_unwrap(raw))

    model = build_model()
    model.load_state_dict(state, strict=True)
    model.to(device)
    model.eval()
    return model, eval_transform(image_size)


def predict_proba(model: nn.Module, tfm: transforms.Compose, img: Image.Image):
    device = next(model.parameters()).device
    x = tfm(img.convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        p = torch.softmax(model(x), dim=1)[0]
    return {CLASS_NAMES[i]: p[i].item() for i in range(len(CLASS_NAMES))}