WildnerveAI commited on
Commit
ff8d4e7
·
verified ·
1 Parent(s): 6ac39a9

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/adapter_layer.py +371 -0
  2. utils/load_model_weights.py +691 -0
utils/adapter_layer.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+ import pydantic # required
6
+ import importlib.util # required
7
+ from typing import Dict, Any, Optional, List, Tuple
8
+ from service_registry import registry, MODEL, PRETRAINED_MODEL, TOKENIZER
9
+
10
+ # Force low memory usage mode
11
+ os.environ["LOW_MEMORY_MODE"] = "1"
12
+
13
+ # Log versions and fail fast if missing
14
+ logger = logging.getLogger(__name__)
15
+ logger.info(f"Using pydantic v{pydantic.__version__}")
16
+
17
+ # Add proper codecarbon import handling
18
+ try:
19
+ import codecarbon
20
+ codecarbon_available = True
21
+ logger.info(f"Using codecarbon v{codecarbon.__version__}")
22
+ except ImportError:
23
+ codecarbon_available = False
24
+ logger.warning("codecarbon is not available - carbon tracking disabled")
25
+ # Create dummy class for compatibility
26
+ class DummyEmissionsTracker:
27
+ def __init__(self, *args, **kwargs): pass
28
+ def start(self): return self
29
+ def stop(self): return 0.0
30
+
31
+ class codecarbon:
32
+ __version__ = "unavailable"
33
+ EmissionsTracker = DummyEmissionsTracker
34
+
35
+ print(f"Successfully using installed dependencies - pydantic: {pydantic.__version__}, codecarbon: {'available' if codecarbon_available else 'unavailable'}")
36
+
37
+ # MEMORY OPTIMIZATION: Show current memory usage
38
+ def log_memory_usage():
39
+ try:
40
+ import psutil
41
+ process = psutil.Process(os.getpid())
42
+ memory_info = process.memory_info()
43
+ memory_mb = memory_info.rss / 1024 / 1024
44
+ logger.info(f"Current memory usage: {memory_mb:.2f} MB")
45
+ return memory_mb
46
+ except:
47
+ return 0
48
+
49
+ # Import dependency helpers
50
+ def is_module_available(module_name):
51
+ try:
52
+ importlib.util.find_spec(module_name)
53
+ return True
54
+ except ImportError:
55
+ return False
56
+
57
+ # More robust import for PromptAnalyzer
58
+ try:
59
+ from model_List import PromptAnalyzer
60
+ logger.info("Successfully imported PromptAnalyzer")
61
+ except ImportError as e:
62
+ logger.error(f"Error importing PromptAnalyzer: {e}")
63
+ # Create a minimal PromptAnalyzer class
64
+ class PromptAnalyzer:
65
+ def __init__(self, **kwargs):
66
+ self.logger = logging.getLogger(__name__)
67
+ self.predefined_topics = {
68
+ "programming": ["python", "java", "code"],
69
+ "general": ["weather", "hello", "chat"]
70
+ }
71
+
72
+ def analyze_prompt(self, prompt: str):
73
+ # Simple keyword-based routing
74
+ prompt_lower = prompt.lower()
75
+ for tech_word in self.predefined_topics.get("programming", []):
76
+ if tech_word in prompt_lower:
77
+ return "model_Custm", 0.8
78
+ return "model_PrTr", 0.6
79
+
80
+ # MEMORY OPTIMIZATION: Create basic PromptAnalyzer without loading models
81
+ class BasicPromptAnalyzer:
82
+ def __init__(self, **kwargs):
83
+ self.logger = logging.getLogger(__name__)
84
+ self.predefined_topics = {
85
+ "programming": ["python", "java", "code"],
86
+ "general": ["weather", "hello", "chat"]
87
+ }
88
+
89
+ def analyze_prompt(self, prompt: str):
90
+ # Simple keyword-based routing
91
+ prompt_lower = prompt.lower()
92
+ for tech_word in self.predefined_topics.get("programming", []):
93
+ if tech_word in prompt_lower:
94
+ return "model_Custm", 0.8
95
+ return "model_PrTr", 0.6
96
+
97
+ class WildnerveModelAdapter:
98
+ """Ultra-lightweight adapter layer for HF inference endpoints."""
99
+
100
+ def __init__(self, model_path: str):
101
+ self.model_path = model_path
102
+ self.tokenizer = None
103
+ self.model = None
104
+ self.model_loaded = False
105
+ logger.info(f"Creating adapter with path: {model_path}")
106
+
107
+ # Safe verification of model file existence
108
+ self._verify_model_files()
109
+
110
+ def _verify_model_files(self):
111
+ """Verify model files exist without loading them"""
112
+ script_dir = os.path.dirname(os.path.abspath(__file__))
113
+ model_files = ["model_Custm.py", "model_PrTr.py"]
114
+
115
+ self.available_models = {}
116
+ for filename in model_files:
117
+ filepath = os.path.join(script_dir, filename)
118
+ if os.path.exists(filepath):
119
+ module_name = filename.replace('.py', '')
120
+ self.available_models[module_name] = filepath
121
+ logger.info(f"Found model file: {filename}")
122
+
123
+ if not self.available_models:
124
+ logger.warning("No model files found - will use stub implementation")
125
+ # Create stub file if needed
126
+ stub_path = os.path.join(script_dir, "model_stub.py")
127
+ if not os.path.exists(stub_path):
128
+ try:
129
+ with open(stub_path, "w") as f:
130
+ f.write("""
131
+ # Minimal stub model
132
+ import torch.nn as nn
133
+ class Wildnerve_tlm01(nn.Module):
134
+ def __init__(self, **kwargs):
135
+ super().__init__()
136
+ self.is_stub = True
137
+ for key, value in kwargs.items():
138
+ setattr(self, key, value)
139
+ def generate(self, prompt=None, **kwargs):
140
+ return f"Stub model response for: {prompt[:30]}..."
141
+ """)
142
+ logger.info("Created stub model file")
143
+ except Exception as e:
144
+ logger.error(f"Failed to create stub model: {e}")
145
+
146
+ def generate(self, text_input, max_length=None, **kwargs):
147
+ """Generate text - with lazy model loading"""
148
+ try:
149
+ # 1. Load model if not already loaded
150
+ if not self.model_loaded:
151
+ logger.info("Loading model for first request")
152
+ self._lazy_load_model()
153
+
154
+ # 2. Let the model handle inference directly with NO pattern matching or rules
155
+ if self.model:
156
+ try:
157
+ logger.info(f"Sending prompt directly to neural model: {type(self.model).__name__}")
158
+ model_response = self.model.generate(
159
+ prompt=text_input,
160
+ max_length=max_length,
161
+ **kwargs
162
+ )
163
+
164
+ # Log response for debugging but don't intercept or alter it
165
+ logger.info(f"Model generated response of length {len(model_response) if isinstance(model_response, str) else 'unknown'}")
166
+
167
+ # Return the raw model response - let the model shine (or fail naturally)
168
+ return model_response
169
+
170
+ except Exception as e:
171
+ # Only log the error but don't substitute with rule-based responses
172
+ logger.error(f"Neural model inference error: {e}")
173
+ # Continue to basic fallback only if the model completely failed
174
+ else:
175
+ logger.warning("No model available - only basic response possible")
176
+
177
+ # 3. Minimal fallback ONLY if model couldn't be loaded or threw exception
178
+ if self.tokenizer:
179
+ return f"The model couldn't be properly initialized. Your input: '{text_input[:30]}...'"
180
+ return f"No language model available to process: '{text_input[:30]}...'"
181
+
182
+ except Exception as e:
183
+ logger.error(f"Critical error in generate method: {e}")
184
+ return f"An error occurred processing your request: {str(e)}"
185
+
186
+ def _lazy_load_model(self):
187
+ """Try to load a model on demand, with multiple fallback options"""
188
+ try:
189
+ logger.info("Attempting to load model on first request")
190
+
191
+ # First initialize tokenizer if not already done
192
+ self._initialize_minimal_tokenizer()
193
+
194
+ # Download and load model weights first with better logging
195
+ try:
196
+ from load_model_weights import download_model_files, load_weights_into_model, verify_token
197
+
198
+ # First verify token is available
199
+ token_verified = verify_token()
200
+ logger.info(f"HF Token verification: {token_verified}")
201
+
202
+ # Get weights from HF repository with more robust error reporting
203
+ logger.info("Downloading model weights...")
204
+ try:
205
+ # Try multiple repositories in priority order
206
+ repositories = [
207
+ "EvolphTech/Weights",
208
+ "Wildnerve/tlm-0.05Bx12",
209
+ "Wildnerve/tlm",
210
+ "EvolphTech/Checkpoints"
211
+ ]
212
+
213
+ weight_files = None
214
+ for repo in repositories:
215
+ logger.info(f"Attempting to download weights from {repo}...")
216
+ try:
217
+ weight_files = download_model_files(repo_id_base=repo)
218
+ if weight_files and "transformer" in weight_files:
219
+ logger.info(f"Successfully downloaded weights from {repo}")
220
+ break
221
+ except Exception as repo_error:
222
+ logger.warning(f"Failed to download from {repo}: {repo_error}")
223
+
224
+ # Add detailed logging about weight files
225
+ if weight_files:
226
+ logger.info(f"Download returned {len(weight_files)} weight files: {list(weight_files.keys())}")
227
+ else:
228
+ logger.warning("No weight files were returned from download_model_files")
229
+
230
+ except Exception as e:
231
+ logger.error(f"Error downloading weights: {str(e)}")
232
+ weight_files = {}
233
+ except ImportError:
234
+ logger.error("Could not import load_model_weights - missing dependencies?")
235
+ weight_files = {}
236
+
237
+ # Rest of model loading code (unchanged)
238
+ # Try to load model_Custm first
239
+ if "model_Custm" in self.available_models:
240
+ try:
241
+ logger.info("Trying to load model_Custm")
242
+ model_custm_spec = importlib.util.spec_from_file_location(
243
+ "model_Custm",
244
+ self.available_models["model_Custm"]
245
+ )
246
+ model_custm = importlib.util.module_from_spec(model_custm_spec)
247
+ model_custm_spec.loader.exec_module(model_custm)
248
+
249
+ if hasattr(model_custm, "Wildnerve_tlm01"):
250
+ logger.info("Creating Wildnerve_tlm01 from model_Custm")
251
+ model_class = getattr(model_custm, "Wildnerve_tlm01")
252
+
253
+ # Create model with safer config handling
254
+ try:
255
+ # Import config handling
256
+ from config import app_config
257
+ # Ensure config_data exists if app_config is a dict
258
+ if isinstance(app_config, dict) and "TRANSFORMER_CONFIG" in app_config:
259
+ if isinstance(app_config["TRANSFORMER_CONFIG"], dict) and "config_data" not in app_config["TRANSFORMER_CONFIG"]:
260
+ app_config["TRANSFORMER_CONFIG"]["config_data"] = app_config["TRANSFORMER_CONFIG"]
261
+ logger.info("Added config_data attribute to TRANSFORMER_CONFIG dictionary")
262
+ except Exception as config_error:
263
+ logger.warning(f"Config handling error: {config_error}")
264
+
265
+ self.model = model_class(
266
+ tokenizer=self.tokenizer,
267
+ vocab_size=50257, # GPT-2 vocab size
268
+ specialization="general",
269
+ embedding_dim=768,
270
+ num_heads=12,
271
+ hidden_dim=768,
272
+ num_layers=2, # Reduced for memory efficiency
273
+ output_size=50257, # Match GPT-2 vocab
274
+ dropout=0.1,
275
+ max_seq_length=128 # Reduced for memory
276
+ )
277
+
278
+ # Enhanced weight loading with detailed path information
279
+ if "transformer" in weight_files and weight_files["transformer"]:
280
+ weight_path = weight_files["transformer"]
281
+ logger.info(f"Loading weights from {weight_path}")
282
+ logger.info(f"Weight file exists: {os.path.exists(weight_path)}")
283
+ logger.info(f"Weight file size: {os.path.getsize(weight_path) / 1024 / 1024:.2f} MB")
284
+
285
+ success = load_weights_into_model(self.model, weight_path, strict=False)
286
+ if success:
287
+ logger.info("✅ Successfully loaded transformer weights")
288
+ else:
289
+ logger.warning("❌ Failed to load transformer weights")
290
+ else:
291
+ logger.warning("❌ No transformer weights found in weight_files")
292
+
293
+ logger.info("Successfully created custom model")
294
+ self.model_loaded = True
295
+ return
296
+ except Exception as e:
297
+ logger.error(f"Failed to load model_Custm: {e}")
298
+
299
+ # Try model_PrTr next
300
+ if "model_PrTr" in self.available_models:
301
+ try:
302
+ logger.info("Trying to load model_PrTr")
303
+ model_prtr_spec = importlib.util.spec_from_file_location(
304
+ "model_PrTr",
305
+ self.available_models["model_PrTr"]
306
+ )
307
+ model_prtr = importlib.util.module_from_spec(model_prtr_spec)
308
+ model_prtr_spec.loader.exec_module(model_prtr)
309
+
310
+ if hasattr(model_prtr, "Wildnerve_tlm01"):
311
+ logger.info("Creating Wildnerve_tlm01 from model_PrTr")
312
+ model_class = getattr(model_prtr, "Wildnerve_tlm01")
313
+ self.model = model_class(
314
+ tokenizer=self.tokenizer,
315
+ model_name="gpt2"
316
+ )
317
+ logger.info("Successfully created pretrained model")
318
+ self.model_loaded = True
319
+ return
320
+ except Exception as e:
321
+ logger.error(f"Failed to load model_PrTr: {e}")
322
+
323
+ # Try stub model as last resort
324
+ try:
325
+ logger.info("Trying to load model_stub")
326
+ script_dir = os.path.dirname(os.path.abspath(__file__))
327
+ stub_path = os.path.join(script_dir, "model_stub.py")
328
+
329
+ if os.path.exists(stub_path):
330
+ stub_spec = importlib.util.spec_from_file_location("model_stub", stub_path)
331
+ model_stub = importlib.util.module_from_spec(stub_spec)
332
+ stub_spec.loader.exec_module(model_stub)
333
+
334
+ if hasattr(model_stub, "Wildnerve_tlm01"):
335
+ logger.info("Creating stub model")
336
+ model_class = getattr(model_stub, "Wildnerve_tlm01")
337
+ self.model = model_class(
338
+ tokenizer=self.tokenizer,
339
+ specialization="stub"
340
+ )
341
+ logger.warning("Using STUB model - limited functionality")
342
+ self.model_loaded = True
343
+ return
344
+ except Exception as e:
345
+ logger.error(f"Failed to load stub model: {e}")
346
+
347
+ logger.error("All model loading attempts failed")
348
+
349
+ except Exception as e:
350
+ logger.error(f"Error in _lazy_load_model: {e}")
351
+ finally:
352
+ # Always mark as loaded to avoid repeated attempts
353
+ self.model_loaded = True
354
+
355
+ def _initialize_minimal_tokenizer(self):
356
+ """Initialize just the tokenizer, not the model"""
357
+ try:
358
+ from transformers import AutoTokenizer
359
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
360
+
361
+ # Fix for GPT-2 tokenizer: set pad_token to eos_token
362
+ if not self.tokenizer.pad_token:
363
+ self.tokenizer.pad_token = self.tokenizer.eos_token
364
+ logger.info("Set GPT-2 pad_token to eos_token")
365
+
366
+ logger.info("Initialized minimal tokenizer")
367
+ except Exception as e:
368
+ logger.error(f"Failed to initialize tokenizer: {e}")
369
+
370
+ # Add import for inspect at the top
371
+ import inspect
utils/load_model_weights.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Functions for downloading model weights from Hugging Face repositories.
3
+ """
4
+ import os
5
+ import sys
6
+ import time
7
+ import logging
8
+ import traceback
9
+ import torch # Add missing torch import
10
+ from pathlib import Path
11
+ from typing import Dict, Optional, Tuple, List, Any, Union
12
+ from urllib.error import HTTPError
13
+ from huggingface_hub import hf_hub_download, HfFileSystem, HfApi
14
+
15
+ # Add the current directory to Python's path to ensure modules are found
16
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
17
+
18
+ # Configure Logging
19
+ logger = logging.getLogger(__name__) # Fix typo: getLOgger -> getLogger
20
+
21
+ # Try local direct import first with fallback to a minimal version
22
+ try:
23
+ from model_repo_config import get_repo_config
24
+ logger.info("Successfully imported model_repo_config")
25
+ except ImportError:
26
+ logger.warning("model_repo_config module not found, using minimal implementation")
27
+
28
+ # Define minimal version inline as fallback
29
+ class MinimalRepoConfig:
30
+ """Minimal repository config for fallback"""
31
+ def __init__(self):
32
+ self.repo_id = "EvolphTech/Weights"
33
+ self.cache_dir = "/tmp/tlm_cache"
34
+ self.weight_locations = ["Wildnerve-tlm01-0.05Bx12.bin", "model.bin", "pytorch_model.bin"]
35
+ self.snn_weight_locations = ["stdp_model_epoch_30.bin", "snn_model.bin"]
36
+ self.default_repo = "EvolphTech/Weights"
37
+ self.alternative_paths = ["Wildnerve/tlm-0.05Bx12", "Wildnerve/tlm", "EvolphTech/Checkpoints"]
38
+ logger.info("Using minimal repository config")
39
+
40
+ def get_auth_token(self):
41
+ """Get authentication token from environment"""
42
+ return os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN")
43
+
44
+ def save_download_status(self, success, files):
45
+ """Minimal implementation that just logs"""
46
+ logger.info(f"Download status: success={success}, files={len(files) if files else 0}")
47
+
48
+ def get_repo_config():
49
+ """Get minimal repository config"""
50
+ return MinimalRepoConfig()
51
+
52
+ # Only set if not already set
53
+ if not os.environ.get("HF_TOKEN"):
54
+ os.environ["HF_TOKEN"] = "your_token_here" # Replace with your actual token
55
+
56
+ # Configure logging
57
+ logger = logging.getLogger(__name__)
58
+
59
+ def verify_token():
60
+ """Verify the HF token is available and properly formatted."""
61
+ token = os.environ.get("HF_TOKEN", os.environ.get("HF_API_TOKEN"))
62
+ if token:
63
+ token_length = len(token)
64
+ token_preview = token[:5] + "..." + token[-5:] if token_length > 10 else "too_short"
65
+ logger.info(f"HF Token found: length={token_length}, preview={token_preview}")
66
+
67
+ # Test if token works against a public Hugging Face API endpoint
68
+ try:
69
+ import requests
70
+ headers = {"Authorization": f"Bearer {token}"}
71
+ test_url = "https://huggingface.co/api/whoami"
72
+ response = requests.get(test_url, headers=headers, timeout=10)
73
+ if response.status_code == 200:
74
+ user_info = response.json()
75
+ logger.info(f"Token validated for user: {user_info.get('name', 'unknown')}")
76
+ return True
77
+ else:
78
+ logger.warning(f"Token validation failed: {response.status_code} - {response.text[:100]}")
79
+ except Exception as e:
80
+ logger.warning(f"Error testing token: {e}")
81
+
82
+ # Even if test fails, return True if we have a token
83
+ return True
84
+ else:
85
+ logger.error("❌ HF Token not found in environment variables!")
86
+ return False
87
+
88
+ # Call this early in the script or application startup
89
+ token_verified = verify_token()
90
+
91
+ def verify_repository(repo_id: str, token: Optional[str] = None) -> Tuple[bool, List[str]]:
92
+ """
93
+ Verify that a repository exists and is accessible.
94
+
95
+ Args:
96
+ repo_id: Repository ID to verify
97
+ token: Optional Hugging Face API token
98
+
99
+ Returns:
100
+ (success, files): Tuple of success flag and list of files
101
+ """
102
+ try:
103
+ # Try to list the repository contents
104
+ api = HfApi()
105
+ logger.info(f"Verifying access to repository: {repo_id}")
106
+
107
+ try:
108
+ files = api.list_repo_files(repo_id, token=token)
109
+ logger.info(f"Repository {repo_id} is accessible")
110
+ logger.info(f"Found {len(files)} files in repository")
111
+ return True, files
112
+
113
+ except Exception as e:
114
+ error_msg = str(e).lower()
115
+
116
+ if "not found" in error_msg or "404" in error_msg:
117
+ logger.error(f"Repository {repo_id} not found. Please check the name.")
118
+ return False, []
119
+ elif "unauthorized" in error_msg or "permission" in error_msg or "401" in error_msg:
120
+ if token:
121
+ logger.error(f"Authentication failed for repository {repo_id} despite token")
122
+ else:
123
+ logger.error(f"No token provided for private repository {repo_id}")
124
+ return False, []
125
+ else:
126
+ logger.error(f"Error accessing repository {repo_id}: {e}")
127
+ return False, []
128
+ except Exception as e:
129
+ logger.error(f"Unexpected error verifying repository {repo_id}: {e}")
130
+ return False, []
131
+
132
+ def download_file(repo_id: str, file_path: str, cache_dir: str, token: Optional[str] = None) -> Optional[str]:
133
+ """
134
+ Download a file from a Hugging Face repository with retry logic.
135
+
136
+ Args:
137
+ repo_id: Repository ID
138
+ file_path: Path to the file within the repository
139
+ cache_dir: Directory to save the file
140
+ token: Optional Hugging Face API token
141
+
142
+ Returns:
143
+ Path to the downloaded file if successful, None otherwise
144
+ """
145
+ max_retries = 3
146
+ for attempt in range(1, max_retries + 1):
147
+ try:
148
+ logger.info(f"Downloading {file_path} from {repo_id} (attempt {attempt}/{max_retries})...")
149
+ local_path = hf_hub_download(
150
+ repo_id=repo_id,
151
+ filename=file_path,
152
+ cache_dir=cache_dir,
153
+ force_download=attempt > 1,
154
+ token=token
155
+ )
156
+ logger.info(f"Successfully downloaded {file_path} to {local_path}")
157
+ return local_path
158
+ except Exception as e:
159
+ logger.warning(f"Failed to download {file_path} from {repo_id} (attempt {attempt}/{max_retries}): {e}")
160
+ if attempt == max_retries:
161
+ return None
162
+ time.sleep(1) # Wait before retry
163
+
164
+ def check_for_local_weights():
165
+ """Check if weights are available locally"""
166
+ # First check if we've already found weights (avoid redundant checks)
167
+ if os.environ.get("MODEL_WEIGHTS_FOUND") == "true" or os.environ.get("USING_LOCAL_WEIGHTS") == "true":
168
+ logger.info("Using previously found local weights")
169
+ return True
170
+
171
+ # Check for transformer weights
172
+ transformer_weights = os.environ.get("TLM_TRANSFORMER_WEIGHTS")
173
+ if transformer_weights and os.path.exists(transformer_weights):
174
+ logger.info(f"Found transformer weights locally at: {transformer_weights}")
175
+
176
+ # Check for SNN weights
177
+ snn_weights = os.environ.get("TLM_SNN_WEIGHTS")
178
+ if snn_weights and os.path.exists(snn_weights):
179
+ logger.info(f"Found SNN weights locally at: {snn_weights}")
180
+
181
+ # Set environment variable to indicate weights are found
182
+ os.environ["MODEL_WEIGHTS_FOUND"] = "true"
183
+ os.environ["USING_LOCAL_WEIGHTS"] = "true"
184
+ return True
185
+
186
+ # Check common paths for transformer weights
187
+ transformer_paths = [
188
+ "/app/Weights/Transformer/Wildnerve-tlm01-0.05Bx12.bin",
189
+ "/app/Weights/Wildnerve-tlm01-0.05Bx12.bin",
190
+ "/app/weights/Wildnerve-tlm01-0.05Bx12.bin",
191
+ "./Weights/Transformer/Wildnerve-tlm01-0.05Bx12.bin",
192
+ "./Weights/Wildnerve-tlm01-0.05Bx12.bin"
193
+ ]
194
+
195
+ for path in transformer_paths:
196
+ if os.path.exists(path):
197
+ logger.info(f"Found transformer weights at: {path}")
198
+ os.environ["TLM_TRANSFORMER_WEIGHTS"] = path
199
+ os.environ["MODEL_WEIGHTS_FOUND"] = "true"
200
+
201
+ # Check for SNN weights
202
+ snn_paths = [
203
+ "/app/Weights/SNN/stdp_model_epoch_30.bin",
204
+ "/app/Weights/stdp_model_epoch_30.bin",
205
+ "/app/weights/stdp_model_epoch_30.bin",
206
+ "./Weights/SNN/stdp_model_epoch_30.bin",
207
+ "./Weights/stdp_model_epoch_30.bin"
208
+ ]
209
+
210
+ for snn_path in snn_paths:
211
+ if os.path.exists(snn_path):
212
+ logger.info(f"Found SNN weights at: {snn_path}")
213
+ os.environ["TLM_SNN_WEIGHTS"] = snn_path
214
+ break
215
+
216
+ return True
217
+
218
+ return False
219
+
220
+ def load_model_weights(model=None):
221
+ """Load model weights from local files or download from repository."""
222
+ # Check for local model weights first
223
+ logger.info("Checking for local model weights...")
224
+ if check_for_local_weights():
225
+ logger.info("Using local weights, skipping repository download")
226
+ return {
227
+ "transformer": os.environ.get("TLM_TRANSFORMER_WEIGHTS"),
228
+ "snn": os.environ.get("TLM_SNN_WEIGHTS")
229
+ }
230
+
231
+ # Only attempt to download if no local weights
232
+ logger.info("No local weights found, attempting to download from repository")
233
+
234
+ # Get repository configuration
235
+ config = get_repo_config()
236
+ repo_id_base = config.repo_id
237
+ cache_dir = config.cache_dir
238
+ sub_dir = None
239
+
240
+ return download_model_files(repo_id_base, sub_dir, cache_dir)
241
+
242
+ def download_model_files(repo_id_base: str, sub_dir: Optional[str] = None,
243
+ cache_dir: Optional[str] = None) -> Dict[str, str]:
244
+ """
245
+ Download model files from a Hugging Face repository.
246
+
247
+ Args:
248
+ repo_id_base: Base repository ID
249
+ sub_dir: Optional subdirectory within the repository
250
+ cache_dir: Optional cache directory
251
+
252
+ Returns:
253
+ Dictionary of downloaded files (file_type: local_path)
254
+ """
255
+ # Get global configuration
256
+ config = get_repo_config()
257
+
258
+ # Use provided cache_dir or fall back to config's cache_dir
259
+ cache_dir = cache_dir or config.cache_dir
260
+
261
+ # Get authentication token if available
262
+ token = config.get_auth_token()
263
+
264
+ # Dictionary to store downloaded file paths
265
+ downloaded_files = {}
266
+
267
+ # FIRST: Check if weights exist locally in the current directory or app directory
268
+ local_weight_paths = [
269
+ "./Wildnerve-tlm01-0.05Bx12.bin",
270
+ "./weights/Wildnerve-tlm01-0.05Bx12.bin",
271
+ "./pytorch_model.bin",
272
+ "./model.bin",
273
+ "/app/Wildnerve-tlm01-0.05Bx12.bin", # For HF Spaces environment
274
+ "/app/weights/Wildnerve-tlm01-0.05Bx12.bin",
275
+ "/app/pytorch_model.bin"
276
+ ]
277
+
278
+ # Look for local weights first
279
+ logger.info("Checking for local model weights...")
280
+ for weight_path in local_weight_paths:
281
+ if os.path.exists(weight_path):
282
+ logger.info(f"Found local weights: {weight_path}")
283
+ downloaded_files["transformer"] = weight_path
284
+ # Try to find a config file too
285
+ local_config_paths = [
286
+ os.path.join(os.path.dirname(weight_path), "config.json"),
287
+ "./config.json",
288
+ "/app/config.json"
289
+ ]
290
+ for config_path in local_config_paths:
291
+ if os.path.exists(config_path):
292
+ downloaded_files["config"] = config_path
293
+ break
294
+
295
+ # Set environment variables
296
+ os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"]
297
+ if "config" in downloaded_files:
298
+ os.environ["TLM_CONFIG_PATH"] = downloaded_files["config"]
299
+
300
+ # Return early since we found local weights
301
+ logger.info(f"Using local weights: {weight_path}")
302
+ return downloaded_files
303
+
304
+ # If no local weights, continue with normal HF download procedure
305
+ logger.info("No local weights found, attempting to download from repository")
306
+
307
+ # Create full repository path (with subdir if provided)
308
+ repo_id = repo_id_base
309
+ if sub_dir:
310
+ # Remove any trailing slashes from repo_id and leading slashes from sub_dir
311
+ repo_id = repo_id_base.rstrip('/') + '/' + sub_dir.lstrip('/')
312
+
313
+ # First try the primary Wildnerve model repository
314
+ wildnerve_repo = "Wildnerve/tlm-0.05Bx12"
315
+ logger.info(f"Trying primary Wildnerve model repository: {wildnerve_repo}")
316
+
317
+ success, files = verify_repository(wildnerve_repo, token)
318
+ if success:
319
+ repo_id = wildnerve_repo
320
+ else:
321
+ # Verify repository exists and is accessible
322
+ success, files = verify_repository(repo_id, token)
323
+ if not success:
324
+ # Try alternatives
325
+ logger.info(f"Primary repository {repo_id} not accessible, trying alternatives")
326
+
327
+ # Try Wildnerve model repo variants first
328
+ wildnerve_variants = ["Wildnerve/tlm", "EvolphTech/Checkpoints"]
329
+ for wildnerve_alt in wildnerve_variants:
330
+ logger.info(f"Trying Wildnerve alternative: {wildnerve_alt}")
331
+ success, files = verify_repository(wildnerve_alt, token)
332
+ if success:
333
+ repo_id = wildnerve_alt
334
+ break
335
+
336
+ # If still not successful, try other fallbacks
337
+ if not success:
338
+ for alt_repo in config.alternative_paths:
339
+ logger.info(f"Trying alternative repository: {alt_repo}")
340
+ success, files = verify_repository(alt_repo, token)
341
+ if success:
342
+ repo_id = alt_repo
343
+ break
344
+
345
+ # Use default if all alternatives fail
346
+ if not success:
347
+ repo_id = config.default_repo
348
+ success, files = verify_repository(repo_id, token)
349
+
350
+ # Dictionary to store downloaded file paths
351
+ downloaded_files = {}
352
+
353
+ # Download configuration if available
354
+ try:
355
+ logger.info(f"Downloading config from {repo_id}...")
356
+ config_path = download_file(repo_id, "config.json", cache_dir, token)
357
+ if config_path:
358
+ downloaded_files["config"] = config_path
359
+ else:
360
+ logger.warning("Will use default config values")
361
+ except Exception as e:
362
+ logger.warning(f"Error downloading config: {e}")
363
+
364
+ # Download transformer weights
365
+ logger.info(f"Downloading transformer weights from {repo_id}...")
366
+ transformer_path = None
367
+
368
+ # First try the specific Wildnerve model file name
369
+ wildnerve_paths = ["Wildnerve-tlm01-0.05Bx12.bin", "model.bin", "pytorch_model.bin"]
370
+ for path in wildnerve_paths:
371
+ logger.info(f"Trying Wildnerve model path: {path}")
372
+ transformer_path = download_file(repo_id, path, cache_dir, token)
373
+ if transformer_path:
374
+ downloaded_files["transformer"] = transformer_path
375
+ break
376
+
377
+ # If that doesn't work, try the standard paths
378
+ if not transformer_path:
379
+ for path in config.weight_locations:
380
+ transformer_path = download_file(repo_id, path, cache_dir, token)
381
+ if transformer_path:
382
+ downloaded_files["transformer"] = transformer_path
383
+ break
384
+ logger.info(f"Trying path: {path}")
385
+
386
+ if not transformer_path:
387
+ logger.warning("No transformer weights found, trying public BERT model as fallback")
388
+ try:
389
+ # Try to download BERT weights
390
+ transformer_path = download_file(config.default_repo, "pytorch_model.bin", cache_dir, token)
391
+ if transformer_path:
392
+ downloaded_files["transformer"] = transformer_path
393
+ logger.info("Successfully downloaded fallback BERT model")
394
+ else:
395
+ # Additional fallbacks to try
396
+ for alt_repo in ["bert-base-uncased", "distilbert-base-uncased"]:
397
+ transformer_path = download_file(alt_repo, "pytorch_model.bin", cache_dir, token)
398
+ if transformer_path:
399
+ downloaded_files["transformer"] = transformer_path
400
+ logger.info(f"Successfully downloaded fallback model from {alt_repo}")
401
+ break
402
+ except Exception as e:
403
+ logger.error(f"Failed to download fallback model: {e}")
404
+
405
+ # Download SNN weights if transformer weights were found
406
+ if "transformer" in downloaded_files:
407
+ logger.info(f"Downloading SNN weights from {repo_id}...")
408
+ snn_path = None
409
+
410
+ for path in config.snn_weight_locations:
411
+ snn_path = download_file(repo_id, path, cache_dir, token)
412
+ if snn_path:
413
+ downloaded_files["snn"] = snn_path
414
+ break
415
+ logger.info(f"Trying path: {path}")
416
+
417
+ # Set environment variables for other modules to use
418
+ if "transformer" in downloaded_files:
419
+ os.environ["TLM_TRANSFORMER_WEIGHTS"] = downloaded_files["transformer"]
420
+ if "snn" in downloaded_files:
421
+ os.environ["TLM_SNN_WEIGHTS"] = downloaded_files["snn"]
422
+
423
+ # Save download status
424
+ config.save_download_status(bool(downloaded_files), downloaded_files)
425
+
426
+ return downloaded_files
427
+
428
+ def find_expanded_weights(base_weight_path, target_dim=768):
429
+ """
430
+ Find expanded weights in various potential locations based on the base weight path.
431
+
432
+ Args:
433
+ base_weight_path: Path to the original weights file
434
+ target_dim: Target embedding dimension to look for
435
+
436
+ Returns:
437
+ Path to expanded weights if found, otherwise None
438
+ """
439
+ if not base_weight_path:
440
+ return None
441
+
442
+ base_name = os.path.basename(base_weight_path)
443
+ base_stem, ext = os.path.splitext(base_name)
444
+ expanded_name = f"{base_stem}_expanded_{target_dim}{ext}"
445
+
446
+ # Check in common writable directories
447
+ common_dirs = [
448
+ "/tmp",
449
+ "/tmp/tlm_data",
450
+ os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data")
451
+ ]
452
+
453
+ # Also check the original directory
454
+ original_dir = os.path.dirname(base_weight_path)
455
+ if original_dir:
456
+ common_dirs.append(original_dir)
457
+
458
+ # Check each location
459
+ for directory in common_dirs:
460
+ if not directory:
461
+ continue
462
+
463
+ expanded_path = os.path.join(directory, expanded_name)
464
+ if os.path.exists(expanded_path):
465
+ logger.info(f"Found expanded weights at {expanded_path}")
466
+ return expanded_path
467
+
468
+ # Check just the base filename for absolute paths
469
+ if os.path.exists(expanded_name):
470
+ return expanded_name
471
+
472
+ return None
473
+
474
+ def load_weights_into_model(model, weights_path: str, strict: bool = False) -> bool:
475
+ """
476
+ Load weights from a file into a model.
477
+
478
+ Args:
479
+ model: The model to load weights into
480
+ weights_path: Path to the weights file
481
+ strict: Whether to strictly enforce that the keys in the weights file match the model
482
+
483
+ Returns:
484
+ bool: True if weights were successfully loaded, False otherwise
485
+ """
486
+ try:
487
+ logger.info(f"Loading weights from {weights_path}")
488
+
489
+ # Try expanded weights first
490
+ expanded_path = find_expanded_weights(weights_path)
491
+ if expanded_path:
492
+ logger.info(f"Using expanded weights: {expanded_path}")
493
+ weights_path = expanded_path
494
+
495
+ # Load the state dictionary
496
+ state_dict = torch.load(weights_path, map_location="cpu")
497
+
498
+ # If state_dict has nested structure, extract the actual model weights
499
+ if isinstance(state_dict, dict) and "model_state_dict" in state_dict:
500
+ state_dict = state_dict["model_state_dict"]
501
+ elif isinstance(state_dict, dict) and "state_dict" in state_dict:
502
+ state_dict = state_dict["state_dict"]
503
+
504
+ # Special handling for Wildnerve-tlm01-0.05Bx12 model
505
+ if "Wildnerve-tlm01" in str(model.__class__):
506
+ logger.info("Detected Wildnerve-tlm01 model, applying special weight loading")
507
+
508
+ # Check if keys need to be remapped
509
+ model_keys = dict(model.named_parameters())
510
+ state_dict_keys = set(state_dict.keys())
511
+
512
+ # Check key alignment
513
+ if not any(k in state_dict_keys for k in model_keys.keys()):
514
+ logger.info("Wildnerve model keys don't match state dict keys, attempting remapping")
515
+
516
+ # Create mapping for common Wildnerve model patterns
517
+ key_mappings = {
518
+ "embedding.weight": ["embeddings.word_embeddings.weight", "embedding.weight", "word_embeddings.weight"],
519
+ "pos_encoder.pe": ["position_embeddings.weight", "pos_encoder.pe", "pe"],
520
+ "transformer_encoder": ["encoder.layer", "transformer.encoder", "transformer_encoder"],
521
+ "classifier.weight": ["output.weight", "classifier.weight", "lm_head.weight"],
522
+ "classifier.bias": ["output.bias", "classifier.bias", "lm_head.bias"]
523
+ }
524
+
525
+ # Apply mappings
526
+ adapted_state_dict = {}
527
+ for target_key, source_keys in key_mappings.items():
528
+ for source_key in source_keys:
529
+ for sd_key in state_dict_keys:
530
+ if source_key in sd_key:
531
+ if target_key not in model_keys:
532
+ # Find a target key that's close enough
533
+ for mk in model_keys:
534
+ if target_key.split('.')[0] in mk:
535
+ adapted_state_dict[mk] = state_dict[sd_key]
536
+ break
537
+ else:
538
+ adapted_state_dict[target_key] = state_dict[sd_key]
539
+
540
+ # Try to load the remapped weights
541
+ if adapted_state_dict:
542
+ logger.info(f"Attempting to load with {len(adapted_state_dict)} remapped keys")
543
+ try:
544
+ missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
545
+ logger.info(f"Loaded remapped weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
546
+ return True
547
+ except Exception as e:
548
+ logger.error(f"Error loading remapped weights: {e}")
549
+
550
+ # Special handling for transformer models from Hugging Face
551
+ if all(k.startswith("bert.") or k.startswith("roberta.") or k.startswith("model.") for k in state_dict.keys()):
552
+ # Try to adapt the state dict keys to match our model
553
+ logger.info("Adapting pretrained Hugging Face transformer weights")
554
+ adapted_state_dict = {}
555
+
556
+ # Map expected model keys to state dict keys
557
+ key_mappings = {
558
+ # Common mappings for transformer models
559
+ "embedding.weight": ["embeddings.word_embeddings.weight", "bert.embeddings.word_embeddings.weight"],
560
+ "pos_encoder.pe": ["embeddings.position_embeddings.weight", "bert.embeddings.position_embeddings.weight"],
561
+ "transformer_encoder": ["encoder.layer", "bert.encoder.layer"],
562
+ "classifier.weight": ["cls.predictions.decoder.weight", "bert.pooler.dense.weight"],
563
+ "classifier.bias": ["cls.predictions.decoder.bias", "bert.pooler.dense.bias"]
564
+ }
565
+
566
+ # Try to map keys from state dict to model
567
+ model_keys = dict(model.named_parameters())
568
+
569
+ # First try exact matches
570
+ for target_key, source_keys in key_mappings.items():
571
+ for source_key in source_keys:
572
+ if source_key in state_dict:
573
+ adapted_state_dict[target_key] = state_dict[source_key]
574
+ break
575
+
576
+ # If we have very few matches, try partial matches
577
+ if len(adapted_state_dict) < len(model_keys) * 0.1:
578
+ logger.info("Using partial key matching for weights")
579
+ for model_key in model_keys:
580
+ for sd_key in state_dict:
581
+ # Skip keys already matched
582
+ if model_key in adapted_state_dict:
583
+ continue
584
+
585
+ # Try to find common substrings in the key names
586
+ key_parts = model_key.split('.')
587
+ sd_parts = sd_key.split('.')
588
+
589
+ # Check for common parts like "attention", "layer", etc.
590
+ common_parts = set(key_parts) & set(sd_parts)
591
+ if len(common_parts) > 0:
592
+ adapted_state_dict[model_key] = state_dict[sd_key]
593
+ break
594
+
595
+ # If we still don't have many matches, try direct loading with non-strict mode
596
+ if len(adapted_state_dict) < len(model_keys) * 0.5:
597
+ logger.warning(f"Could not adapt many keys ({len(adapted_state_dict)}/{len(model_keys)})")
598
+ logger.warning("Attempting to load original state dict with non-strict mode")
599
+ try:
600
+ # Load with non-strict mode to allow partial loading
601
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
602
+ logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
603
+ return True
604
+ except Exception as e:
605
+ logger.error(f"Error loading original state dict: {e}")
606
+ return False
607
+ else:
608
+ # Load adapted state dict
609
+ logger.info(f"Loading adapted state dict with {len(adapted_state_dict)} keys")
610
+ try:
611
+ missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
612
+ logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
613
+ return True
614
+ except Exception as e:
615
+ logger.error(f"Error loading adapted state dict: {e}")
616
+ return False
617
+ else:
618
+ # Standard loading
619
+ try:
620
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
621
+ logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
622
+ return True
623
+ except Exception as e:
624
+ logger.error(f"Error loading state dict: {e}")
625
+
626
+ # Try non-strict loading if strict failed
627
+ if strict:
628
+ logger.info("Attempting non-strict loading")
629
+ try:
630
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
631
+ logger.info(f"Loaded weights with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
632
+ return True
633
+ except Exception as ne:
634
+ logger.error(f"Non-strict loading also failed: {ne}")
635
+
636
+ return False
637
+ except Exception as e:
638
+ logger.error(f"Failed to load weights: {e}")
639
+ return False
640
+
641
+ def list_model_files(repo_id: str, token: Optional[str] = None) -> List[str]:
642
+ """
643
+ List model files in a repository.
644
+
645
+ Args:
646
+ repo_id: Repository ID
647
+ token: Optional Hugging Face API token
648
+
649
+ Returns:
650
+ List of file paths
651
+ """
652
+ try:
653
+ api = HfApi()
654
+ files = api.list_repo_files(repo_id, token=token)
655
+
656
+ # Filter for model files
657
+ model_files = [f for f in files if f.endswith('.bin') or f.endswith('.pt') or f.endswith('.pth')]
658
+ logger.info(f"Found {len(model_files)} model files in {repo_id}")
659
+
660
+ return model_files
661
+ except Exception as e:
662
+ logger.error(f"Error listing model files in {repo_id}: {e}")
663
+ return []
664
+
665
+ if __name__ == "__main__":
666
+ # Configure logging
667
+ logging.basicConfig(level=logging.INFO)
668
+
669
+ # Get arguments
670
+ import argparse
671
+ parser = argparse.ArgumentParser(description="Download model weights")
672
+ parser.add_argument("--repo-id", type=str, default=None, help="Repository ID")
673
+ parser.add_argument("--sub-dir", type=str, default=None, help="Subdirectory within repository")
674
+ parser.add_argument("--cache-dir", type=str, default=None, help="Cache directory")
675
+ args = parser.parse_args()
676
+
677
+ # Download model files
678
+ repo_id = args.repo_id or os.environ.get("MODEL_REPO") or get_repo_config().repo_id
679
+ result = download_model_files(repo_id, args.sub_dir, args.cache_dir)
680
+
681
+ # Print results
682
+ print(f"\nDownload Results:")
683
+ if "transformer" in result:
684
+ print(f"Transformer weights: {result['transformer']}")
685
+ else:
686
+ print(f"⚠️ No transformer weights downloaded")
687
+
688
+ if "snn" in result:
689
+ print(f"SNN weights: {result['snn']}")
690
+ else:
691
+ print(f"⚠️ No SNN weights downloaded")