| |
| """ |
| Convert Scatterbrain checkpoint from individual expert weights to stacked format. |
| |
| This converts: |
| model.layers.0.mlp.experts.{i}.gate_proj.weight -> model.layers.0.mlp.expert_gate_proj |
| model.layers.0.mlp.experts.{i}.up_proj.weight -> model.layers.0.mlp.expert_up_proj |
| model.layers.0.mlp.experts.{i}.down_proj.weight -> model.layers.0.mlp.expert_down_proj |
| """ |
|
|
| import json |
| import os |
| import re |
| import torch |
| from safetensors.torch import load_file, save_file |
| from collections import defaultdict |
|
|
|
|
| def convert_checkpoint(model_dir: str, output_dir: str = None): |
| """Convert checkpoint from individual experts to stacked format.""" |
| if output_dir is None: |
| output_dir = model_dir |
|
|
| |
| config_path = os.path.join(model_dir, "config.json") |
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| num_experts = config["num_experts"] |
| print(f"Converting {num_experts} experts to stacked format...") |
|
|
| |
| safetensor_files = sorted([f for f in os.listdir(model_dir) if f.endswith('.safetensors')]) |
|
|
| if not safetensor_files: |
| raise ValueError(f"No safetensor files found in {model_dir}") |
|
|
| |
| all_weights = {} |
| for sf_file in safetensor_files: |
| path = os.path.join(model_dir, sf_file) |
| print(f"Loading {sf_file}...") |
| weights = load_file(path) |
| all_weights.update(weights) |
|
|
| |
| expert_pattern = re.compile(r'(.+\.mlp)\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight') |
|
|
| |
| expert_groups = defaultdict(lambda: defaultdict(dict)) |
| other_weights = {} |
|
|
| for key, value in all_weights.items(): |
| match = expert_pattern.match(key) |
| if match: |
| layer_prefix = match.group(1) |
| expert_idx = int(match.group(2)) |
| proj_type = match.group(3) |
| expert_groups[layer_prefix][proj_type][expert_idx] = value |
| else: |
| other_weights[key] = value |
|
|
| |
| new_weights = dict(other_weights) |
|
|
| for layer_prefix, proj_types in expert_groups.items(): |
| for proj_type, expert_weights in proj_types.items(): |
| |
| sorted_weights = [expert_weights[i] for i in range(num_experts)] |
| stacked = torch.stack(sorted_weights, dim=0) |
|
|
| new_key = f"{layer_prefix}.expert_{proj_type}" |
| new_weights[new_key] = stacked |
| print(f" {new_key}: {stacked.shape}") |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| output_path = os.path.join(output_dir, "model.safetensors") |
|
|
| print(f"\nSaving to {output_path}...") |
| save_file(new_weights, output_path) |
|
|
| |
| if output_dir == model_dir: |
| for sf_file in safetensor_files: |
| if sf_file != "model.safetensors": |
| old_path = os.path.join(model_dir, sf_file) |
| print(f"Removing old file: {sf_file}") |
| os.remove(old_path) |
|
|
| |
| index_path = os.path.join(model_dir, "model.safetensors.index.json") |
| if os.path.exists(index_path): |
| print("Removing old index file...") |
| os.remove(index_path) |
|
|
| print("\nConversion complete!") |
| print(f"New weights saved to: {output_path}") |
|
|
| |
| total_params = sum(p.numel() for p in new_weights.values()) |
| print(f"Total parameters: {total_params:,}") |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| model_dir = "/home/aibox/training/scatterbrain-small-experimental" |
| else: |
| model_dir = sys.argv[1] |
|
|
| output_dir = sys.argv[2] if len(sys.argv) > 2 else None |
|
|
| convert_checkpoint(model_dir, output_dir) |
|
|