melisklc0's picture
Upload load_model.py with huggingface_hub
85225b8 verified
"""
Standalone inference module for the Hugging Face model repo.
This file is uploaded as ``load_model.py`` next to ``model.pth``. Copy-paste friendly:
only needs ``torch``, ``torchvision``, ``pillow``, ``huggingface_hub``.
"""
from __future__ import annotations
import os
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torchvision import models, transforms
from PIL import Image
IMAGE_SIZE = 299
CLASS_NAMES = ("No-Stroke", "Stroke")
class StrokeHead(nn.Sequential):
def __init__(self, in_features: int, num_classes: int = 2):
super().__init__(
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, num_classes),
)
def build_model(num_classes: int = 2) -> nn.Module:
m = models.efficientnet_b0(weights=None)
in_features = m.classifier[1].in_features
m.classifier = StrokeHead(in_features, num_classes)
return m
def _normalize_state_dict(state: dict) -> dict:
if any(k.startswith("model.") for k in state):
state = {k.replace("model.", ""): v for k, v in state.items()}
if any(k.startswith("module.") for k in state):
state = {k.replace("module.", ""): v for k, v in state.items()}
return state
def _unwrap(raw: object) -> dict:
if isinstance(raw, dict) and "state_dict" in raw:
return dict(raw["state_dict"])
if isinstance(raw, dict) and "model_state_dict" in raw:
return dict(raw["model_state_dict"])
if not isinstance(raw, dict):
raise TypeError(f"Expected state dict, got {type(raw)}")
return raw
def eval_transform(image_size: int = IMAGE_SIZE) -> transforms.Compose:
return transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def load_model(
repo_id: str,
*,
weights_name: str = "model.pth",
image_size: int = IMAGE_SIZE,
token: str | None = None,
):
"""
Download weights from the Hub, build EfficientNet-B0 + head, return ``(model, transform)``.
"""
tok = token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
path = hf_hub_download(
repo_id=repo_id,
filename=weights_name,
repo_type="model",
token=tok,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
raw = torch.load(path, map_location=device, weights_only=False)
state = _normalize_state_dict(_unwrap(raw))
model = build_model()
model.load_state_dict(state, strict=True)
model.to(device)
model.eval()
return model, eval_transform(image_size)
def predict_proba(model: nn.Module, tfm: transforms.Compose, img: Image.Image):
device = next(model.parameters()).device
x = tfm(img.convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
p = torch.softmax(model(x), dim=1)[0]
return {CLASS_NAMES[i]: p[i].item() for i in range(len(CLASS_NAMES))}