NautilusAI / model_loader.py
gionuibk's picture
Upload model_loader.py with huggingface_hub
b543f77 verified
"""
ModelLoader - Smart Model Sharing between Spaces
Downloads models ONCE from gionuibk/NautilusModels and caches locally.
Subsequent calls use cached version (no re-download).
"""
import os
import json
from pathlib import Path
from typing import Optional, Dict, Any
from huggingface_hub import HfApi, hf_hub_download
HF_MODEL_REPO = "gionuibk/NautilusModels"
print("📦 ModelLoader Module Initialized") # Force Upload Hash Change
# Use HF's default cache - files are cached and only re-downloaded if changed
LOCAL_CACHE = None # Let hf_hub_download use default cache
class ModelLoader:
"""
Smart model loader with caching.
- Downloads from HF only if not cached or file changed
- Uses local cache for fast loading
"""
def __init__(self, token: str = None):
self.token = token or os.environ.get("HF_TOKEN")
self.api = HfApi(token=self.token)
self._manifest = None
self._manifest_loaded = False
def _load_manifest(self) -> Dict[str, Any]:
"""Load best_models.json manifest from HuggingFace (cached)."""
if self._manifest_loaded:
return self._manifest or {}
try:
manifest_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="best_models.json",
repo_type="model",
token=self.token
# Uses HF default cache - fast if already downloaded
)
with open(manifest_path, 'r') as f:
self._manifest = json.load(f)
print(f"📋 Model manifest loaded: {list(self._manifest.keys())}")
except Exception as e:
print(f"⚠️ Could not load best_models.json: {e}")
self._manifest = {}
self._manifest_loaded = True
return self._manifest
def get_best_model(self, model_type: str, format: str = "onnx") -> Optional[str]:
"""
Get path to the best model. Downloads once, then uses cache.
Args:
model_type: "deeplob", "trm", "lstm", etc.
format: "onnx" or "pt"
Returns:
Local path to cached model file
"""
manifest = self._load_manifest()
if model_type not in manifest:
print(f"⚠️ No model found for: {model_type}")
return None
model_info = manifest[model_type]
# Prefer ONNX, fallback to PT
if format == "onnx" and model_info.get("onnx_file"):
filename = model_info["onnx_file"]
elif model_info.get("pt_file"):
filename = model_info["pt_file"]
if format == "onnx":
print(f"⚠️ ONNX not available for {model_type}, using PT")
else:
print(f"⚠️ No model file for {model_type}")
return None
# Download (or use cached version)
try:
local_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=filename,
repo_type="model",
token=self.token
)
# Dynamic Metric Logging
metric_name = model_info.get("metric_name", "acc")
metric_val = model_info.get("metric_value", model_info.get("accuracy", "N/A"))
# Format value nicely
if isinstance(metric_val, (int, float)):
val_str = f"{metric_val:.4f}" if metric_val < 1 else f"{metric_val:.2f}"
else:
val_str = str(metric_val)
print(f"✅ Model ready: {filename} ({metric_name}={val_str})")
return local_path
except Exception as e:
print(f"❌ Failed to get {filename}: {e}")
return None
def check_for_update(self, model_type: str, current_path: str, format: str = "onnx") -> Optional[str]:
"""
Check if a newer model is available.
Returns new path if available, or None if no update.
"""
# Force re-download of manifest (bypass internal flag, use cache control?)
# hf_hub_download handles ETag check, so calling it again is cheap.
try:
old_manifest = self._manifest.copy() if self._manifest else {}
# Re-fetch manifest (populating self._manifest with fresh data)
self._manifest_loaded = False
new_manifest = self._load_manifest()
if model_type not in new_manifest:
return None
model_info = new_manifest[model_type]
# Determine filename
if format == "onnx" and model_info.get("onnx_file"):
filename = model_info["onnx_file"]
elif model_info.get("pt_file"):
filename = model_info["pt_file"]
else:
return None
# Compare with current
# If we don't have a current path, treat as update
if not current_path:
return self.get_best_model(model_type, format)
if filename not in current_path:
print(f"🆕 New model detected: {filename} (Replacing {os.path.basename(current_path)})")
return self.get_best_model(model_type, format)
return None
except Exception as e:
print(f"⚠️ Check for update failed: {e}")
return None
def get_model_info(self, model_type: str) -> Optional[Dict[str, Any]]:
"""Get metadata about a model."""
return self._load_manifest().get(model_type)
def list_available_models(self) -> list:
"""List all available model types."""
return list(self._load_manifest().keys())
# Singleton instance
_loader: Optional[ModelLoader] = None
def get_model_loader() -> ModelLoader:
"""Get singleton ModelLoader."""
global _loader
if _loader is None:
_loader = ModelLoader()
return _loader
def get_best_model(model_type: str, format: str = "onnx") -> Optional[str]:
"""Get path to best model (downloads once, then cached)."""
return get_model_loader().get_best_model(model_type, format)
def check_for_model_update(model_type: str, current_path: str, format: str = "onnx") -> Optional[str]:
"""Check and fetch update if available."""
return get_model_loader().check_for_update(model_type, current_path, format)