Spaces:
Running
Running
| # Copyright (C) 2025 Arcee AI | |
| # SPDX-License-Identifier: LGPL-3.0-only | |
| from typing import ClassVar, List, Optional | |
| from pydantic import BaseModel | |
| from transformers import PretrainedConfig | |
| from mergekit.architecture.base import ( | |
| ModuleArchitecture, | |
| WeightInfo, | |
| ) | |
| from mergekit.architecture.json_definitions import NAME_TO_ARCH | |
| MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0] | |
| MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture | |
| class MixtralModuleArchitecture(ModuleArchitecture, BaseModel): | |
| ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" | |
| num_local_experts: int | |
| def name(self) -> str: | |
| return "mixtral" | |
| def from_config(cls, config: PretrainedConfig): | |
| return MixtralModuleArchitecture(num_local_experts=config.num_local_experts) | |
| def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| return MISTRAL_MODULE_ARCH.pre_weights(config) | |
| def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| return MISTRAL_MODULE_ARCH.post_weights(config) | |
| def num_layers_config_key(self) -> str: | |
| return MISTRAL_MODULE_ARCH.num_layers_config_key() | |
| def layer_weights( | |
| self, index: int, config: PretrainedConfig | |
| ) -> Optional[List[WeightInfo]]: | |
| num_experts = self.num_local_experts | |
| prefix = f"model.layers.{index}" | |
| tensor_names = [] | |
| for expert_idx in range(num_experts): | |
| for param in ("w1", "w2", "w3"): | |
| tensor_names.append( | |
| prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" | |
| ) | |
| tensor_names.append(prefix + ".block_sparse_moe.gate.weight") | |
| res = [] | |
| for name in tensor_names: | |
| res.append(WeightInfo(name=name)) | |
| for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config): | |
| if ".mlp." in weight_info.name: | |
| continue | |
| res.append(weight_info) | |
| return res | |
| QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0] | |
| QWEN3_MODULE_ARCH = QWEN3_INFO.modules["default"].architecture | |
| class Qwen3MoeModuleArchitecture(ModuleArchitecture, BaseModel): | |
| ARCHITECTURE_NAME: ClassVar[str] = "Qwen3MoeForCausalLM" | |
| num_experts: int | |
| def name(self) -> str: | |
| return "qwen3_moe" | |
| def from_config(cls, config: PretrainedConfig): | |
| return Qwen3MoeModuleArchitecture(num_experts=config.num_experts) | |
| def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| return QWEN3_MODULE_ARCH.pre_weights(config) | |
| def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| return QWEN3_MODULE_ARCH.post_weights(config) | |
| def num_layers_config_key(self) -> str: | |
| return QWEN3_MODULE_ARCH.num_layers_config_key() | |
| def layer_weights( | |
| self, index: int, config: PretrainedConfig | |
| ) -> Optional[List[WeightInfo]]: | |
| prefix = f"model.layers.{index}" | |
| tensor_names = [] | |
| for expert_idx in range(self.num_experts): | |
| for param in ("up_proj", "gate_proj", "down_proj"): | |
| tensor_names.append( | |
| prefix + f".mlp.experts.{expert_idx}.{param}.weight" | |
| ) | |
| tensor_names.append(prefix + ".mlp.gate.weight") | |
| res = [] | |
| for name in tensor_names: | |
| res.append(WeightInfo(name=name)) | |
| for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config): | |
| if ".mlp." in weight_info.name: | |
| continue | |
| res.append(weight_info) | |
| return res | |
| AFMOE_PARTIAL_INFO = NAME_TO_ARCH["_AfmoePartialForCausalLM"][0] | |
| AFMOE_PARTIAL_MODULE_ARCH = AFMOE_PARTIAL_INFO.modules["default"].architecture | |
| class AfmoeModuleArchitecture(ModuleArchitecture, BaseModel): | |
| ARCHITECTURE_NAME: ClassVar[str] = "AfmoeForCausalLM" | |
| num_experts: int | |
| def name(self) -> str: | |
| return "afmoe" | |
| def from_config(cls, config: PretrainedConfig): | |
| return AfmoeModuleArchitecture(num_experts=config.num_experts) | |
| def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| return AFMOE_PARTIAL_MODULE_ARCH.pre_weights(config) | |
| def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| return AFMOE_PARTIAL_MODULE_ARCH.post_weights(config) | |
| def num_layers_config_key(self) -> str: | |
| return AFMOE_PARTIAL_MODULE_ARCH.num_layers_config_key() | |
| def layer_weights( | |
| self, index: int, config: PretrainedConfig | |
| ) -> Optional[List[WeightInfo]]: | |
| res = AFMOE_PARTIAL_MODULE_ARCH.layer_weights(index, config) or [] | |
| prefix = f"model.layers.{index}" | |
| for expert_idx in range(self.num_experts): | |
| for param in ("up_proj", "gate_proj", "down_proj"): | |
| res.append( | |
| WeightInfo( | |
| name=prefix + f".mlp.experts.{expert_idx}.{param}.weight", | |
| optional=True, | |
| ) | |
| ) | |
| return res | |
| # Add this to moe_defs.py | |
| # 1. Get the base Llama info from the registry | |
| LLAMA_INFO = NAME_TO_ARCH["LlamaForCausalLM"][0] | |
| LLAMA_MODULE_ARCH = LLAMA_INFO.modules["default"].architecture | |
| class LlamaMoeModuleArchitecture(ModuleArchitecture, BaseModel): | |
| # This is the name that will appear in the output config.json | |
| ARCHITECTURE_NAME: ClassVar[str] = "LlamaMoeForCausalLM" | |
| num_experts: int | |
| def name(self) -> str: | |
| return "llama_moe" | |
| def from_config(cls, config: PretrainedConfig): | |
| # This looks for the 'num_experts' key in the model's config | |
| return LlamaMoeModuleArchitecture(num_experts=getattr(config, "num_experts", 8)) | |
| def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| # Uses standard Llama embeddings/norms | |
| return LLAMA_MODULE_ARCH.pre_weights(config) | |
| def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: | |
| # Uses standard Llama final norm/head | |
| return LLAMA_MODULE_ARCH.post_weights(config) | |
| def num_layers_config_key(self) -> str: | |
| return LLAMA_MODULE_ARCH.num_layers_config_key() | |
| def layer_weights(self, index: int, config: PretrainedConfig) -> Optional[List[WeightInfo]]: | |
| prefix = f"model.layers.{index}" | |
| res = [] | |
| # 2. Define the Expert weights | |
| # We map the dense MLP layers into an expert array | |
| for expert_idx in range(self.num_experts): | |
| for param in ("gate_proj", "up_proj", "down_proj"): | |
| res.append( | |
| WeightInfo(name=prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight") | |
| ) | |
| # 3. Define the Router (Gate) weight | |
| res.append(WeightInfo(name=prefix + ".block_sparse_moe.gate.weight")) | |
| # 4. Add the non-MLP weights (Attention layers, Input Norms) | |
| # We skip the original .mlp. weights because we replaced them with experts | |
| for weight_info in LLAMA_MODULE_ARCH.layer_weights(index, config): | |
| if ".mlp." not in weight_info.name: | |
| res.append(weight_info) | |
| return res |