""" Model loading and caching module. This module provides functions to load anomaly detection models from Hugging Face Hub with caching support to avoid reloading the same model multiple times. """ import os import torch from collections import OrderedDict from huggingface_hub import hf_hub_download from anomalib.models import Patchcore, EfficientAd from config import HF_REPO_ID, MODEL_TO_DIR # Maximum number of models to keep in cache (prevents unbounded memory growth) # Reduced for HF Spaces limited storage MAX_MODEL_CACHE_SIZE = 30 # Global model cache with LRU eviction (using OrderedDict) _model_cache = OrderedDict() def get_ckpt_path(model_name: str, category: str) -> str: """ Download or retrieve the checkpoint file for a given model and category. Args: model_name: Name of the model ("patchcore" or "efficientad") category: MVTec AD category (e.g., "bottle", "cable") Returns: Path to the downloaded checkpoint file """ dirname = MODEL_TO_DIR[model_name] hf_path = f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt" return hf_hub_download( repo_id=HF_REPO_ID, filename=hf_path, local_dir="models", local_dir_use_symlinks=False, ) def load_model(model_name: str, category: str): """ Load an anomaly detection model with caching and LRU eviction. Args: model_name: Name of the model ("patchcore" or "efficientad") category: MVTec AD category Returns: Loaded model on the appropriate device (CUDA if available) Raises: ValueError: If an unknown model name is provided """ key = f"{model_name}_{category}" # Return cached model if available (move to end to mark as recently used) if key in _model_cache: _model_cache.move_to_end(key) return _model_cache[key] # Evict least recently used model if cache is full if len(_model_cache) >= MAX_MODEL_CACHE_SIZE: _model_cache.popitem(last=False) # Remove first (oldest) item # Download checkpoint ckpt = get_ckpt_path(model_name, category) # Load the appropriate model type if model_name == "patchcore": model = Patchcore.load_from_checkpoint(ckpt) elif model_name == "efficientad": model = EfficientAd.load_from_checkpoint(ckpt) else: raise ValueError(f"Unknown model: {model_name}") # Set evaluation mode and move to device model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # Cache the model (add to end) _model_cache[key] = model return model def clear_model_cache(): """Clear the model cache to free memory.""" global _model_cache _model_cache.clear() def warmup_cache(model_names: list = None, categories: list = None): """ Pre-download and cache models in background to reduce first-inference latency. Args: model_names: List of model names to warmup. Default: ["patchcore", "efficientad"] categories: List of categories to warmup. Default: ["bottle"] Returns: dict: Mapping of model keys to their cached instances """ import os from threading import Thread if model_names is None: model_names = ["patchcore", "efficientad"] if categories is None: categories = ["bottle"] results = {} def _warmup_single(model_name, category): try: model = load_model(model_name, category) key = f"{model_name}_{category}" results[key] = model except Exception as e: print(f"[WARMUP] Failed to load {model_name}/{category}: {e}") threads = [] for model_name in model_names: for category in categories: t = Thread(target=_warmup_single, args=(model_name, category), daemon=True) t.start() threads.append(t) # Don't wait for threads - they run in background return results