Spaces:
Sleeping
Sleeping
| """ | |
| ModelLoader - Smart Model Sharing between Spaces | |
| Downloads models ONCE from gionuibk/NautilusModels and caches locally. | |
| Subsequent calls use cached version (no re-download). | |
| """ | |
| import os | |
| import json | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| from huggingface_hub import HfApi, hf_hub_download | |
| HF_MODEL_REPO = "gionuibk/NautilusModels" | |
| print("📦 ModelLoader Module Initialized") # Force Upload Hash Change | |
| # Use HF's default cache - files are cached and only re-downloaded if changed | |
| LOCAL_CACHE = None # Let hf_hub_download use default cache | |
| class ModelLoader: | |
| """ | |
| Smart model loader with caching. | |
| - Downloads from HF only if not cached or file changed | |
| - Uses local cache for fast loading | |
| """ | |
| def __init__(self, token: str = None): | |
| self.token = token or os.environ.get("HF_TOKEN") | |
| self.api = HfApi(token=self.token) | |
| self._manifest = None | |
| self._manifest_loaded = False | |
| def _load_manifest(self) -> Dict[str, Any]: | |
| """Load best_models.json manifest from HuggingFace (cached).""" | |
| if self._manifest_loaded: | |
| return self._manifest or {} | |
| try: | |
| manifest_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename="best_models.json", | |
| repo_type="model", | |
| token=self.token | |
| # Uses HF default cache - fast if already downloaded | |
| ) | |
| with open(manifest_path, 'r') as f: | |
| self._manifest = json.load(f) | |
| print(f"📋 Model manifest loaded: {list(self._manifest.keys())}") | |
| except Exception as e: | |
| print(f"⚠️ Could not load best_models.json: {e}") | |
| self._manifest = {} | |
| self._manifest_loaded = True | |
| return self._manifest | |
| def get_best_model(self, model_type: str, format: str = "onnx") -> Optional[str]: | |
| """ | |
| Get path to the best model. Downloads once, then uses cache. | |
| Args: | |
| model_type: "deeplob", "trm", "lstm", etc. | |
| format: "onnx" or "pt" | |
| Returns: | |
| Local path to cached model file | |
| """ | |
| manifest = self._load_manifest() | |
| if model_type not in manifest: | |
| print(f"⚠️ No model found for: {model_type}") | |
| return None | |
| model_info = manifest[model_type] | |
| # Prefer ONNX, fallback to PT | |
| if format == "onnx" and model_info.get("onnx_file"): | |
| filename = model_info["onnx_file"] | |
| elif model_info.get("pt_file"): | |
| filename = model_info["pt_file"] | |
| if format == "onnx": | |
| print(f"⚠️ ONNX not available for {model_type}, using PT") | |
| else: | |
| print(f"⚠️ No model file for {model_type}") | |
| return None | |
| # Download (or use cached version) | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=filename, | |
| repo_type="model", | |
| token=self.token | |
| ) | |
| # Dynamic Metric Logging | |
| metric_name = model_info.get("metric_name", "acc") | |
| metric_val = model_info.get("metric_value", model_info.get("accuracy", "N/A")) | |
| # Format value nicely | |
| if isinstance(metric_val, (int, float)): | |
| val_str = f"{metric_val:.4f}" if metric_val < 1 else f"{metric_val:.2f}" | |
| else: | |
| val_str = str(metric_val) | |
| print(f"✅ Model ready: {filename} ({metric_name}={val_str})") | |
| return local_path | |
| except Exception as e: | |
| print(f"❌ Failed to get {filename}: {e}") | |
| return None | |
| def check_for_update(self, model_type: str, current_path: str, format: str = "onnx") -> Optional[str]: | |
| """ | |
| Check if a newer model is available. | |
| Returns new path if available, or None if no update. | |
| """ | |
| # Force re-download of manifest (bypass internal flag, use cache control?) | |
| # hf_hub_download handles ETag check, so calling it again is cheap. | |
| try: | |
| old_manifest = self._manifest.copy() if self._manifest else {} | |
| # Re-fetch manifest (populating self._manifest with fresh data) | |
| self._manifest_loaded = False | |
| new_manifest = self._load_manifest() | |
| if model_type not in new_manifest: | |
| return None | |
| model_info = new_manifest[model_type] | |
| # Determine filename | |
| if format == "onnx" and model_info.get("onnx_file"): | |
| filename = model_info["onnx_file"] | |
| elif model_info.get("pt_file"): | |
| filename = model_info["pt_file"] | |
| else: | |
| return None | |
| # Compare with current | |
| # If we don't have a current path, treat as update | |
| if not current_path: | |
| return self.get_best_model(model_type, format) | |
| if filename not in current_path: | |
| print(f"🆕 New model detected: {filename} (Replacing {os.path.basename(current_path)})") | |
| return self.get_best_model(model_type, format) | |
| return None | |
| except Exception as e: | |
| print(f"⚠️ Check for update failed: {e}") | |
| return None | |
| def get_model_info(self, model_type: str) -> Optional[Dict[str, Any]]: | |
| """Get metadata about a model.""" | |
| return self._load_manifest().get(model_type) | |
| def list_available_models(self) -> list: | |
| """List all available model types.""" | |
| return list(self._load_manifest().keys()) | |
| # Singleton instance | |
| _loader: Optional[ModelLoader] = None | |
| def get_model_loader() -> ModelLoader: | |
| """Get singleton ModelLoader.""" | |
| global _loader | |
| if _loader is None: | |
| _loader = ModelLoader() | |
| return _loader | |
| def get_best_model(model_type: str, format: str = "onnx") -> Optional[str]: | |
| """Get path to best model (downloads once, then cached).""" | |
| return get_model_loader().get_best_model(model_type, format) | |
| def check_for_model_update(model_type: str, current_path: str, format: str = "onnx") -> Optional[str]: | |
| """Check and fetch update if available.""" | |
| return get_model_loader().check_for_update(model_type, current_path, format) | |