File size: 3,087 Bytes
5f463e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import tqdm
import transformers
from mergekit.moe.arch import MoEOutputArchitecture
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
from mergekit.moe.config import MoEMergeConfig
from mergekit.options import MergeOptions
from mergekit.architecture import arch_info_for_config

class LlamaMoE(MoEOutputArchitecture):
    def name(self) -> str:
        return "LlamaMoE"

    def supports_config(self, config: MoEMergeConfig, explain: bool = False, trust_remote_code: bool = False) -> bool:
        # Ensure the base model is a Llama model
        model_cfg = config.base_model.config(trust_remote_code=trust_remote_code)
        if model_cfg.model_type != "llama":
            if explain:
                print("LlamaMoE only supports Llama base models")
            return False
        return True

    def write_model(self, out_path: str, config: MoEMergeConfig, merge_options: MergeOptions, router_weights: list[torch.Tensor], shared_router_weights=None):
        base_model = config.base_model
        base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
        
        # 1. Generate the config.json
        out_cfg = base_cfg.to_dict()
        # Note: Most Llama MoEs use the Mixtral architecture name for compatibility with loaders
        out_cfg["architectures"] = ["MixtralForCausalLM"] 
        out_cfg["num_local_experts"] = len(config.experts)
        out_cfg["num_experts_per_tok"] = config.experts_per_token
        
        out_dtype = select_dtype(config, base_cfg)
        
        # 2. Initialize IO
        loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
        
        # 3. Map Tensors
        for weight_info in tqdm.tqdm(arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights"):
            tensor_name = weight_info.name
            if ".mlp." in tensor_name:
                for expert_idx, expert in enumerate(config.experts):
                    # Map Llama's gate_proj/up_proj/down_proj to Mixtral's w1/w3/w2
                    expert_name = tensor_name.replace(".mlp.gate_proj", f".block_sparse_moe.experts.{expert_idx}.w1")
                    expert_name = expert_name.replace(".mlp.down_proj", f".block_sparse_moe.experts.{expert_idx}.w2")
                    expert_name = expert_name.replace(".mlp.up_proj", f".block_sparse_moe.experts.{expert_idx}.w3")
                    
                    expert_loader = loaders.get(expert.source_model)
                    copy_tensor_out(weight_info, expert_loader, writer, expert=expert, output_name=expert_name, out_dtype=out_dtype)
            else:
                # Copy Attention and Norms from base model
                copy_tensor_out(weight_info, base_loader, writer, out_dtype=out_dtype)

        # 4. Write Router Weights
        for layer_idx, weight in enumerate(router_weights):
            writer.save_tensor(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", weight.to(dtype=out_dtype))

        writer.finalize()