| """ |
| Shard Generator for Helion-OSC |
| Creates placeholder or actual safetensors shard files |
| |
| This script helps you: |
| 1. Generate placeholder shards for testing |
| 2. Split a large model into 116 shards |
| 3. Verify shard integrity |
| """ |
|
|
| import torch |
| import json |
| import os |
| from pathlib import Path |
| from typing import Dict, List, Optional |
| import logging |
| from tqdm import tqdm |
| from safetensors.torch import save_file, load_file |
| import numpy as np |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ShardGenerator: |
| """Generate and manage model shards""" |
| |
| def __init__(self, output_dir: str, total_shards: int = 116): |
| """ |
| Initialize shard generator |
| |
| Args: |
| output_dir: Directory to save shards |
| total_shards: Total number of shards to generate |
| """ |
| self.output_dir = Path(output_dir) |
| self.total_shards = total_shards |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| logger.info(f"Shard generator initialized") |
| logger.info(f"Output directory: {self.output_dir}") |
| logger.info(f"Total shards: {self.total_shards}") |
| |
| def get_shard_name(self, shard_idx: int) -> str: |
| """Get formatted shard name""" |
| return f"model-{shard_idx:05d}-of-{self.total_shards:05d}.safetensors" |
| |
| def generate_placeholder_shards( |
| self, |
| shard_size_mb: float = 2800, |
| tensor_dtype: torch.dtype = torch.bfloat16 |
| ): |
| """ |
| Generate placeholder shards for testing |
| |
| Args: |
| shard_size_mb: Target size per shard in MB |
| tensor_dtype: Data type for tensors |
| """ |
| logger.info("Generating placeholder shards...") |
| logger.info(f"Target shard size: {shard_size_mb} MB") |
| |
| |
| |
| bytes_per_element = 2 if tensor_dtype == torch.bfloat16 else 4 |
| target_bytes = shard_size_mb * 1024 * 1024 |
| num_elements = int(target_bytes / bytes_per_element) |
| |
| |
| |
| tensor_shapes = self._generate_realistic_shapes(num_elements) |
| |
| for shard_idx in tqdm(range(1, self.total_shards + 1), desc="Creating shards"): |
| shard_name = self.get_shard_name(shard_idx) |
| shard_path = self.output_dir / shard_name |
| |
| |
| tensors = {} |
| for name, shape in tensor_shapes.items(): |
| key = f"layer_{shard_idx}.{name}" |
| tensors[key] = torch.randn(shape, dtype=tensor_dtype) |
| |
| |
| save_file(tensors, str(shard_path)) |
| |
| |
| actual_size_mb = shard_path.stat().st_size / (1024 * 1024) |
| logger.debug(f"{shard_name}: {actual_size_mb:.2f} MB") |
| |
| logger.info(f"✓ Generated {self.total_shards} placeholder shards") |
| |
| def _generate_realistic_shapes(self, total_elements: int) -> Dict[str, tuple]: |
| """ |
| Generate realistic tensor shapes for a transformer layer |
| |
| Args: |
| total_elements: Total number of elements to distribute |
| |
| Returns: |
| Dictionary of tensor names and shapes |
| """ |
| |
| hidden_size = 8192 |
| intermediate_size = 28672 |
| num_heads = 64 |
| head_dim = 128 |
| |
| shapes = { |
| "self_attn.q_proj.weight": (hidden_size, hidden_size), |
| "self_attn.k_proj.weight": (hidden_size // 8, hidden_size), |
| "self_attn.v_proj.weight": (hidden_size // 8, hidden_size), |
| "self_attn.o_proj.weight": (hidden_size, hidden_size), |
| "mlp.gate_proj.weight": (intermediate_size, hidden_size), |
| "mlp.up_proj.weight": (intermediate_size, hidden_size), |
| "mlp.down_proj.weight": (hidden_size, intermediate_size), |
| "input_layernorm.weight": (hidden_size,), |
| "post_attention_layernorm.weight": (hidden_size,), |
| } |
| |
| return shapes |
| |
| def split_large_model( |
| self, |
| model_state_dict: Dict[str, torch.Tensor], |
| max_shard_size_gb: float = 2.8 |
| ): |
| """ |
| Split a large model into shards |
| |
| Args: |
| model_state_dict: Model weights dictionary |
| max_shard_size_gb: Maximum size per shard in GB |
| """ |
| logger.info("Splitting model into shards...") |
| |
| max_shard_bytes = max_shard_size_gb * 1024 ** 3 |
| |
| current_shard = {} |
| current_size = 0 |
| shard_idx = 1 |
| weight_map = {} |
| |
| for name, tensor in tqdm(model_state_dict.items(), desc="Processing weights"): |
| |
| tensor_bytes = tensor.nelement() * tensor.element_size() |
| |
| |
| if current_size + tensor_bytes > max_shard_bytes and current_shard: |
| |
| shard_name = self.get_shard_name(shard_idx) |
| self._save_shard(current_shard, shard_name) |
| |
| |
| for weight_name in current_shard.keys(): |
| weight_map[weight_name] = shard_name |
| |
| |
| current_shard = {} |
| current_size = 0 |
| shard_idx += 1 |
| |
| |
| current_shard[name] = tensor |
| current_size += tensor_bytes |
| |
| |
| if current_shard: |
| shard_name = self.get_shard_name(shard_idx) |
| self._save_shard(current_shard, shard_name) |
| |
| for weight_name in current_shard.keys(): |
| weight_map[weight_name] = shard_name |
| |
| logger.info(f"✓ Model split into {shard_idx} shards") |
| |
| |
| self._save_index(weight_map, shard_idx) |
| |
| return weight_map |
| |
| def _save_shard(self, tensors: Dict[str, torch.Tensor], shard_name: str): |
| """Save a shard file""" |
| shard_path = self.output_dir / shard_name |
| save_file(tensors, str(shard_path)) |
| size_mb = shard_path.stat().st_size / (1024 * 1024) |
| logger.info(f"Saved {shard_name} ({size_mb:.2f} MB)") |
| |
| def _save_index(self, weight_map: Dict[str, str], total_shards: int): |
| """Save the weight map index file""" |
| index = { |
| "metadata": { |
| "total_size": sum( |
| (self.output_dir / shard).stat().st_size |
| for shard in set(weight_map.values()) |
| ), |
| "total_shards": total_shards, |
| "format": "safetensors", |
| "model_type": "helion-osc" |
| }, |
| "weight_map": weight_map |
| } |
| |
| index_path = self.output_dir / "model.safetensors.index.json" |
| with open(index_path, 'w') as f: |
| json.dump(index, f, indent=2) |
| |
| logger.info(f"Saved index to {index_path}") |
| |
| def verify_shards(self) -> bool: |
| """Verify all shards can be loaded""" |
| logger.info("Verifying shards...") |
| |
| all_valid = True |
| |
| for shard_idx in tqdm(range(1, self.total_shards + 1), desc="Verifying"): |
| shard_name = self.get_shard_name(shard_idx) |
| shard_path = self.output_dir / shard_name |
| |
| if not shard_path.exists(): |
| logger.error(f"Missing: {shard_name}") |
| all_valid = False |
| continue |
| |
| try: |
| |
| _ = load_file(str(shard_path)) |
| except Exception as e: |
| logger.error(f"Invalid {shard_name}: {e}") |
| all_valid = False |
| |
| if all_valid: |
| logger.info("✓ All shards verified successfully") |
| else: |
| logger.error("✗ Some shards are missing or invalid") |
| |
| return all_valid |
| |
| def get_shard_stats(self) -> Dict: |
| """Get statistics about shards""" |
| stats = { |
| "total_shards": self.total_shards, |
| "present_shards": 0, |
| "total_size_gb": 0, |
| "sizes_mb": [] |
| } |
| |
| for shard_idx in range(1, self.total_shards + 1): |
| shard_name = self.get_shard_name(shard_idx) |
| shard_path = self.output_dir / shard_name |
| |
| if shard_path.exists(): |
| stats["present_shards"] += 1 |
| size_mb = shard_path.stat().st_size / (1024 * 1024) |
| stats["sizes_mb"].append(size_mb) |
| stats["total_size_gb"] += size_mb / 1024 |
| |
| if stats["sizes_mb"]: |
| stats["avg_size_mb"] = np.mean(stats["sizes_mb"]) |
| stats["min_size_mb"] = np.min(stats["sizes_mb"]) |
| stats["max_size_mb"] = np.max(stats["sizes_mb"]) |
| |
| return stats |
|
|
|
|
| def main(): |
| """CLI interface""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Helion-OSC Shard Generator") |
| parser.add_argument( |
| "output_dir", |
| type=str, |
| help="Output directory for shards" |
| ) |
| parser.add_argument( |
| "--action", |
| choices=["generate", "verify", "stats"], |
| default="generate", |
| help="Action to perform" |
| ) |
| parser.add_argument( |
| "--total-shards", |
| type=int, |
| default=116, |
| help="Total number of shards" |
| ) |
| parser.add_argument( |
| "--shard-size", |
| type=float, |
| default=2800, |
| help="Target shard size in MB" |
| ) |
| |
| args = parser.parse_args() |
| |
| generator = ShardGenerator( |
| output_dir=args.output_dir, |
| total_shards=args.total_shards |
| ) |
| |
| if args.action == "generate": |
| logger.info("Generating placeholder shards for testing...") |
| logger.warning("Note: These are random tensors for testing only!") |
| generator.generate_placeholder_shards(shard_size_mb=args.shard_size) |
| |
| elif args.action == "verify": |
| generator.verify_shards() |
| |
| elif args.action == "stats": |
| stats = generator.get_shard_stats() |
| print("\n" + "="*80) |
| print("SHARD STATISTICS") |
| print("="*80) |
| print(f"Total Shards: {stats['total_shards']}") |
| print(f"Present Shards: {stats['present_shards']}") |
| print(f"Total Size: {stats['total_size_gb']:.2f} GB") |
| |
| if stats['present_shards'] > 0: |
| print(f"Average Size: {stats['avg_size_mb']:.2f} MB") |
| print(f"Min Size: {stats['min_size_mb']:.2f} MB") |
| print(f"Max Size: {stats['max_size_mb']:.2f} MB") |
| |
| print("="*80) |
|
|
|
|
| if __name__ == "__main__": |
| main() |