File size: 5,044 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering
# unsupported issues.
SUPPORTED_MOE_MODELS = []
try:
from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM
SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM)
SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM)
except ImportError:
pass
try:
from vllm.model_executor.models.mixtral import MixtralForCausalLM
SUPPORTED_MOE_MODELS.append(MixtralForCausalLM)
except ImportError:
pass
try:
from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM
SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM)
except ImportError:
pass
try:
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM)
except ImportError:
pass
try:
from vllm.model_executor.models.qwen3_vl_moe import Qwen3MoeLLMForCausalLM
SUPPORTED_MOE_MODELS.append(Qwen3MoeLLMForCausalLM)
except ImportError:
pass
try:
from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM
SUPPORTED_MOE_MODELS.append(Qwen3NextForCausalLM)
except ImportError:
pass
try:
from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration
SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration)
except ImportError:
pass
def patch_vllm_moe_model_weight_loader(model):
# this is a work around to load the weight of vllm fused moe model
# it is from a bug from vllm 0.8.2
# all the weights are supposed to have a weight_loader, but the moe weights
# do not have a weight_loader, so we need to patch it
# (True, 'model.embed_tokens.weight')
# (True, 'model.layers.0.self_attn.qkv_proj.weight')
# (True, 'model.layers.0.self_attn.qkv_proj.bias')
# (True, 'model.layers.0.self_attn.o_proj.weight')
# (True, 'model.layers.0.mlp.gate.weight')
# (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight')
# (True, 'model.layers.0.mlp.shared_expert.down_proj.weight')
# (False, 'model.layers.0.mlp.shared_expert_gate.weight') use default
# (False, 'model.layers.0.input_layernorm.weight') use default
# (False, 'model.layers.0.post_attention_layernorm.weight') use default
# (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader
# (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader
# Early return if no MOE models are supported
if not SUPPORTED_MOE_MODELS:
return
original_model_type = type(model)
if hasattr(model, "runnable") and "ACLGraphWrapper" in str(original_model_type):
model = model.runnable
original_model_type = type(model)
# Define MLP attribute mapping for different model types
MLP_ATTR_MAPPING = {}
try:
from vllm.model_executor.models.mixtral import MixtralForCausalLM
MLP_ATTR_MAPPING[MixtralForCausalLM] = "block_sparse_moe"
except ImportError:
pass
DEFAULT_MLP_ATTR = "mlp"
# Get inner model (either model.model or model.language_model)
inner_model = getattr(model, "model", None) or getattr(model, "language_model", None)
if inner_model is None:
raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.")
if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)) and not isinstance(inner_model, tuple(SUPPORTED_MOE_MODELS)):
return
# TODO(@leisuzz): class Qwen3MoeLLMForCausalLM is not available if VLLM version < 0.11.0,
# will update the 'if statement' with 'isinstance' when verl commonly use VLLM version >= 0.11.0
if type(inner_model).__name__ == "Qwen3MoeLLMForCausalLM":
inner_model = inner_model.model # Reassign inner_model in Qwen3-vl
for layer_idx, layer in enumerate(inner_model.layers):
mlp_attr = MLP_ATTR_MAPPING.get(original_model_type, DEFAULT_MLP_ATTR)
mlp = getattr(layer, mlp_attr, None)
if not mlp:
continue
experts = getattr(mlp, "experts", None)
if not experts or not hasattr(experts, "weight_loader"):
continue
# Patch the weight loaders
for name, param in mlp.named_parameters():
if "w13_weight" in name or "w2_weight" in name:
param.weight_loader = experts.weight_loader
|