Spaces:
Running
Running
| import torch | |
| import tqdm | |
| import transformers | |
| from mergekit.moe.arch import MoEOutputArchitecture | |
| from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype | |
| from mergekit.moe.config import MoEMergeConfig | |
| from mergekit.options import MergeOptions | |
| from mergekit.architecture import arch_info_for_config | |
| class LlamaMoE(MoEOutputArchitecture): | |
| def name(self) -> str: | |
| return "LlamaMoE" | |
| def supports_config(self, config: MoEMergeConfig, explain: bool = False, trust_remote_code: bool = False) -> bool: | |
| # Ensure the base model is a Llama model | |
| model_cfg = config.base_model.config(trust_remote_code=trust_remote_code) | |
| if model_cfg.model_type != "llama": | |
| if explain: | |
| print("LlamaMoE only supports Llama base models") | |
| return False | |
| return True | |
| def write_model(self, out_path: str, config: MoEMergeConfig, merge_options: MergeOptions, router_weights: list[torch.Tensor], shared_router_weights=None): | |
| base_model = config.base_model | |
| base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) | |
| # 1. Generate the config.json | |
| out_cfg = base_cfg.to_dict() | |
| # Note: Most Llama MoEs use the Mixtral architecture name for compatibility with loaders | |
| out_cfg["architectures"] = ["MixtralForCausalLM"] | |
| out_cfg["num_local_experts"] = len(config.experts) | |
| out_cfg["num_experts_per_tok"] = config.experts_per_token | |
| out_dtype = select_dtype(config, base_cfg) | |
| # 2. Initialize IO | |
| loaders, base_loader, writer = initialize_io(config, out_path, merge_options) | |
| # 3. Map Tensors | |
| for weight_info in tqdm.tqdm(arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights"): | |
| tensor_name = weight_info.name | |
| if ".mlp." in tensor_name: | |
| for expert_idx, expert in enumerate(config.experts): | |
| # Map Llama's gate_proj/up_proj/down_proj to Mixtral's w1/w3/w2 | |
| expert_name = tensor_name.replace(".mlp.gate_proj", f".block_sparse_moe.experts.{expert_idx}.w1") | |
| expert_name = expert_name.replace(".mlp.down_proj", f".block_sparse_moe.experts.{expert_idx}.w2") | |
| expert_name = expert_name.replace(".mlp.up_proj", f".block_sparse_moe.experts.{expert_idx}.w3") | |
| expert_loader = loaders.get(expert.source_model) | |
| copy_tensor_out(weight_info, expert_loader, writer, expert=expert, output_name=expert_name, out_dtype=out_dtype) | |
| else: | |
| # Copy Attention and Norms from base model | |
| copy_tensor_out(weight_info, base_loader, writer, out_dtype=out_dtype) | |
| # 4. Write Router Weights | |
| for layer_idx, weight in enumerate(router_weights): | |
| writer.save_tensor(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", weight.to(dtype=out_dtype)) | |
| writer.finalize() |