iad-explainable-hf / models /model_loader.py
Parikshit Rathode
added lazy loading and fixed floating point issue in inference
fdca5d6
"""
Model loading and caching module.
This module provides functions to load anomaly detection models from
Hugging Face Hub with caching support to avoid reloading the same model multiple times.
"""
import os
import torch
from collections import OrderedDict
from huggingface_hub import hf_hub_download
from anomalib.models import Patchcore, EfficientAd
from config import HF_REPO_ID, MODEL_TO_DIR
# Maximum number of models to keep in cache (prevents unbounded memory growth)
# Reduced for HF Spaces limited storage
MAX_MODEL_CACHE_SIZE = 30
# Global model cache with LRU eviction (using OrderedDict)
_model_cache = OrderedDict()
def get_ckpt_path(model_name: str, category: str) -> str:
"""
Download or retrieve the checkpoint file for a given model and category.
Args:
model_name: Name of the model ("patchcore" or "efficientad")
category: MVTec AD category (e.g., "bottle", "cable")
Returns:
Path to the downloaded checkpoint file
"""
dirname = MODEL_TO_DIR[model_name]
hf_path = f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"
return hf_hub_download(
repo_id=HF_REPO_ID,
filename=hf_path,
local_dir="models",
local_dir_use_symlinks=False,
)
def load_model(model_name: str, category: str):
"""
Load an anomaly detection model with caching and LRU eviction.
Args:
model_name: Name of the model ("patchcore" or "efficientad")
category: MVTec AD category
Returns:
Loaded model on the appropriate device (CUDA if available)
Raises:
ValueError: If an unknown model name is provided
"""
key = f"{model_name}_{category}"
# Return cached model if available (move to end to mark as recently used)
if key in _model_cache:
_model_cache.move_to_end(key)
return _model_cache[key]
# Evict least recently used model if cache is full
if len(_model_cache) >= MAX_MODEL_CACHE_SIZE:
_model_cache.popitem(last=False) # Remove first (oldest) item
# Download checkpoint
ckpt = get_ckpt_path(model_name, category)
# Load the appropriate model type
if model_name == "patchcore":
model = Patchcore.load_from_checkpoint(ckpt)
elif model_name == "efficientad":
model = EfficientAd.load_from_checkpoint(ckpt)
else:
raise ValueError(f"Unknown model: {model_name}")
# Set evaluation mode and move to device
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Cache the model (add to end)
_model_cache[key] = model
return model
def clear_model_cache():
"""Clear the model cache to free memory."""
global _model_cache
_model_cache.clear()
def warmup_cache(model_names: list = None, categories: list = None):
"""
Pre-download and cache models in background to reduce first-inference latency.
Args:
model_names: List of model names to warmup. Default: ["patchcore", "efficientad"]
categories: List of categories to warmup. Default: ["bottle"]
Returns:
dict: Mapping of model keys to their cached instances
"""
import os
from threading import Thread
if model_names is None:
model_names = ["patchcore", "efficientad"]
if categories is None:
categories = ["bottle"]
results = {}
def _warmup_single(model_name, category):
try:
model = load_model(model_name, category)
key = f"{model_name}_{category}"
results[key] = model
except Exception as e:
print(f"[WARMUP] Failed to load {model_name}/{category}: {e}")
threads = []
for model_name in model_names:
for category in categories:
t = Thread(target=_warmup_single, args=(model_name, category), daemon=True)
t.start()
threads.append(t)
# Don't wait for threads - they run in background
return results