File size: 4,024 Bytes
c5732cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdca5d6
c5732cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdca5d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Model loading and caching module.

This module provides functions to load anomaly detection models from
Hugging Face Hub with caching support to avoid reloading the same model multiple times.
"""

import os
import torch
from collections import OrderedDict
from huggingface_hub import hf_hub_download
from anomalib.models import Patchcore, EfficientAd

from config import HF_REPO_ID, MODEL_TO_DIR

# Maximum number of models to keep in cache (prevents unbounded memory growth)
# Reduced for HF Spaces limited storage
MAX_MODEL_CACHE_SIZE = 30

# Global model cache with LRU eviction (using OrderedDict)
_model_cache = OrderedDict()


def get_ckpt_path(model_name: str, category: str) -> str:
    """
    Download or retrieve the checkpoint file for a given model and category.

    Args:
        model_name: Name of the model ("patchcore" or "efficientad")
        category: MVTec AD category (e.g., "bottle", "cable")

    Returns:
        Path to the downloaded checkpoint file
    """
    dirname = MODEL_TO_DIR[model_name]
    hf_path = f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt"

    return hf_hub_download(
        repo_id=HF_REPO_ID,
        filename=hf_path,
        local_dir="models",
        local_dir_use_symlinks=False,
    )


def load_model(model_name: str, category: str):
    """
    Load an anomaly detection model with caching and LRU eviction.

    Args:
        model_name: Name of the model ("patchcore" or "efficientad")
        category: MVTec AD category

    Returns:
        Loaded model on the appropriate device (CUDA if available)

    Raises:
        ValueError: If an unknown model name is provided
    """
    key = f"{model_name}_{category}"

    # Return cached model if available (move to end to mark as recently used)
    if key in _model_cache:
        _model_cache.move_to_end(key)
        return _model_cache[key]

    # Evict least recently used model if cache is full
    if len(_model_cache) >= MAX_MODEL_CACHE_SIZE:
        _model_cache.popitem(last=False)  # Remove first (oldest) item

    # Download checkpoint
    ckpt = get_ckpt_path(model_name, category)

    # Load the appropriate model type
    if model_name == "patchcore":
        model = Patchcore.load_from_checkpoint(ckpt)
    elif model_name == "efficientad":
        model = EfficientAd.load_from_checkpoint(ckpt)
    else:
        raise ValueError(f"Unknown model: {model_name}")

    # Set evaluation mode and move to device
    model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    # Cache the model (add to end)
    _model_cache[key] = model

    return model


def clear_model_cache():
    """Clear the model cache to free memory."""
    global _model_cache
    _model_cache.clear()


def warmup_cache(model_names: list = None, categories: list = None):
    """
    Pre-download and cache models in background to reduce first-inference latency.
    
    Args:
        model_names: List of model names to warmup. Default: ["patchcore", "efficientad"]
        categories: List of categories to warmup. Default: ["bottle"]
        
    Returns:
        dict: Mapping of model keys to their cached instances
    """
    import os
    from threading import Thread
    
    if model_names is None:
        model_names = ["patchcore", "efficientad"]
    if categories is None:
        categories = ["bottle"]
    
    results = {}
    
    def _warmup_single(model_name, category):
        try:
            model = load_model(model_name, category)
            key = f"{model_name}_{category}"
            results[key] = model
        except Exception as e:
            print(f"[WARMUP] Failed to load {model_name}/{category}: {e}")
    
    threads = []
    for model_name in model_names:
        for category in categories:
            t = Thread(target=_warmup_single, args=(model_name, category), daemon=True)
            t.start()
            threads.append(t)
    
    # Don't wait for threads - they run in background
    return results