| | """ |
| | Shard Generator for Helion-OSC |
| | Creates placeholder or actual safetensors shard files |
| | |
| | This script helps you: |
| | 1. Generate placeholder shards for testing |
| | 2. Split a large model into 116 shards |
| | 3. Verify shard integrity |
| | """ |
| |
|
| | import torch |
| | import json |
| | import os |
| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| | import logging |
| | from tqdm import tqdm |
| | from safetensors.torch import save_file, load_file |
| | import numpy as np |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ShardGenerator: |
| | """Generate and manage model shards""" |
| | |
| | def __init__(self, output_dir: str, total_shards: int = 116): |
| | """ |
| | Initialize shard generator |
| | |
| | Args: |
| | output_dir: Directory to save shards |
| | total_shards: Total number of shards to generate |
| | """ |
| | self.output_dir = Path(output_dir) |
| | self.total_shards = total_shards |
| | self.output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | logger.info(f"Shard generator initialized") |
| | logger.info(f"Output directory: {self.output_dir}") |
| | logger.info(f"Total shards: {self.total_shards}") |
| | |
| | def get_shard_name(self, shard_idx: int) -> str: |
| | """Get formatted shard name""" |
| | return f"model-{shard_idx:05d}-of-{self.total_shards:05d}.safetensors" |
| | |
| | def generate_placeholder_shards( |
| | self, |
| | shard_size_mb: float = 2800, |
| | tensor_dtype: torch.dtype = torch.bfloat16 |
| | ): |
| | """ |
| | Generate placeholder shards for testing |
| | |
| | Args: |
| | shard_size_mb: Target size per shard in MB |
| | tensor_dtype: Data type for tensors |
| | """ |
| | logger.info("Generating placeholder shards...") |
| | logger.info(f"Target shard size: {shard_size_mb} MB") |
| | |
| | |
| | |
| | bytes_per_element = 2 if tensor_dtype == torch.bfloat16 else 4 |
| | target_bytes = shard_size_mb * 1024 * 1024 |
| | num_elements = int(target_bytes / bytes_per_element) |
| | |
| | |
| | |
| | tensor_shapes = self._generate_realistic_shapes(num_elements) |
| | |
| | for shard_idx in tqdm(range(1, self.total_shards + 1), desc="Creating shards"): |
| | shard_name = self.get_shard_name(shard_idx) |
| | shard_path = self.output_dir / shard_name |
| | |
| | |
| | tensors = {} |
| | for name, shape in tensor_shapes.items(): |
| | key = f"layer_{shard_idx}.{name}" |
| | tensors[key] = torch.randn(shape, dtype=tensor_dtype) |
| | |
| | |
| | save_file(tensors, str(shard_path)) |
| | |
| | |
| | actual_size_mb = shard_path.stat().st_size / (1024 * 1024) |
| | logger.debug(f"{shard_name}: {actual_size_mb:.2f} MB") |
| | |
| | logger.info(f"✓ Generated {self.total_shards} placeholder shards") |
| | |
| | def _generate_realistic_shapes(self, total_elements: int) -> Dict[str, tuple]: |
| | """ |
| | Generate realistic tensor shapes for a transformer layer |
| | |
| | Args: |
| | total_elements: Total number of elements to distribute |
| | |
| | Returns: |
| | Dictionary of tensor names and shapes |
| | """ |
| | |
| | hidden_size = 8192 |
| | intermediate_size = 28672 |
| | num_heads = 64 |
| | head_dim = 128 |
| | |
| | shapes = { |
| | "self_attn.q_proj.weight": (hidden_size, hidden_size), |
| | "self_attn.k_proj.weight": (hidden_size // 8, hidden_size), |
| | "self_attn.v_proj.weight": (hidden_size // 8, hidden_size), |
| | "self_attn.o_proj.weight": (hidden_size, hidden_size), |
| | "mlp.gate_proj.weight": (intermediate_size, hidden_size), |
| | "mlp.up_proj.weight": (intermediate_size, hidden_size), |
| | "mlp.down_proj.weight": (hidden_size, intermediate_size), |
| | "input_layernorm.weight": (hidden_size,), |
| | "post_attention_layernorm.weight": (hidden_size,), |
| | } |
| | |
| | return shapes |
| | |
| | def split_large_model( |
| | self, |
| | model_state_dict: Dict[str, torch.Tensor], |
| | max_shard_size_gb: float = 2.8 |
| | ): |
| | """ |
| | Split a large model into shards |
| | |
| | Args: |
| | model_state_dict: Model weights dictionary |
| | max_shard_size_gb: Maximum size per shard in GB |
| | """ |
| | logger.info("Splitting model into shards...") |
| | |
| | max_shard_bytes = max_shard_size_gb * 1024 ** 3 |
| | |
| | current_shard = {} |
| | current_size = 0 |
| | shard_idx = 1 |
| | weight_map = {} |
| | |
| | for name, tensor in tqdm(model_state_dict.items(), desc="Processing weights"): |
| | |
| | tensor_bytes = tensor.nelement() * tensor.element_size() |
| | |
| | |
| | if current_size + tensor_bytes > max_shard_bytes and current_shard: |
| | |
| | shard_name = self.get_shard_name(shard_idx) |
| | self._save_shard(current_shard, shard_name) |
| | |
| | |
| | for weight_name in current_shard.keys(): |
| | weight_map[weight_name] = shard_name |
| | |
| | |
| | current_shard = {} |
| | current_size = 0 |
| | shard_idx += 1 |
| | |
| | |
| | current_shard[name] = tensor |
| | current_size += tensor_bytes |
| | |
| | |
| | if current_shard: |
| | shard_name = self.get_shard_name(shard_idx) |
| | self._save_shard(current_shard, shard_name) |
| | |
| | for weight_name in current_shard.keys(): |
| | weight_map[weight_name] = shard_name |
| | |
| | logger.info(f"✓ Model split into {shard_idx} shards") |
| | |
| | |
| | self._save_index(weight_map, shard_idx) |
| | |
| | return weight_map |
| | |
| | def _save_shard(self, tensors: Dict[str, torch.Tensor], shard_name: str): |
| | """Save a shard file""" |
| | shard_path = self.output_dir / shard_name |
| | save_file(tensors, str(shard_path)) |
| | size_mb = shard_path.stat().st_size / (1024 * 1024) |
| | logger.info(f"Saved {shard_name} ({size_mb:.2f} MB)") |
| | |
| | def _save_index(self, weight_map: Dict[str, str], total_shards: int): |
| | """Save the weight map index file""" |
| | index = { |
| | "metadata": { |
| | "total_size": sum( |
| | (self.output_dir / shard).stat().st_size |
| | for shard in set(weight_map.values()) |
| | ), |
| | "total_shards": total_shards, |
| | "format": "safetensors", |
| | "model_type": "helion-osc" |
| | }, |
| | "weight_map": weight_map |
| | } |
| | |
| | index_path = self.output_dir / "model.safetensors.index.json" |
| | with open(index_path, 'w') as f: |
| | json.dump(index, f, indent=2) |
| | |
| | logger.info(f"Saved index to {index_path}") |
| | |
| | def verify_shards(self) -> bool: |
| | """Verify all shards can be loaded""" |
| | logger.info("Verifying shards...") |
| | |
| | all_valid = True |
| | |
| | for shard_idx in tqdm(range(1, self.total_shards + 1), desc="Verifying"): |
| | shard_name = self.get_shard_name(shard_idx) |
| | shard_path = self.output_dir / shard_name |
| | |
| | if not shard_path.exists(): |
| | logger.error(f"Missing: {shard_name}") |
| | all_valid = False |
| | continue |
| | |
| | try: |
| | |
| | _ = load_file(str(shard_path)) |
| | except Exception as e: |
| | logger.error(f"Invalid {shard_name}: {e}") |
| | all_valid = False |
| | |
| | if all_valid: |
| | logger.info("✓ All shards verified successfully") |
| | else: |
| | logger.error("✗ Some shards are missing or invalid") |
| | |
| | return all_valid |
| | |
| | def get_shard_stats(self) -> Dict: |
| | """Get statistics about shards""" |
| | stats = { |
| | "total_shards": self.total_shards, |
| | "present_shards": 0, |
| | "total_size_gb": 0, |
| | "sizes_mb": [] |
| | } |
| | |
| | for shard_idx in range(1, self.total_shards + 1): |
| | shard_name = self.get_shard_name(shard_idx) |
| | shard_path = self.output_dir / shard_name |
| | |
| | if shard_path.exists(): |
| | stats["present_shards"] += 1 |
| | size_mb = shard_path.stat().st_size / (1024 * 1024) |
| | stats["sizes_mb"].append(size_mb) |
| | stats["total_size_gb"] += size_mb / 1024 |
| | |
| | if stats["sizes_mb"]: |
| | stats["avg_size_mb"] = np.mean(stats["sizes_mb"]) |
| | stats["min_size_mb"] = np.min(stats["sizes_mb"]) |
| | stats["max_size_mb"] = np.max(stats["sizes_mb"]) |
| | |
| | return stats |
| |
|
| |
|
| | def main(): |
| | """CLI interface""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Helion-OSC Shard Generator") |
| | parser.add_argument( |
| | "output_dir", |
| | type=str, |
| | help="Output directory for shards" |
| | ) |
| | parser.add_argument( |
| | "--action", |
| | choices=["generate", "verify", "stats"], |
| | default="generate", |
| | help="Action to perform" |
| | ) |
| | parser.add_argument( |
| | "--total-shards", |
| | type=int, |
| | default=116, |
| | help="Total number of shards" |
| | ) |
| | parser.add_argument( |
| | "--shard-size", |
| | type=float, |
| | default=2800, |
| | help="Target shard size in MB" |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | generator = ShardGenerator( |
| | output_dir=args.output_dir, |
| | total_shards=args.total_shards |
| | ) |
| | |
| | if args.action == "generate": |
| | logger.info("Generating placeholder shards for testing...") |
| | logger.warning("Note: These are random tensors for testing only!") |
| | generator.generate_placeholder_shards(shard_size_mb=args.shard_size) |
| | |
| | elif args.action == "verify": |
| | generator.verify_shards() |
| | |
| | elif args.action == "stats": |
| | stats = generator.get_shard_stats() |
| | print("\n" + "="*80) |
| | print("SHARD STATISTICS") |
| | print("="*80) |
| | print(f"Total Shards: {stats['total_shards']}") |
| | print(f"Present Shards: {stats['present_shards']}") |
| | print(f"Total Size: {stats['total_size_gb']:.2f} GB") |
| | |
| | if stats['present_shards'] > 0: |
| | print(f"Average Size: {stats['avg_size_mb']:.2f} MB") |
| | print(f"Min Size: {stats['min_size_mb']:.2f} MB") |
| | print(f"Max Size: {stats['max_size_mb']:.2f} MB") |
| | |
| | print("="*80) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |