File size: 11,896 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | # Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# use mcore transformer config to initialize the model
import inspect
from abc import ABC, abstractmethod
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from .config_converter import PretrainedConfig, TransformerConfig
class BaseModelInitializer(ABC):
"""Base class for model initializers."""
def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):
self.tfconfig = tfconfig
self.hf_config = hf_config
self.has_vp_stage = inspect.signature(get_gpt_decoder_block_spec).parameters.get("vp_stage", None) is not None
@abstractmethod
def get_transformer_layer_spec(self, vp_stage=None):
"""Get the transformer layer specification.
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py"""
pass
def get_rope_scaling_args(self) -> dict:
"""Get rope scaling args."""
rope_scaling_args = {}
if "rope_scaling" in self.hf_config:
if self.hf_config.rope_scaling is not None:
# assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now"
rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"]
return rope_scaling_args
def initialize(
self,
pre_process: bool = True,
post_process: bool = True,
share_embeddings_and_output_weights: bool = False,
value: bool = False,
**extra_kwargs,
) -> GPTModel:
"""Initialize a GPT model with the given configuration.
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py
Args:
pre_process (bool): include embedding layer.
post_process (bool): including an output layer.
share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.
value (bool): add an extra linear layer for classification or regression.
Returns:
GPTModel: An initialized GPT model instance
"""
vp_stage = extra_kwargs.get("vp_stage", None)
transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage)
rope_scaling_args = self.get_rope_scaling_args()
mtp_block_spec = extra_kwargs.get("mtp_block_spec", None)
model = GPTModel(
config=self.tfconfig,
transformer_layer_spec=transformer_layer_spec,
vocab_size=self.hf_config.vocab_size,
max_sequence_length=self.hf_config.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
share_embeddings_and_output_weights=share_embeddings_and_output_weights,
position_embedding_type="rope",
rotary_base=self.hf_config.rope_theta,
**rope_scaling_args,
mtp_block_spec=mtp_block_spec,
**({} if not self.has_vp_stage else {"vp_stage": vp_stage}),
)
if post_process and value:
from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer
model.output_layer = LinearForLastLayer(
input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig
)
return model
class DenseModel(BaseModelInitializer):
"""Initializer for dense models like Llama and Qwen2."""
def get_transformer_layer_spec(self, vp_stage=None):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage}
return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)
class Qwen2MoEModel(BaseModelInitializer):
"""Initializer for Qwen2 MoE models."""
def get_transformer_layer_spec(self, vp_stage=None):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage}
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)
# Patch layer spec for shared experts
for i in range(len(transformer_layer_spec.layer_specs)):
transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True
return transformer_layer_spec
def initialize(self, **kwargs):
# Qwen default freeze_moe_router: true
model = super().initialize(**kwargs)
freeze_moe_router = kwargs.get("freeze_moe_router", True)
if freeze_moe_router:
for layer in model.decoder.layers:
layer.mlp.router.weight.requires_grad = False
return model
class MixtralModel(BaseModelInitializer):
"""Initializer for Mixtral models."""
def get_transformer_layer_spec(self, vp_stage=None):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage}
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)
return transformer_layer_spec
def initialize(self, **kwargs):
model = super().initialize(**kwargs)
freeze_moe_router = kwargs.get("freeze_moe_router", False)
if freeze_moe_router:
for layer in model.decoder.layers:
layer.mlp.router.weight.requires_grad = False
return model
class Qwen3MoEModel(BaseModelInitializer):
"""Initializer for Qwen3 MoE models."""
def get_transformer_layer_spec(self, vp_stage=None):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage}
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)
return transformer_layer_spec
def initialize(self, **kwargs):
# Qwen default freeze_moe_router: true
model = super().initialize(**kwargs)
freeze_moe_router = kwargs.get("freeze_moe_router", True)
if freeze_moe_router:
for layer in model.decoder.layers:
layer.mlp.router.weight.requires_grad = False
return model
class DeepseekV3Model(BaseModelInitializer):
"""Initializer for DeepseekV3 models."""
def get_transformer_layer_spec(self, vp_stage=None):
extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage}
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)
return transformer_layer_spec
def get_rope_scaling_args(self) -> dict:
"""Get rope scaling args."""
rope_scaling_args = {}
return rope_scaling_args
def initialize(
self,
**kwargs,
):
vp_stage = kwargs.get("vp_stage", None)
freeze_moe_router = kwargs.get("freeze_moe_router", True)
if freeze_moe_router:
self.tfconfig.moe_router_load_balancing_type = "none"
# MTP
if self.tfconfig.mtp_num_layers is not None and self.tfconfig.mtp_num_layers > 0:
transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage)
mtp_block_spec = get_gpt_mtp_block_spec(
self.tfconfig, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage
)
kwargs["mtp_block_spec"] = mtp_block_spec
model = super().initialize(**kwargs)
if freeze_moe_router:
for layer in model.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
return model
class Qwen25VLModel(BaseModelInitializer):
"""Initializer for Qwen2.5 VL models."""
def get_transformer_layer_spec(self, vp_stage=None):
extra_kwargs = {} if not self.has_vp_stage else {"vp_stage": vp_stage}
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True, **extra_kwargs)
return transformer_layer_spec
def initialize(
self,
pre_process=None,
post_process=None,
share_embeddings_and_output_weights=False,
value=False,
**extra_kwargs,
):
tfconfig = self.tfconfig
hf_config = self.hf_config
# Qwen2_5_VLForConditionalGeneration
from copy import deepcopy
transformer_layer_spec = self.get_transformer_layer_spec()
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from megatron.core.models.gpt.moe_module_specs import MLPSubmodules
from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config
vision_transformer_config = get_vision_model_config(deepcopy(tfconfig))
vision_transformer_config.pipeline_model_parallel_size = 1
vision_transformer_config.first_pipeline_num_layers = None
vision_projection_config = get_vision_projection_config(
deepcopy(tfconfig),
vision_transformer_config.hidden_size,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
)
vision_projection_layer_spec = MLPSubmodules(
linear_fc1=TEColumnParallelLinear,
linear_fc2=TERowParallelLinear,
)
vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec()
qwen25_vl_model = Qwen2_5VLModel(
language_transformer_config=tfconfig,
language_transformer_layer_spec=transformer_layer_spec,
language_vocab_size=hf_config.vocab_size,
language_max_sequence_length=hf_config.max_position_embeddings,
vision_transformer_config=vision_transformer_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_layer_spec,
vision_projection_type="mlp",
language_rotary_base=hf_config.rope_theta,
pre_process=pre_process,
post_process=post_process,
add_decoder=True,
add_encoder=True,
parallel_output=True,
language_share_embeddings_and_output_weights=share_embeddings_and_output_weights,
)
if post_process and value:
from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer
qwen25_vl_model.language_model.output_layer = LinearForLastLayer(
input_size=tfconfig.hidden_size, output_size=1, config=tfconfig
)
return qwen25_vl_model
|