Spaces:
Running
Running
File size: 1,766 Bytes
78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b 78faeac 0ce958b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | # inference/model_loader.py
"""
Download and cache model weights from HuggingFace Hub.
Models are loaded once at startup and cached in memory.
"""
import logging
import os
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from training.models.efficientnet_b4 import AnemiaModel
log = logging.getLogger(__name__)
_MODEL_CACHE: dict[str, AnemiaModel] = {}
HF_REPOS = {
"conjunctiva": os.getenv("HF_CONJ_MODEL_REPO", "hssling/anemia-efficientnet-b4-conjunctiva"),
"nailbed": os.getenv("HF_NAIL_MODEL_REPO", "hssling/anemia-efficientnet-b4-nailbed"),
}
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
DEVICE = torch.device("cpu") # HF Spaces free tier is CPU-only
def load_model(site: str) -> AnemiaModel:
"""Load and cache model for a given site ('conjunctiva' or 'nailbed')."""
if site in _MODEL_CACHE:
return _MODEL_CACHE[site]
repo_id = HF_REPOS.get(site)
if repo_id is None:
raise ValueError(f"Unknown site: {site!r}. Must be 'conjunctiva' or 'nailbed'.")
log.info(f"Downloading model weights from {repo_id} ...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", token=HF_TOKEN)
model = AnemiaModel(pretrained=False)
state_dict = load_file(ckpt_path, device="cpu")
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
_MODEL_CACHE[site] = model
log.info(f"Model loaded for site: {site}")
return model
def preload_all_models():
"""Eagerly load all models at startup to avoid cold-start delays."""
for site in HF_REPOS:
try:
load_model(site)
except Exception as e:
log.warning(f"Could not preload {site} model: {e}")
|