File size: 12,239 Bytes
545ba0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5025ae
545ba0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
"""
Helion-OSC Sharded Model Loader
Efficiently loads 116 safetensors shards (2.8GB each)
"""

import torch
import json
import os
from pathlib import Path
from typing import Dict, Optional, List
import logging
from tqdm import tqdm
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer
import psutil

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class ShardedModelLoader:
    """
    Loader for sharded safetensors model files
    Optimized for 116 shards of 2.8GB each
    """
    
    def __init__(self, model_path: str):
        """
        Initialize the sharded model loader
        
        Args:
            model_path: Path to the inference directory containing shards
        """
        self.model_path = Path(model_path)
        self.config_path = self.model_path / "model_config.json"
        self.index_path = self.model_path / "model.safetensors.index.json"
        
        # Load configuration
        logger.info(f"Loading configuration from {self.config_path}")
        with open(self.config_path, 'r') as f:
            self.config = json.load(f)
        
        # Load weight index
        logger.info(f"Loading weight index from {self.index_path}")
        with open(self.index_path, 'r') as f:
            self.index = json.load(f)
        
        self.metadata = self.index.get("metadata", {})
        self.weight_map = self.index.get("weight_map", {})
        
        logger.info(f"Model: {self.metadata.get('model_type', 'unknown')}")
        logger.info(f"Total shards: {self.metadata.get('total_shards', 0)}")
        logger.info(f"Total size: {self.metadata.get('total_size', 0) / 1e9:.2f} GB")
        logger.info(f"Total parameters: {self.config['architectures_info']['total_parameters']}")
        logger.info(f"Active parameters: {self.config['architectures_info']['active_parameters']}")
    
    def get_shard_path(self, shard_name: str) -> Path:
        """Get full path to a shard file"""
        return self.model_path / shard_name
    
    def get_available_memory(self) -> Dict[str, float]:
        """Get available system memory"""
        memory = psutil.virtual_memory()
        result = {
            "ram_total_gb": memory.total / 1e9,
            "ram_available_gb": memory.available / 1e9,
            "ram_percent_used": memory.percent
        }
        
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                gpu_mem = torch.cuda.get_device_properties(i).total_memory
                gpu_allocated = torch.cuda.memory_allocated(i)
                result[f"gpu_{i}_total_gb"] = gpu_mem / 1e9
                result[f"gpu_{i}_available_gb"] = (gpu_mem - gpu_allocated) / 1e9
        
        return result
    
    def load_shard(self, shard_name: str, device: str = "cpu") -> Dict[str, torch.Tensor]:
        """
        Load a single shard file
        
        Args:
            shard_name: Name of the shard file
            device: Device to load tensors to
            
        Returns:
            Dictionary of weight tensors
        """
        shard_path = self.get_shard_path(shard_name)
        
        if not shard_path.exists():
            raise FileNotFoundError(f"Shard not found: {shard_path}")
        
        logger.debug(f"Loading shard: {shard_name}")
        return load_file(str(shard_path), device=device)
    
    def load_sharded_weights(
        self,
        device: str = "cpu",
        low_memory: bool = False,
        show_progress: bool = True
    ) -> Dict[str, torch.Tensor]:
        """
        Load all sharded weights
        
        Args:
            device: Device to load weights to
            low_memory: Use memory-efficient loading
            show_progress: Show progress bar
            
        Returns:
            Dictionary of all model weights
        """
        logger.info("Loading sharded model weights...")
        
        # Check available memory
        mem_info = self.get_available_memory()
        logger.info(f"Available RAM: {mem_info['ram_available_gb']:.2f} GB")
        if "gpu_0_available_gb" in mem_info:
            logger.info(f"Available GPU 0: {mem_info['gpu_0_available_gb']:.2f} GB")
        
        # Get unique shard files
        shard_files = sorted(set(self.weight_map.values()))
        total_shards = len(shard_files)
        
        logger.info(f"Loading {total_shards} shard files...")
        
        all_weights = {}
        
        # Create progress bar
        pbar = tqdm(shard_files, disable=not show_progress, desc="Loading shards")
        
        for shard_name in pbar:
            pbar.set_description(f"Loading {shard_name}")
            
            # Load shard
            shard_weights = self.load_shard(shard_name, device=device)
            
            # Add to all weights
            all_weights.update(shard_weights)
            
            # Clear memory if low_memory mode
            if low_memory:
                del shard_weights
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        logger.info(f"Loaded {len(all_weights)} weight tensors")
        return all_weights
    
    def get_layer_weights(self, layer_idx: int) -> List[str]:
        """
        Get all weight keys for a specific layer
        
        Args:
            layer_idx: Layer index
            
        Returns:
            List of weight keys for that layer
        """
        prefix = f"model.layers.{layer_idx}."
        return [k for k in self.weight_map.keys() if k.startswith(prefix)]
    
    def get_shard_for_weight(self, weight_key: str) -> Optional[str]:
        """
        Get shard file name for a specific weight
        
        Args:
            weight_key: Weight key/name
            
        Returns:
            Shard file name or None
        """
        return self.weight_map.get(weight_key)
    
    def verify_shards(self) -> Dict[str, bool]:
        """
        Verify all shard files exist
        
        Returns:
            Dictionary mapping shard names to existence status
        """
        logger.info("Verifying shard files...")
        
        shard_files = set(self.weight_map.values())
        verification = {}
        
        for shard_name in tqdm(sorted(shard_files), desc="Verifying"):
            shard_path = self.get_shard_path(shard_name)
            verification[shard_name] = shard_path.exists()
        
        missing = [s for s, exists in verification.items() if not exists]
        
        if missing:
            logger.warning(f"Missing {len(missing)} shard files:")
            for shard in missing[:10]:  # Show first 10
                logger.warning(f"  - {shard}")
            if len(missing) > 10:
                logger.warning(f"  ... and {len(missing) - 10} more")
        else:
            logger.info("✓ All shard files present")
        
        return verification
    
    def load_metadata(self) -> Dict:
        """Load model metadata"""
        return {
            "config": self.config,
            "index": self.index,
            "total_shards": self.metadata.get("total_shards", 0),
            "total_size_gb": self.metadata.get("total_size", 0) / 1e9,
            "architecture": self.config.get("architectures_info", {}),
            "num_layers": self.config.get("num_hidden_layers", 0),
            "hidden_size": self.config.get("hidden_size", 0),
            "vocab_size": self.config.get("vocab_size", 0)
        }


def load_full_model(
    model_path: str,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    low_memory: bool = False
):
    """
    Convenience function to load the full model
    
    Args:
        model_path: Path to inference directory
        device: Device to load model to
        low_memory: Use low memory loading
        
    Returns:
        Loaded model weights and metadata
    """
    loader = ShardedModelLoader(model_path)
    
    # Verify shards first
    verification = loader.verify_shards()
    missing = sum(1 for exists in verification.values() if not exists)
    
    if missing > 0:
        raise FileNotFoundError(
            f"Cannot load model: {missing} shard files are missing. "
            f"Please download all 116 shard files."
        )
    
    # Load weights
    weights = loader.load_sharded_weights(
        device=device,
        low_memory=low_memory,
        show_progress=True
    )
    
    # Load metadata
    metadata = loader.load_metadata()
    
    return weights, metadata


def inspect_model(model_path: str):
    """
    Inspect model structure without loading weights
    
    Args:
        model_path: Path to inference directory
    """
    loader = ShardedModelLoader(model_path)
    
    print("\n" + "="*80)
    print("HELION-OSC MODEL INSPECTION")
    print("="*80)
    
    metadata = loader.load_metadata()
    
    print(f"\nModel Type: {metadata['architecture'].get('model_description', 'N/A')}")
    print(f"Architecture: {metadata['architecture'].get('architecture_type', 'N/A')}")
    print(f"Total Parameters: {metadata['architecture'].get('total_parameters', 'N/A')}")
    print(f"Active Parameters: {metadata['architecture'].get('active_parameters', 'N/A')}")
    
    print(f"\nModel Configuration:")
    print(f"  Layers: {metadata['num_layers']}")
    print(f"  Hidden Size: {metadata['hidden_size']}")
    print(f"  Vocabulary Size: {metadata['vocab_size']}")
    print(f"  Attention Heads: {metadata['config'].get('num_attention_heads', 'N/A')}")
    print(f"  KV Heads: {metadata['config'].get('num_key_value_heads', 'N/A')}")
    
    print(f"\nMoE Configuration:")
    arch = metadata['architecture']
    print(f"  Number of Experts: {arch.get('num_experts', 'N/A')}")
    print(f"  Experts per Token: {arch.get('experts_per_token', 'N/A')}")
    print(f"  Shared Experts: {arch.get('num_shared_experts', 'N/A')}")
    
    print(f"\nStorage Information:")
    print(f"  Total Shards: {metadata['total_shards']}")
    print(f"  Total Size: {metadata['total_size_gb']:.2f} GB")
    print(f"  Shard Size: ~2.8 GB each")
    print(f"  Format: safetensors")
    print(f"  Precision: bfloat16")
    
    print(f"\nContext Length:")
    print(f"  Max Position Embeddings: {metadata['config'].get('max_position_embeddings', 'N/A')}")
    print(f"  RoPE Theta: {metadata['config'].get('rope_theta', 'N/A')}")
    
    print("\n" + "="*80)
    
    # Verify shards
    print("\nVerifying shard files...")
    verification = loader.verify_shards()
    present = sum(1 for exists in verification.values() if exists)
    total = len(verification)
    
    print(f"\nShard Status: {present}/{total} files present")
    
    if present == total:
        print("✓ All shard files are available")
    else:
        print(f"✗ Missing {total - present} shard files")


def main():
    """Main CLI interface"""
    import argparse
    
    parser = argparse.ArgumentParser(description="Helion-OSC Sharded Model Loader")
    parser.add_argument(
        "model_path",
        type=str,
        help="Path to inference directory"
    )
    parser.add_argument(
        "--action",
        choices=["inspect", "verify", "load"],
        default="inspect",
        help="Action to perform"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to load model to"
    )
    parser.add_argument(
        "--low-memory",
        action="store_true",
        help="Use low memory mode"
    )
    
    args = parser.parse_args()
    
    if args.action == "inspect":
        inspect_model(args.model_path)
    
    elif args.action == "verify":
        loader = ShardedModelLoader(args.model_path)
        loader.verify_shards()
    
    elif args.action == "load":
        logger.info("Loading full model...")
        weights, metadata = load_full_model(
            args.model_path,
            device=args.device,
            low_memory=args.low_memory
        )
        logger.info(f"Successfully loaded {len(weights)} weight tensors")
        logger.info(f"Model ready on {args.device}")


if __name__ == "__main__":
    main()