model_tools / moe_defs.py
Naphula's picture
Upload 8 files
5f463e1 verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
from typing import ClassVar, List, Optional
from pydantic import BaseModel
from transformers import PretrainedConfig
from mergekit.architecture.base import (
ModuleArchitecture,
WeightInfo,
)
from mergekit.architecture.json_definitions import NAME_TO_ARCH
MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0]
MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture
class MixtralModuleArchitecture(ModuleArchitecture, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
num_local_experts: int
def name(self) -> str:
return "mixtral"
@classmethod
def from_config(cls, config: PretrainedConfig):
return MixtralModuleArchitecture(num_local_experts=config.num_local_experts)
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return MISTRAL_MODULE_ARCH.pre_weights(config)
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return MISTRAL_MODULE_ARCH.post_weights(config)
def num_layers_config_key(self) -> str:
return MISTRAL_MODULE_ARCH.num_layers_config_key()
def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
num_experts = self.num_local_experts
prefix = f"model.layers.{index}"
tensor_names = []
for expert_idx in range(num_experts):
for param in ("w1", "w2", "w3"):
tensor_names.append(
prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
)
tensor_names.append(prefix + ".block_sparse_moe.gate.weight")
res = []
for name in tensor_names:
res.append(WeightInfo(name=name))
for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config):
if ".mlp." in weight_info.name:
continue
res.append(weight_info)
return res
QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0]
QWEN3_MODULE_ARCH = QWEN3_INFO.modules["default"].architecture
class Qwen3MoeModuleArchitecture(ModuleArchitecture, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "Qwen3MoeForCausalLM"
num_experts: int
def name(self) -> str:
return "qwen3_moe"
@classmethod
def from_config(cls, config: PretrainedConfig):
return Qwen3MoeModuleArchitecture(num_experts=config.num_experts)
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return QWEN3_MODULE_ARCH.pre_weights(config)
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return QWEN3_MODULE_ARCH.post_weights(config)
def num_layers_config_key(self) -> str:
return QWEN3_MODULE_ARCH.num_layers_config_key()
def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
prefix = f"model.layers.{index}"
tensor_names = []
for expert_idx in range(self.num_experts):
for param in ("up_proj", "gate_proj", "down_proj"):
tensor_names.append(
prefix + f".mlp.experts.{expert_idx}.{param}.weight"
)
tensor_names.append(prefix + ".mlp.gate.weight")
res = []
for name in tensor_names:
res.append(WeightInfo(name=name))
for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config):
if ".mlp." in weight_info.name:
continue
res.append(weight_info)
return res
AFMOE_PARTIAL_INFO = NAME_TO_ARCH["_AfmoePartialForCausalLM"][0]
AFMOE_PARTIAL_MODULE_ARCH = AFMOE_PARTIAL_INFO.modules["default"].architecture
class AfmoeModuleArchitecture(ModuleArchitecture, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "AfmoeForCausalLM"
num_experts: int
def name(self) -> str:
return "afmoe"
@classmethod
def from_config(cls, config: PretrainedConfig):
return AfmoeModuleArchitecture(num_experts=config.num_experts)
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return AFMOE_PARTIAL_MODULE_ARCH.pre_weights(config)
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return AFMOE_PARTIAL_MODULE_ARCH.post_weights(config)
def num_layers_config_key(self) -> str:
return AFMOE_PARTIAL_MODULE_ARCH.num_layers_config_key()
def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
res = AFMOE_PARTIAL_MODULE_ARCH.layer_weights(index, config) or []
prefix = f"model.layers.{index}"
for expert_idx in range(self.num_experts):
for param in ("up_proj", "gate_proj", "down_proj"):
res.append(
WeightInfo(
name=prefix + f".mlp.experts.{expert_idx}.{param}.weight",
optional=True,
)
)
return res
# Add this to moe_defs.py
# 1. Get the base Llama info from the registry
LLAMA_INFO = NAME_TO_ARCH["LlamaForCausalLM"][0]
LLAMA_MODULE_ARCH = LLAMA_INFO.modules["default"].architecture
class LlamaMoeModuleArchitecture(ModuleArchitecture, BaseModel):
# This is the name that will appear in the output config.json
ARCHITECTURE_NAME: ClassVar[str] = "LlamaMoeForCausalLM"
num_experts: int
def name(self) -> str:
return "llama_moe"
@classmethod
def from_config(cls, config: PretrainedConfig):
# This looks for the 'num_experts' key in the model's config
return LlamaMoeModuleArchitecture(num_experts=getattr(config, "num_experts", 8))
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
# Uses standard Llama embeddings/norms
return LLAMA_MODULE_ARCH.pre_weights(config)
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
# Uses standard Llama final norm/head
return LLAMA_MODULE_ARCH.post_weights(config)
def num_layers_config_key(self) -> str:
return LLAMA_MODULE_ARCH.num_layers_config_key()
def layer_weights(self, index: int, config: PretrainedConfig) -> Optional[List[WeightInfo]]:
prefix = f"model.layers.{index}"
res = []
# 2. Define the Expert weights
# We map the dense MLP layers into an expert array
for expert_idx in range(self.num_experts):
for param in ("gate_proj", "up_proj", "down_proj"):
res.append(
WeightInfo(name=prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight")
)
# 3. Define the Router (Gate) weight
res.append(WeightInfo(name=prefix + ".block_sparse_moe.gate.weight"))
# 4. Add the non-MLP weights (Attention layers, Input Norms)
# We skip the original .mlp. weights because we replaced them with experts
for weight_info in LLAMA_MODULE_ARCH.layer_weights(index, config):
if ".mlp." not in weight_info.name:
res.append(weight_info)
return res