text2text / verl /models /mcore /model_initializer.py
braindeck
Initial commit
bcdf9fa
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# use mcore transformer config to initialize the model
from abc import ABC, abstractmethod
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from .config_converter import PretrainedConfig, TransformerConfig
class BaseModelInitializer(ABC):
"""Base class for model initializers."""
def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig):
self.tfconfig = tfconfig
self.hf_config = hf_config
@abstractmethod
def get_transformer_layer_spec(self):
"""Get the transformer layer specification.
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py"""
pass
def get_rope_scaling_args(self) -> dict:
"""Get rope scaling args."""
rope_scaling_args = {}
if "rope_scaling" in self.hf_config:
if self.hf_config.rope_scaling is not None:
assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now"
rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"]
return rope_scaling_args
def initialize(
self,
pre_process: bool = True,
post_process: bool = True,
share_embeddings_and_output_weights: bool = False,
value: bool = False,
**extra_kwargs,
) -> GPTModel:
"""Initialize a GPT model with the given configuration.
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py
Args:
pre_process (bool): include embedding layer.
post_process (bool): including an output layer.
share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared.
value (bool): add an extra linear layer for classification or regression.
Returns:
GPTModel: An initialized GPT model instance
"""
transformer_layer_spec = self.get_transformer_layer_spec()
rope_scaling_args = self.get_rope_scaling_args()
model = GPTModel(
config=self.tfconfig,
transformer_layer_spec=transformer_layer_spec,
vocab_size=self.hf_config.vocab_size,
max_sequence_length=self.hf_config.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
share_embeddings_and_output_weights=share_embeddings_and_output_weights,
position_embedding_type="rope",
rotary_base=self.hf_config.rope_theta,
**rope_scaling_args,
)
if post_process and value:
from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer
model.output_layer = LinearForLastLayer(input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig)
return model
class DenseModel(BaseModelInitializer):
"""Initializer for dense models like Llama and Qwen2."""
def get_transformer_layer_spec(self):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
class Qwen2MoEModel(BaseModelInitializer):
"""Initializer for Qwen2 MoE models."""
def get_transformer_layer_spec(self):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
# Patch layer spec for shared experts
for i in range(len(transformer_layer_spec.layer_specs)):
transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True
return transformer_layer_spec
def initialize(self, freeze_moe_router: bool = True, **kwargs):
# Qwen default freeze_moe_router: true
model = super().initialize(**kwargs)
if freeze_moe_router:
for layer in model.decoder.layers:
layer.mlp.router.weight.requires_grad = False
layer.mlp.shared_experts.gate_weight.requires_grad = False
return model
class MixtralModel(BaseModelInitializer):
"""Initializer for Mixtral models."""
def get_transformer_layer_spec(self):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
return transformer_layer_spec
def initialize(self, freeze_moe_router: bool = False, **kwargs):
model = super().initialize(**kwargs)
if freeze_moe_router:
for layer in model.decoder.layers:
layer.mlp.router.weight.requires_grad = False
return model
class Qwen3MoEModel(BaseModelInitializer):
"""Initializer for Qwen3 MoE models."""
def get_transformer_layer_spec(self):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
return transformer_layer_spec
def initialize(self, freeze_moe_router: bool = True, **kwargs):
# Qwen default freeze_moe_router: true
model = super().initialize(**kwargs)
if freeze_moe_router:
for layer in model.decoder.layers:
layer.mlp.router.weight.requires_grad = False
return model
class Qwen25VLModel(BaseModelInitializer):
"""Initializer for Qwen2.5 VL models."""
def get_transformer_layer_spec(self):
raise NotImplementedError("VLM is not supported yet")