Spaces:
Sleeping
Sleeping
| """ | |
| Model loading and caching module. | |
| This module provides functions to load anomaly detection models from | |
| Hugging Face Hub with caching support to avoid reloading the same model multiple times. | |
| """ | |
| import os | |
| import torch | |
| from collections import OrderedDict | |
| from huggingface_hub import hf_hub_download | |
| from anomalib.models import Patchcore, EfficientAd | |
| from config import HF_REPO_ID, MODEL_TO_DIR | |
| # Maximum number of models to keep in cache (prevents unbounded memory growth) | |
| # Reduced for HF Spaces limited storage | |
| MAX_MODEL_CACHE_SIZE = 30 | |
| # Global model cache with LRU eviction (using OrderedDict) | |
| _model_cache = OrderedDict() | |
| def get_ckpt_path(model_name: str, category: str) -> str: | |
| """ | |
| Download or retrieve the checkpoint file for a given model and category. | |
| Args: | |
| model_name: Name of the model ("patchcore" or "efficientad") | |
| category: MVTec AD category (e.g., "bottle", "cable") | |
| Returns: | |
| Path to the downloaded checkpoint file | |
| """ | |
| dirname = MODEL_TO_DIR[model_name] | |
| hf_path = f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt" | |
| return hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=hf_path, | |
| local_dir="models", | |
| local_dir_use_symlinks=False, | |
| ) | |
| def load_model(model_name: str, category: str): | |
| """ | |
| Load an anomaly detection model with caching and LRU eviction. | |
| Args: | |
| model_name: Name of the model ("patchcore" or "efficientad") | |
| category: MVTec AD category | |
| Returns: | |
| Loaded model on the appropriate device (CUDA if available) | |
| Raises: | |
| ValueError: If an unknown model name is provided | |
| """ | |
| key = f"{model_name}_{category}" | |
| # Return cached model if available (move to end to mark as recently used) | |
| if key in _model_cache: | |
| _model_cache.move_to_end(key) | |
| return _model_cache[key] | |
| # Evict least recently used model if cache is full | |
| if len(_model_cache) >= MAX_MODEL_CACHE_SIZE: | |
| _model_cache.popitem(last=False) # Remove first (oldest) item | |
| # Download checkpoint | |
| ckpt = get_ckpt_path(model_name, category) | |
| # Load the appropriate model type | |
| if model_name == "patchcore": | |
| model = Patchcore.load_from_checkpoint(ckpt) | |
| elif model_name == "efficientad": | |
| model = EfficientAd.load_from_checkpoint(ckpt) | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # Set evaluation mode and move to device | |
| model.eval() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # Cache the model (add to end) | |
| _model_cache[key] = model | |
| return model | |
| def clear_model_cache(): | |
| """Clear the model cache to free memory.""" | |
| global _model_cache | |
| _model_cache.clear() | |
| def warmup_cache(model_names: list = None, categories: list = None): | |
| """ | |
| Pre-download and cache models in background to reduce first-inference latency. | |
| Args: | |
| model_names: List of model names to warmup. Default: ["patchcore", "efficientad"] | |
| categories: List of categories to warmup. Default: ["bottle"] | |
| Returns: | |
| dict: Mapping of model keys to their cached instances | |
| """ | |
| import os | |
| from threading import Thread | |
| if model_names is None: | |
| model_names = ["patchcore", "efficientad"] | |
| if categories is None: | |
| categories = ["bottle"] | |
| results = {} | |
| def _warmup_single(model_name, category): | |
| try: | |
| model = load_model(model_name, category) | |
| key = f"{model_name}_{category}" | |
| results[key] = model | |
| except Exception as e: | |
| print(f"[WARMUP] Failed to load {model_name}/{category}: {e}") | |
| threads = [] | |
| for model_name in model_names: | |
| for category in categories: | |
| t = Thread(target=_warmup_single, args=(model_name, category), daemon=True) | |
| t.start() | |
| threads.append(t) | |
| # Don't wait for threads - they run in background | |
| return results | |