""" Standalone inference module for the Hugging Face model repo. This file is uploaded as ``load_model.py`` next to ``model.pth``. Copy-paste friendly: only needs ``torch``, ``torchvision``, ``pillow``, ``huggingface_hub``. """ from __future__ import annotations import os import torch import torch.nn as nn from huggingface_hub import hf_hub_download from torchvision import models, transforms from PIL import Image IMAGE_SIZE = 299 CLASS_NAMES = ("No-Stroke", "Stroke") class StrokeHead(nn.Sequential): def __init__(self, in_features: int, num_classes: int = 2): super().__init__( nn.Linear(in_features, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, num_classes), ) def build_model(num_classes: int = 2) -> nn.Module: m = models.efficientnet_b0(weights=None) in_features = m.classifier[1].in_features m.classifier = StrokeHead(in_features, num_classes) return m def _normalize_state_dict(state: dict) -> dict: if any(k.startswith("model.") for k in state): state = {k.replace("model.", ""): v for k, v in state.items()} if any(k.startswith("module.") for k in state): state = {k.replace("module.", ""): v for k, v in state.items()} return state def _unwrap(raw: object) -> dict: if isinstance(raw, dict) and "state_dict" in raw: return dict(raw["state_dict"]) if isinstance(raw, dict) and "model_state_dict" in raw: return dict(raw["model_state_dict"]) if not isinstance(raw, dict): raise TypeError(f"Expected state dict, got {type(raw)}") return raw def eval_transform(image_size: int = IMAGE_SIZE) -> transforms.Compose: return transforms.Compose( [ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def load_model( repo_id: str, *, weights_name: str = "model.pth", image_size: int = IMAGE_SIZE, token: str | None = None, ): """ Download weights from the Hub, build EfficientNet-B0 + head, return ``(model, transform)``. """ tok = token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") path = hf_hub_download( repo_id=repo_id, filename=weights_name, repo_type="model", token=tok, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") raw = torch.load(path, map_location=device, weights_only=False) state = _normalize_state_dict(_unwrap(raw)) model = build_model() model.load_state_dict(state, strict=True) model.to(device) model.eval() return model, eval_transform(image_size) def predict_proba(model: nn.Module, tfm: transforms.Compose, img: Image.Image): device = next(model.parameters()).device x = tfm(img.convert("RGB")).unsqueeze(0).to(device) with torch.no_grad(): p = torch.softmax(model(x), dim=1)[0] return {CLASS_NAMES[i]: p[i].item() for i in range(len(CLASS_NAMES))}