scatterbrain-sm-experimental / convert_to_stacked.py
ToastyPigeon's picture
Upload folder using huggingface_hub
9711198 verified
#!/usr/bin/env python3
"""
Convert Scatterbrain checkpoint from individual expert weights to stacked format.
This converts:
model.layers.0.mlp.experts.{i}.gate_proj.weight -> model.layers.0.mlp.expert_gate_proj
model.layers.0.mlp.experts.{i}.up_proj.weight -> model.layers.0.mlp.expert_up_proj
model.layers.0.mlp.experts.{i}.down_proj.weight -> model.layers.0.mlp.expert_down_proj
"""
import json
import os
import re
import torch
from safetensors.torch import load_file, save_file
from collections import defaultdict
def convert_checkpoint(model_dir: str, output_dir: str = None):
"""Convert checkpoint from individual experts to stacked format."""
if output_dir is None:
output_dir = model_dir
# Load config to get num_experts
config_path = os.path.join(model_dir, "config.json")
with open(config_path) as f:
config = json.load(f)
num_experts = config["num_experts"]
print(f"Converting {num_experts} experts to stacked format...")
# Find all safetensor files
safetensor_files = sorted([f for f in os.listdir(model_dir) if f.endswith('.safetensors')])
if not safetensor_files:
raise ValueError(f"No safetensor files found in {model_dir}")
# Load all weights
all_weights = {}
for sf_file in safetensor_files:
path = os.path.join(model_dir, sf_file)
print(f"Loading {sf_file}...")
weights = load_file(path)
all_weights.update(weights)
# Pattern to match expert weights
expert_pattern = re.compile(r'(.+\.mlp)\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight')
# Group expert weights by layer and projection type
expert_groups = defaultdict(lambda: defaultdict(dict))
other_weights = {}
for key, value in all_weights.items():
match = expert_pattern.match(key)
if match:
layer_prefix = match.group(1) # e.g., "model.layers.0.mlp"
expert_idx = int(match.group(2))
proj_type = match.group(3) # gate_proj, up_proj, down_proj
expert_groups[layer_prefix][proj_type][expert_idx] = value
else:
other_weights[key] = value
# Stack expert weights
new_weights = dict(other_weights)
for layer_prefix, proj_types in expert_groups.items():
for proj_type, expert_weights in proj_types.items():
# Sort by expert index and stack
sorted_weights = [expert_weights[i] for i in range(num_experts)]
stacked = torch.stack(sorted_weights, dim=0)
new_key = f"{layer_prefix}.expert_{proj_type}"
new_weights[new_key] = stacked
print(f" {new_key}: {stacked.shape}")
# Save to single safetensor file
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "model.safetensors")
print(f"\nSaving to {output_path}...")
save_file(new_weights, output_path)
# Remove old safetensor files if output_dir == model_dir
if output_dir == model_dir:
for sf_file in safetensor_files:
if sf_file != "model.safetensors":
old_path = os.path.join(model_dir, sf_file)
print(f"Removing old file: {sf_file}")
os.remove(old_path)
# Update index file if it exists
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
print("Removing old index file...")
os.remove(index_path)
print("\nConversion complete!")
print(f"New weights saved to: {output_path}")
# Print summary
total_params = sum(p.numel() for p in new_weights.values())
print(f"Total parameters: {total_params:,}")
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
model_dir = "/home/aibox/training/scatterbrain-small-experimental"
else:
model_dir = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
convert_checkpoint(model_dir, output_dir)