| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Reverse process of moe_merge.py - splits merged MoE expert weights back to individual experts. |
| | |
| | This script takes a HF checkpoint that has been processed by moe_merge.py (where expert weights |
| | are stacked into single tensors) and splits them back to the original format with individual |
| | expert weights. |
| | |
| | The process reverses the merging by: |
| | 1. Loading stacked tensors like model.layers.{i}.mlp.experts.gate_proj |
| | 2. Unstacking them back to individual experts model.layers.{i}.mlp.experts.{j}.gate_proj.weight |
| | 3. Handling all three projection types: gate_proj, up_proj, down_proj |
| | |
| | Usage: python moe_split.py --merge_hf_path <merged_checkpoint> --split_hf_path <output_dir> |
| | """ |
| |
|
| | import os |
| | from argparse import ArgumentParser |
| | from dataclasses import dataclass |
| | from glob import glob |
| | from typing import Generator |
| |
|
| | import torch |
| | from safetensors.torch import safe_open |
| | from tqdm import tqdm |
| | from transformers import AutoConfig |
| | from veomni.models import build_tokenizer, save_model_weights |
| |
|
| |
|
| | @dataclass |
| | class StateDictIterator: |
| | filepath: str |
| |
|
| | def __iter__(self) -> Generator[tuple[str, "torch.Tensor"], None, None]: |
| | if self.filepath.endswith(".safetensors"): |
| | with safe_open(self.filepath, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | yield key, f.get_tensor(key) |
| |
|
| | else: |
| | state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True) |
| | for key in state_dict.keys(): |
| | yield key, state_dict[key] |
| |
|
| |
|
| | def main(merge_hf_path, split_hf_path): |
| | torch.set_default_dtype(torch.bfloat16) |
| | os.makedirs(split_hf_path, exist_ok=True) |
| |
|
| | config = AutoConfig.from_pretrained(merge_hf_path) |
| | tokenizer = build_tokenizer(merge_hf_path) |
| |
|
| | safetensor_files = list(glob(os.path.join(merge_hf_path, "*.safetensors"))) |
| | safetensor_files.sort() |
| | state_dict_iterators = [StateDictIterator(shard_file) for shard_file in safetensor_files] |
| | new_state_dict = {} |
| | for state_dict_iterator in tqdm(state_dict_iterators, desc="Loading checkpoint shards"): |
| | for name, tensor in state_dict_iterator: |
| | new_state_dict[name] = tensor.cpu() |
| |
|
| | num_experts = config.num_experts |
| | num_hidden_layers = config.num_hidden_layers |
| | for i in range(num_hidden_layers): |
| | print(f"Converting layer {i}") |
| | for proj_name in ["gate_proj", "up_proj", "down_proj"]: |
| | stacked_key = f"model.layers.{i}.mlp.experts.{proj_name}" |
| | if stacked_key in new_state_dict: |
| | stacked_tensor = new_state_dict.pop(stacked_key) |
| | for j in range(num_experts): |
| | expert_key = f"model.layers.{i}.mlp.experts.{j}.{proj_name}.weight" |
| | new_state_dict[expert_key] = stacked_tensor[j] |
| |
|
| | model_assets = [config, tokenizer] |
| |
|
| | print("Saving to safetensors") |
| | save_model_weights(split_hf_path, new_state_dict, model_assets=model_assets) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = ArgumentParser() |
| | parser.add_argument("--merge_hf_path", type=str, required=True) |
| | parser.add_argument("--split_hf_path", type=str, required=True) |
| | args = parser.parse_args() |
| | main(args.merge_hf_path, args.split_hf_path) |
| |
|