model_tools / llama.py
Naphula's picture
Upload 8 files
5f463e1 verified
import torch
import tqdm
import transformers
from mergekit.moe.arch import MoEOutputArchitecture
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
from mergekit.moe.config import MoEMergeConfig
from mergekit.options import MergeOptions
from mergekit.architecture import arch_info_for_config
class LlamaMoE(MoEOutputArchitecture):
def name(self) -> str:
return "LlamaMoE"
def supports_config(self, config: MoEMergeConfig, explain: bool = False, trust_remote_code: bool = False) -> bool:
# Ensure the base model is a Llama model
model_cfg = config.base_model.config(trust_remote_code=trust_remote_code)
if model_cfg.model_type != "llama":
if explain:
print("LlamaMoE only supports Llama base models")
return False
return True
def write_model(self, out_path: str, config: MoEMergeConfig, merge_options: MergeOptions, router_weights: list[torch.Tensor], shared_router_weights=None):
base_model = config.base_model
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
# 1. Generate the config.json
out_cfg = base_cfg.to_dict()
# Note: Most Llama MoEs use the Mixtral architecture name for compatibility with loaders
out_cfg["architectures"] = ["MixtralForCausalLM"]
out_cfg["num_local_experts"] = len(config.experts)
out_cfg["num_experts_per_tok"] = config.experts_per_token
out_dtype = select_dtype(config, base_cfg)
# 2. Initialize IO
loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
# 3. Map Tensors
for weight_info in tqdm.tqdm(arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights"):
tensor_name = weight_info.name
if ".mlp." in tensor_name:
for expert_idx, expert in enumerate(config.experts):
# Map Llama's gate_proj/up_proj/down_proj to Mixtral's w1/w3/w2
expert_name = tensor_name.replace(".mlp.gate_proj", f".block_sparse_moe.experts.{expert_idx}.w1")
expert_name = expert_name.replace(".mlp.down_proj", f".block_sparse_moe.experts.{expert_idx}.w2")
expert_name = expert_name.replace(".mlp.up_proj", f".block_sparse_moe.experts.{expert_idx}.w3")
expert_loader = loaders.get(expert.source_model)
copy_tensor_out(weight_info, expert_loader, writer, expert=expert, output_name=expert_name, out_dtype=out_dtype)
else:
# Copy Attention and Norms from base model
copy_tensor_out(weight_info, base_loader, writer, out_dtype=out_dtype)
# 4. Write Router Weights
for layer_idx, weight in enumerate(router_weights):
writer.save_tensor(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", weight.to(dtype=out_dtype))
writer.finalize()