File size: 6,555 Bytes
d81d817
eac84cc
 
 
d81d817
 
 
 
 
 
 
 
 
b543f77
eac84cc
 
d81d817
 
 
 
eac84cc
 
 
d81d817
 
eac84cc
d81d817
 
 
 
 
 
eac84cc
d81d817
 
 
 
 
 
 
 
eac84cc
 
d81d817
 
 
eac84cc
d81d817
 
 
 
 
 
 
 
 
eac84cc
d81d817
 
eac84cc
d81d817
 
 
eac84cc
d81d817
 
 
 
eac84cc
d81d817
 
 
 
eac84cc
d81d817
 
eac84cc
d81d817
 
eac84cc
 
 
d81d817
 
eac84cc
d81d817
 
 
 
 
eac84cc
d81d817
dfbcfcc
 
 
 
 
 
 
 
 
 
 
d81d817
 
eac84cc
d81d817
dfbcfcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d81d817
 
eac84cc
 
d81d817
 
 
eac84cc
d81d817
 
eac84cc
 
d81d817
 
 
eac84cc
 
 
 
 
d81d817
 
 
eac84cc
d81d817
dfbcfcc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
ModelLoader - Smart Model Sharing between Spaces
Downloads models ONCE from gionuibk/NautilusModels and caches locally.
Subsequent calls use cached version (no re-download).
"""
import os
import json
from pathlib import Path
from typing import Optional, Dict, Any
from huggingface_hub import HfApi, hf_hub_download


HF_MODEL_REPO = "gionuibk/NautilusModels"
print("📦 ModelLoader Module Initialized") # Force Upload Hash Change
# Use HF's default cache - files are cached and only re-downloaded if changed
LOCAL_CACHE = None  # Let hf_hub_download use default cache


class ModelLoader:
    """
    Smart model loader with caching.
    - Downloads from HF only if not cached or file changed
    - Uses local cache for fast loading
    """
    
    def __init__(self, token: str = None):
        self.token = token or os.environ.get("HF_TOKEN")
        self.api = HfApi(token=self.token)
        self._manifest = None
        self._manifest_loaded = False
    
    def _load_manifest(self) -> Dict[str, Any]:
        """Load best_models.json manifest from HuggingFace (cached)."""
        if self._manifest_loaded:
            return self._manifest or {}
            
        try:
            manifest_path = hf_hub_download(
                repo_id=HF_MODEL_REPO,
                filename="best_models.json",
                repo_type="model",
                token=self.token
                # Uses HF default cache - fast if already downloaded
            )
            with open(manifest_path, 'r') as f:
                self._manifest = json.load(f)
            print(f"📋 Model manifest loaded: {list(self._manifest.keys())}")
        except Exception as e:
            print(f"⚠️ Could not load best_models.json: {e}")
            self._manifest = {}
        
        self._manifest_loaded = True
        return self._manifest
    
    def get_best_model(self, model_type: str, format: str = "onnx") -> Optional[str]:
        """
        Get path to the best model. Downloads once, then uses cache.
        
        Args:
            model_type: "deeplob", "trm", "lstm", etc.
            format: "onnx" or "pt"
            
        Returns:
            Local path to cached model file
        """
        manifest = self._load_manifest()
        
        if model_type not in manifest:
            print(f"⚠️ No model found for: {model_type}")
            return None
        
        model_info = manifest[model_type]
        
        # Prefer ONNX, fallback to PT
        if format == "onnx" and model_info.get("onnx_file"):
            filename = model_info["onnx_file"]
        elif model_info.get("pt_file"):
            filename = model_info["pt_file"]
            if format == "onnx":
                print(f"⚠️ ONNX not available for {model_type}, using PT")
        else:
            print(f"⚠️ No model file for {model_type}")
            return None
        
        # Download (or use cached version)
        try:
            local_path = hf_hub_download(
                repo_id=HF_MODEL_REPO,
                filename=filename,
                repo_type="model",
                token=self.token
            )
            # Dynamic Metric Logging
            metric_name = model_info.get("metric_name", "acc")
            metric_val = model_info.get("metric_value", model_info.get("accuracy", "N/A"))
            
            # Format value nicely
            if isinstance(metric_val, (int, float)):
                val_str = f"{metric_val:.4f}" if metric_val < 1 else f"{metric_val:.2f}"
            else:
                val_str = str(metric_val)
                
            print(f"✅ Model ready: {filename} ({metric_name}={val_str})")
            return local_path
        except Exception as e:
            print(f"❌ Failed to get {filename}: {e}")
            return None

    def check_for_update(self, model_type: str, current_path: str, format: str = "onnx") -> Optional[str]:
        """
        Check if a newer model is available.
        Returns new path if available, or None if no update.
        """
        # Force re-download of manifest (bypass internal flag, use cache control?)
        # hf_hub_download handles ETag check, so calling it again is cheap.
        try:
            old_manifest = self._manifest.copy() if self._manifest else {}
            
            # Re-fetch manifest (populating self._manifest with fresh data)
            self._manifest_loaded = False 
            new_manifest = self._load_manifest()
            
            if model_type not in new_manifest:
                return None
                
            model_info = new_manifest[model_type]
            
            # Determine filename
            if format == "onnx" and model_info.get("onnx_file"):
                filename = model_info["onnx_file"]
            elif model_info.get("pt_file"):
                filename = model_info["pt_file"]
            else:
                return None
                
            # Compare with current
            # If we don't have a current path, treat as update
            if not current_path:
                return self.get_best_model(model_type, format)
                
            if filename not in current_path:
                print(f"🆕 New model detected: {filename} (Replacing {os.path.basename(current_path)})")
                return self.get_best_model(model_type, format)
                
            return None
            
        except Exception as e:
            print(f"⚠️ Check for update failed: {e}")
            return None
    
    def get_model_info(self, model_type: str) -> Optional[Dict[str, Any]]:
        """Get metadata about a model."""
        return self._load_manifest().get(model_type)
    
    def list_available_models(self) -> list:
        """List all available model types."""
        return list(self._load_manifest().keys())


# Singleton instance
_loader: Optional[ModelLoader] = None


def get_model_loader() -> ModelLoader:
    """Get singleton ModelLoader."""
    global _loader
    if _loader is None:
        _loader = ModelLoader()
    return _loader


def get_best_model(model_type: str, format: str = "onnx") -> Optional[str]:
    """Get path to best model (downloads once, then cached)."""
    return get_model_loader().get_best_model(model_type, format)

def check_for_model_update(model_type: str, current_path: str, format: str = "onnx") -> Optional[str]:
    """Check and fetch update if available."""
    return get_model_loader().check_for_update(model_type, current_path, format)