Spaces:
Sleeping
Sleeping
File size: 6,555 Bytes
d81d817 eac84cc d81d817 b543f77 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 dfbcfcc d81d817 eac84cc d81d817 dfbcfcc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 eac84cc d81d817 dfbcfcc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
ModelLoader - Smart Model Sharing between Spaces
Downloads models ONCE from gionuibk/NautilusModels and caches locally.
Subsequent calls use cached version (no re-download).
"""
import os
import json
from pathlib import Path
from typing import Optional, Dict, Any
from huggingface_hub import HfApi, hf_hub_download
HF_MODEL_REPO = "gionuibk/NautilusModels"
print("📦 ModelLoader Module Initialized") # Force Upload Hash Change
# Use HF's default cache - files are cached and only re-downloaded if changed
LOCAL_CACHE = None # Let hf_hub_download use default cache
class ModelLoader:
"""
Smart model loader with caching.
- Downloads from HF only if not cached or file changed
- Uses local cache for fast loading
"""
def __init__(self, token: str = None):
self.token = token or os.environ.get("HF_TOKEN")
self.api = HfApi(token=self.token)
self._manifest = None
self._manifest_loaded = False
def _load_manifest(self) -> Dict[str, Any]:
"""Load best_models.json manifest from HuggingFace (cached)."""
if self._manifest_loaded:
return self._manifest or {}
try:
manifest_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="best_models.json",
repo_type="model",
token=self.token
# Uses HF default cache - fast if already downloaded
)
with open(manifest_path, 'r') as f:
self._manifest = json.load(f)
print(f"📋 Model manifest loaded: {list(self._manifest.keys())}")
except Exception as e:
print(f"⚠️ Could not load best_models.json: {e}")
self._manifest = {}
self._manifest_loaded = True
return self._manifest
def get_best_model(self, model_type: str, format: str = "onnx") -> Optional[str]:
"""
Get path to the best model. Downloads once, then uses cache.
Args:
model_type: "deeplob", "trm", "lstm", etc.
format: "onnx" or "pt"
Returns:
Local path to cached model file
"""
manifest = self._load_manifest()
if model_type not in manifest:
print(f"⚠️ No model found for: {model_type}")
return None
model_info = manifest[model_type]
# Prefer ONNX, fallback to PT
if format == "onnx" and model_info.get("onnx_file"):
filename = model_info["onnx_file"]
elif model_info.get("pt_file"):
filename = model_info["pt_file"]
if format == "onnx":
print(f"⚠️ ONNX not available for {model_type}, using PT")
else:
print(f"⚠️ No model file for {model_type}")
return None
# Download (or use cached version)
try:
local_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=filename,
repo_type="model",
token=self.token
)
# Dynamic Metric Logging
metric_name = model_info.get("metric_name", "acc")
metric_val = model_info.get("metric_value", model_info.get("accuracy", "N/A"))
# Format value nicely
if isinstance(metric_val, (int, float)):
val_str = f"{metric_val:.4f}" if metric_val < 1 else f"{metric_val:.2f}"
else:
val_str = str(metric_val)
print(f"✅ Model ready: {filename} ({metric_name}={val_str})")
return local_path
except Exception as e:
print(f"❌ Failed to get {filename}: {e}")
return None
def check_for_update(self, model_type: str, current_path: str, format: str = "onnx") -> Optional[str]:
"""
Check if a newer model is available.
Returns new path if available, or None if no update.
"""
# Force re-download of manifest (bypass internal flag, use cache control?)
# hf_hub_download handles ETag check, so calling it again is cheap.
try:
old_manifest = self._manifest.copy() if self._manifest else {}
# Re-fetch manifest (populating self._manifest with fresh data)
self._manifest_loaded = False
new_manifest = self._load_manifest()
if model_type not in new_manifest:
return None
model_info = new_manifest[model_type]
# Determine filename
if format == "onnx" and model_info.get("onnx_file"):
filename = model_info["onnx_file"]
elif model_info.get("pt_file"):
filename = model_info["pt_file"]
else:
return None
# Compare with current
# If we don't have a current path, treat as update
if not current_path:
return self.get_best_model(model_type, format)
if filename not in current_path:
print(f"🆕 New model detected: {filename} (Replacing {os.path.basename(current_path)})")
return self.get_best_model(model_type, format)
return None
except Exception as e:
print(f"⚠️ Check for update failed: {e}")
return None
def get_model_info(self, model_type: str) -> Optional[Dict[str, Any]]:
"""Get metadata about a model."""
return self._load_manifest().get(model_type)
def list_available_models(self) -> list:
"""List all available model types."""
return list(self._load_manifest().keys())
# Singleton instance
_loader: Optional[ModelLoader] = None
def get_model_loader() -> ModelLoader:
"""Get singleton ModelLoader."""
global _loader
if _loader is None:
_loader = ModelLoader()
return _loader
def get_best_model(model_type: str, format: str = "onnx") -> Optional[str]:
"""Get path to best model (downloads once, then cached)."""
return get_model_loader().get_best_model(model_type, format)
def check_for_model_update(model_type: str, current_path: str, format: str = "onnx") -> Optional[str]:
"""Check and fetch update if available."""
return get_model_loader().check_for_update(model_type, current_path, format)
|