""" ModelLoader - Smart Model Sharing between Spaces Downloads models ONCE from gionuibk/NautilusModels and caches locally. Subsequent calls use cached version (no re-download). """ import os import json from pathlib import Path from typing import Optional, Dict, Any from huggingface_hub import HfApi, hf_hub_download HF_MODEL_REPO = "gionuibk/NautilusModels" print("📦 ModelLoader Module Initialized") # Force Upload Hash Change # Use HF's default cache - files are cached and only re-downloaded if changed LOCAL_CACHE = None # Let hf_hub_download use default cache class ModelLoader: """ Smart model loader with caching. - Downloads from HF only if not cached or file changed - Uses local cache for fast loading """ def __init__(self, token: str = None): self.token = token or os.environ.get("HF_TOKEN") self.api = HfApi(token=self.token) self._manifest = None self._manifest_loaded = False def _load_manifest(self) -> Dict[str, Any]: """Load best_models.json manifest from HuggingFace (cached).""" if self._manifest_loaded: return self._manifest or {} try: manifest_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename="best_models.json", repo_type="model", token=self.token # Uses HF default cache - fast if already downloaded ) with open(manifest_path, 'r') as f: self._manifest = json.load(f) print(f"📋 Model manifest loaded: {list(self._manifest.keys())}") except Exception as e: print(f"⚠️ Could not load best_models.json: {e}") self._manifest = {} self._manifest_loaded = True return self._manifest def get_best_model(self, model_type: str, format: str = "onnx") -> Optional[str]: """ Get path to the best model. Downloads once, then uses cache. Args: model_type: "deeplob", "trm", "lstm", etc. format: "onnx" or "pt" Returns: Local path to cached model file """ manifest = self._load_manifest() if model_type not in manifest: print(f"⚠️ No model found for: {model_type}") return None model_info = manifest[model_type] # Prefer ONNX, fallback to PT if format == "onnx" and model_info.get("onnx_file"): filename = model_info["onnx_file"] elif model_info.get("pt_file"): filename = model_info["pt_file"] if format == "onnx": print(f"⚠️ ONNX not available for {model_type}, using PT") else: print(f"⚠️ No model file for {model_type}") return None # Download (or use cached version) try: local_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename=filename, repo_type="model", token=self.token ) # Dynamic Metric Logging metric_name = model_info.get("metric_name", "acc") metric_val = model_info.get("metric_value", model_info.get("accuracy", "N/A")) # Format value nicely if isinstance(metric_val, (int, float)): val_str = f"{metric_val:.4f}" if metric_val < 1 else f"{metric_val:.2f}" else: val_str = str(metric_val) print(f"✅ Model ready: {filename} ({metric_name}={val_str})") return local_path except Exception as e: print(f"❌ Failed to get {filename}: {e}") return None def check_for_update(self, model_type: str, current_path: str, format: str = "onnx") -> Optional[str]: """ Check if a newer model is available. Returns new path if available, or None if no update. """ # Force re-download of manifest (bypass internal flag, use cache control?) # hf_hub_download handles ETag check, so calling it again is cheap. try: old_manifest = self._manifest.copy() if self._manifest else {} # Re-fetch manifest (populating self._manifest with fresh data) self._manifest_loaded = False new_manifest = self._load_manifest() if model_type not in new_manifest: return None model_info = new_manifest[model_type] # Determine filename if format == "onnx" and model_info.get("onnx_file"): filename = model_info["onnx_file"] elif model_info.get("pt_file"): filename = model_info["pt_file"] else: return None # Compare with current # If we don't have a current path, treat as update if not current_path: return self.get_best_model(model_type, format) if filename not in current_path: print(f"🆕 New model detected: {filename} (Replacing {os.path.basename(current_path)})") return self.get_best_model(model_type, format) return None except Exception as e: print(f"⚠️ Check for update failed: {e}") return None def get_model_info(self, model_type: str) -> Optional[Dict[str, Any]]: """Get metadata about a model.""" return self._load_manifest().get(model_type) def list_available_models(self) -> list: """List all available model types.""" return list(self._load_manifest().keys()) # Singleton instance _loader: Optional[ModelLoader] = None def get_model_loader() -> ModelLoader: """Get singleton ModelLoader.""" global _loader if _loader is None: _loader = ModelLoader() return _loader def get_best_model(model_type: str, format: str = "onnx") -> Optional[str]: """Get path to best model (downloads once, then cached).""" return get_model_loader().get_best_model(model_type, format) def check_for_model_update(model_type: str, current_path: str, format: str = "onnx") -> Optional[str]: """Check and fetch update if available.""" return get_model_loader().check_for_update(model_type, current_path, format)