kunhunjon's picture
Upload sharded model (9x2GB shards, continuous batching, neuronxcc 2.21)
d3b192b verified
#!/usr/bin/env python3
"""
Script to reconstruct the original model file from shards
"""
import json
import hashlib
from pathlib import Path
def reconstruct_file(shards_dir="."):
shards_dir = Path(shards_dir)
# Find metadata file
metadata_files = list(shards_dir.glob("*.shards.json"))
if not metadata_files:
print("Error: No shards metadata file found")
return False
metadata_path = metadata_files[0]
print(f"Loading metadata: {metadata_path}")
with open(metadata_path, 'r') as f:
metadata = json.load(f)
output_file = metadata["original_file"]
print(f"Reconstructing: {output_file}")
print(f" Expected size: {metadata['file_size'] / (1024**3):.2f} GB")
print(f" Number of shards: {metadata['num_shards']}")
with open(output_file, 'wb') as f_out:
for shard_info in metadata["shards"]:
shard_path = shards_dir / shard_info["filename"]
print(f" Processing shard {shard_info['index'] + 1}/{metadata['num_shards']}: {shard_info['filename']}")
if not shard_path.exists():
print(f"Error: Shard not found: {shard_path}")
return False
# Read shard
with open(shard_path, 'rb') as f_in:
chunk_data = f_in.read()
# Verify hash
chunk_hash = hashlib.sha256(chunk_data).hexdigest()
if chunk_hash != shard_info["sha256"]:
print(f"Error: Hash mismatch for {shard_info['filename']}")
print(f" Expected: {shard_info['sha256']}")
print(f" Got: {chunk_hash}")
return False
# Write to output
f_out.write(chunk_data)
print(f"\n✓ Reconstruction complete: {output_file}")
return True
if __name__ == "__main__":
import sys
shards_dir = sys.argv[1] if len(sys.argv) > 1 else "."
success = reconstruct_file(shards_dir)
exit(0 if success else 1)