# inference/model_loader.py """ Download and cache model weights from HuggingFace Hub. Models are loaded once at startup and cached in memory. """ import logging import os import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from training.models.efficientnet_b4 import AnemiaModel log = logging.getLogger(__name__) _MODEL_CACHE: dict[str, AnemiaModel] = {} HF_REPOS = { "conjunctiva": os.getenv("HF_CONJ_MODEL_REPO", "hssling/anemia-efficientnet-b4-conjunctiva"), "nailbed": os.getenv("HF_NAIL_MODEL_REPO", "hssling/anemia-efficientnet-b4-nailbed"), } HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") DEVICE = torch.device("cpu") # HF Spaces free tier is CPU-only def load_model(site: str) -> AnemiaModel: """Load and cache model for a given site ('conjunctiva' or 'nailbed').""" if site in _MODEL_CACHE: return _MODEL_CACHE[site] repo_id = HF_REPOS.get(site) if repo_id is None: raise ValueError(f"Unknown site: {site!r}. Must be 'conjunctiva' or 'nailbed'.") log.info(f"Downloading model weights from {repo_id} ...") ckpt_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", token=HF_TOKEN) model = AnemiaModel(pretrained=False) state_dict = load_file(ckpt_path, device="cpu") model.load_state_dict(state_dict) model.to(DEVICE) model.eval() _MODEL_CACHE[site] = model log.info(f"Model loaded for site: {site}") return model def preload_all_models(): """Eagerly load all models at startup to avoid cold-start delays.""" for site in HF_REPOS: try: load_model(site) except Exception as e: log.warning(f"Could not preload {site} model: {e}")