Trouter-Library commited on
Commit
545ba0c
·
verified ·
1 Parent(s): 1b822b1

Create inference/load_model.py

Browse files
Files changed (1) hide show
  1. inference/load_model.py +370 -0
inference/load_model.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-OSC Sharded Model Loader
3
+ Efficiently loads 116 safetensors shards (2.8GB each)
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Dict, Optional, List
11
+ import logging
12
+ from tqdm import tqdm
13
+ from safetensors.torch import load_file
14
+ from transformers import AutoConfig, AutoTokenizer
15
+ import psutil
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ShardedModelLoader:
22
+ """
23
+ Loader for sharded safetensors model files
24
+ Optimized for 116 shards of 2.8GB each
25
+ """
26
+
27
+ def __init__(self, model_path: str):
28
+ """
29
+ Initialize the sharded model loader
30
+
31
+ Args:
32
+ model_path: Path to the inference directory containing shards
33
+ """
34
+ self.model_path = Path(model_path)
35
+ self.config_path = self.model_path / "config.json"
36
+ self.index_path = self.model_path / "model.safetensors.index.json"
37
+
38
+ # Load configuration
39
+ logger.info(f"Loading configuration from {self.config_path}")
40
+ with open(self.config_path, 'r') as f:
41
+ self.config = json.load(f)
42
+
43
+ # Load weight index
44
+ logger.info(f"Loading weight index from {self.index_path}")
45
+ with open(self.index_path, 'r') as f:
46
+ self.index = json.load(f)
47
+
48
+ self.metadata = self.index.get("metadata", {})
49
+ self.weight_map = self.index.get("weight_map", {})
50
+
51
+ logger.info(f"Model: {self.metadata.get('model_type', 'unknown')}")
52
+ logger.info(f"Total shards: {self.metadata.get('total_shards', 0)}")
53
+ logger.info(f"Total size: {self.metadata.get('total_size', 0) / 1e9:.2f} GB")
54
+ logger.info(f"Total parameters: {self.config['architectures_info']['total_parameters']}")
55
+ logger.info(f"Active parameters: {self.config['architectures_info']['active_parameters']}")
56
+
57
+ def get_shard_path(self, shard_name: str) -> Path:
58
+ """Get full path to a shard file"""
59
+ return self.model_path / shard_name
60
+
61
+ def get_available_memory(self) -> Dict[str, float]:
62
+ """Get available system memory"""
63
+ memory = psutil.virtual_memory()
64
+ result = {
65
+ "ram_total_gb": memory.total / 1e9,
66
+ "ram_available_gb": memory.available / 1e9,
67
+ "ram_percent_used": memory.percent
68
+ }
69
+
70
+ if torch.cuda.is_available():
71
+ for i in range(torch.cuda.device_count()):
72
+ gpu_mem = torch.cuda.get_device_properties(i).total_memory
73
+ gpu_allocated = torch.cuda.memory_allocated(i)
74
+ result[f"gpu_{i}_total_gb"] = gpu_mem / 1e9
75
+ result[f"gpu_{i}_available_gb"] = (gpu_mem - gpu_allocated) / 1e9
76
+
77
+ return result
78
+
79
+ def load_shard(self, shard_name: str, device: str = "cpu") -> Dict[str, torch.Tensor]:
80
+ """
81
+ Load a single shard file
82
+
83
+ Args:
84
+ shard_name: Name of the shard file
85
+ device: Device to load tensors to
86
+
87
+ Returns:
88
+ Dictionary of weight tensors
89
+ """
90
+ shard_path = self.get_shard_path(shard_name)
91
+
92
+ if not shard_path.exists():
93
+ raise FileNotFoundError(f"Shard not found: {shard_path}")
94
+
95
+ logger.debug(f"Loading shard: {shard_name}")
96
+ return load_file(str(shard_path), device=device)
97
+
98
+ def load_sharded_weights(
99
+ self,
100
+ device: str = "cpu",
101
+ low_memory: bool = False,
102
+ show_progress: bool = True
103
+ ) -> Dict[str, torch.Tensor]:
104
+ """
105
+ Load all sharded weights
106
+
107
+ Args:
108
+ device: Device to load weights to
109
+ low_memory: Use memory-efficient loading
110
+ show_progress: Show progress bar
111
+
112
+ Returns:
113
+ Dictionary of all model weights
114
+ """
115
+ logger.info("Loading sharded model weights...")
116
+
117
+ # Check available memory
118
+ mem_info = self.get_available_memory()
119
+ logger.info(f"Available RAM: {mem_info['ram_available_gb']:.2f} GB")
120
+ if "gpu_0_available_gb" in mem_info:
121
+ logger.info(f"Available GPU 0: {mem_info['gpu_0_available_gb']:.2f} GB")
122
+
123
+ # Get unique shard files
124
+ shard_files = sorted(set(self.weight_map.values()))
125
+ total_shards = len(shard_files)
126
+
127
+ logger.info(f"Loading {total_shards} shard files...")
128
+
129
+ all_weights = {}
130
+
131
+ # Create progress bar
132
+ pbar = tqdm(shard_files, disable=not show_progress, desc="Loading shards")
133
+
134
+ for shard_name in pbar:
135
+ pbar.set_description(f"Loading {shard_name}")
136
+
137
+ # Load shard
138
+ shard_weights = self.load_shard(shard_name, device=device)
139
+
140
+ # Add to all weights
141
+ all_weights.update(shard_weights)
142
+
143
+ # Clear memory if low_memory mode
144
+ if low_memory:
145
+ del shard_weights
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ logger.info(f"Loaded {len(all_weights)} weight tensors")
150
+ return all_weights
151
+
152
+ def get_layer_weights(self, layer_idx: int) -> List[str]:
153
+ """
154
+ Get all weight keys for a specific layer
155
+
156
+ Args:
157
+ layer_idx: Layer index
158
+
159
+ Returns:
160
+ List of weight keys for that layer
161
+ """
162
+ prefix = f"model.layers.{layer_idx}."
163
+ return [k for k in self.weight_map.keys() if k.startswith(prefix)]
164
+
165
+ def get_shard_for_weight(self, weight_key: str) -> Optional[str]:
166
+ """
167
+ Get shard file name for a specific weight
168
+
169
+ Args:
170
+ weight_key: Weight key/name
171
+
172
+ Returns:
173
+ Shard file name or None
174
+ """
175
+ return self.weight_map.get(weight_key)
176
+
177
+ def verify_shards(self) -> Dict[str, bool]:
178
+ """
179
+ Verify all shard files exist
180
+
181
+ Returns:
182
+ Dictionary mapping shard names to existence status
183
+ """
184
+ logger.info("Verifying shard files...")
185
+
186
+ shard_files = set(self.weight_map.values())
187
+ verification = {}
188
+
189
+ for shard_name in tqdm(sorted(shard_files), desc="Verifying"):
190
+ shard_path = self.get_shard_path(shard_name)
191
+ verification[shard_name] = shard_path.exists()
192
+
193
+ missing = [s for s, exists in verification.items() if not exists]
194
+
195
+ if missing:
196
+ logger.warning(f"Missing {len(missing)} shard files:")
197
+ for shard in missing[:10]: # Show first 10
198
+ logger.warning(f" - {shard}")
199
+ if len(missing) > 10:
200
+ logger.warning(f" ... and {len(missing) - 10} more")
201
+ else:
202
+ logger.info("✓ All shard files present")
203
+
204
+ return verification
205
+
206
+ def load_metadata(self) -> Dict:
207
+ """Load model metadata"""
208
+ return {
209
+ "config": self.config,
210
+ "index": self.index,
211
+ "total_shards": self.metadata.get("total_shards", 0),
212
+ "total_size_gb": self.metadata.get("total_size", 0) / 1e9,
213
+ "architecture": self.config.get("architectures_info", {}),
214
+ "num_layers": self.config.get("num_hidden_layers", 0),
215
+ "hidden_size": self.config.get("hidden_size", 0),
216
+ "vocab_size": self.config.get("vocab_size", 0)
217
+ }
218
+
219
+
220
+ def load_full_model(
221
+ model_path: str,
222
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
223
+ low_memory: bool = False
224
+ ):
225
+ """
226
+ Convenience function to load the full model
227
+
228
+ Args:
229
+ model_path: Path to inference directory
230
+ device: Device to load model to
231
+ low_memory: Use low memory loading
232
+
233
+ Returns:
234
+ Loaded model weights and metadata
235
+ """
236
+ loader = ShardedModelLoader(model_path)
237
+
238
+ # Verify shards first
239
+ verification = loader.verify_shards()
240
+ missing = sum(1 for exists in verification.values() if not exists)
241
+
242
+ if missing > 0:
243
+ raise FileNotFoundError(
244
+ f"Cannot load model: {missing} shard files are missing. "
245
+ f"Please download all 116 shard files."
246
+ )
247
+
248
+ # Load weights
249
+ weights = loader.load_sharded_weights(
250
+ device=device,
251
+ low_memory=low_memory,
252
+ show_progress=True
253
+ )
254
+
255
+ # Load metadata
256
+ metadata = loader.load_metadata()
257
+
258
+ return weights, metadata
259
+
260
+
261
+ def inspect_model(model_path: str):
262
+ """
263
+ Inspect model structure without loading weights
264
+
265
+ Args:
266
+ model_path: Path to inference directory
267
+ """
268
+ loader = ShardedModelLoader(model_path)
269
+
270
+ print("\n" + "="*80)
271
+ print("HELION-OSC MODEL INSPECTION")
272
+ print("="*80)
273
+
274
+ metadata = loader.load_metadata()
275
+
276
+ print(f"\nModel Type: {metadata['architecture'].get('model_description', 'N/A')}")
277
+ print(f"Architecture: {metadata['architecture'].get('architecture_type', 'N/A')}")
278
+ print(f"Total Parameters: {metadata['architecture'].get('total_parameters', 'N/A')}")
279
+ print(f"Active Parameters: {metadata['architecture'].get('active_parameters', 'N/A')}")
280
+
281
+ print(f"\nModel Configuration:")
282
+ print(f" Layers: {metadata['num_layers']}")
283
+ print(f" Hidden Size: {metadata['hidden_size']}")
284
+ print(f" Vocabulary Size: {metadata['vocab_size']}")
285
+ print(f" Attention Heads: {metadata['config'].get('num_attention_heads', 'N/A')}")
286
+ print(f" KV Heads: {metadata['config'].get('num_key_value_heads', 'N/A')}")
287
+
288
+ print(f"\nMoE Configuration:")
289
+ arch = metadata['architecture']
290
+ print(f" Number of Experts: {arch.get('num_experts', 'N/A')}")
291
+ print(f" Experts per Token: {arch.get('experts_per_token', 'N/A')}")
292
+ print(f" Shared Experts: {arch.get('num_shared_experts', 'N/A')}")
293
+
294
+ print(f"\nStorage Information:")
295
+ print(f" Total Shards: {metadata['total_shards']}")
296
+ print(f" Total Size: {metadata['total_size_gb']:.2f} GB")
297
+ print(f" Shard Size: ~2.8 GB each")
298
+ print(f" Format: safetensors")
299
+ print(f" Precision: bfloat16")
300
+
301
+ print(f"\nContext Length:")
302
+ print(f" Max Position Embeddings: {metadata['config'].get('max_position_embeddings', 'N/A')}")
303
+ print(f" RoPE Theta: {metadata['config'].get('rope_theta', 'N/A')}")
304
+
305
+ print("\n" + "="*80)
306
+
307
+ # Verify shards
308
+ print("\nVerifying shard files...")
309
+ verification = loader.verify_shards()
310
+ present = sum(1 for exists in verification.values() if exists)
311
+ total = len(verification)
312
+
313
+ print(f"\nShard Status: {present}/{total} files present")
314
+
315
+ if present == total:
316
+ print("✓ All shard files are available")
317
+ else:
318
+ print(f"✗ Missing {total - present} shard files")
319
+
320
+
321
+ def main():
322
+ """Main CLI interface"""
323
+ import argparse
324
+
325
+ parser = argparse.ArgumentParser(description="Helion-OSC Sharded Model Loader")
326
+ parser.add_argument(
327
+ "model_path",
328
+ type=str,
329
+ help="Path to inference directory"
330
+ )
331
+ parser.add_argument(
332
+ "--action",
333
+ choices=["inspect", "verify", "load"],
334
+ default="inspect",
335
+ help="Action to perform"
336
+ )
337
+ parser.add_argument(
338
+ "--device",
339
+ type=str,
340
+ default="cuda" if torch.cuda.is_available() else "cpu",
341
+ help="Device to load model to"
342
+ )
343
+ parser.add_argument(
344
+ "--low-memory",
345
+ action="store_true",
346
+ help="Use low memory mode"
347
+ )
348
+
349
+ args = parser.parse_args()
350
+
351
+ if args.action == "inspect":
352
+ inspect_model(args.model_path)
353
+
354
+ elif args.action == "verify":
355
+ loader = ShardedModelLoader(args.model_path)
356
+ loader.verify_shards()
357
+
358
+ elif args.action == "load":
359
+ logger.info("Loading full model...")
360
+ weights, metadata = load_full_model(
361
+ args.model_path,
362
+ device=args.device,
363
+ low_memory=args.low_memory
364
+ )
365
+ logger.info(f"Successfully loaded {len(weights)} weight tensors")
366
+ logger.info(f"Model ready on {args.device}")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()