import gc import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from app.core.device import DEVICE _cache: dict = {} def load_image_model(cfg: dict): """Lazy-load a model by config key. Returns (processor, model) or None on failure.""" key = cfg["key"] if key in _cache: return _cache[key] print(f"Loading {cfg['desc']} ({cfg['name']})...") try: proc = AutoImageProcessor.from_pretrained(cfg["name"]) model = AutoModelForImageClassification.from_pretrained(cfg["name"]).to(DEVICE) model.eval() _cache[key] = (proc, model) print(f"{key} ready, labels: {model.config.id2label}") except Exception as e: print(f"Failed to load {key}: {e}") _cache[key] = None return _cache[key] def unload_all(): global _cache for entry in _cache.values(): if entry is not None: proc, model = entry del model del proc _cache = {} gc.collect() if torch.backends.mps.is_available(): torch.mps.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache()