text2text / verl /models /mcore /registry.py
braindeck
Initial commit
bcdf9fa
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Registry module for model architecture components.
"""
from enum import Enum
from typing import Callable, Dict, Type
import torch
import torch.nn as nn
from .config_converter import (
PretrainedConfig,
TransformerConfig,
hf_to_mcore_config_dense,
hf_to_mcore_config_dpskv3,
hf_to_mcore_config_llama4,
hf_to_mcore_config_mixtral,
hf_to_mcore_config_qwen2_5_vl,
hf_to_mcore_config_qwen2moe,
hf_to_mcore_config_qwen3moe,
)
from .model_forward import (
gptmodel_forward,
)
from .model_initializer import (
BaseModelInitializer,
DenseModel,
MixtralModel,
Qwen2MoEModel,
Qwen3MoEModel,
Qwen25VLModel,
)
from .weight_converter import (
McoreToHFWeightConverterDense,
McoreToHFWeightConverterMixtral,
McoreToHFWeightConverterQwen2Moe,
McoreToHFWeightConverterQwen3Moe,
)
class SupportedModel(Enum):
LLAMA = "LlamaForCausalLM" # tested
QWEN2 = "Qwen2ForCausalLM" # tested
QWEN2_MOE = "Qwen2MoeForCausalLM" # pending
DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested
MIXTRAL = "MixtralForCausalLM" # tested
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported
LLAMA4 = "Llama4ForConditionalGeneration" # not tested
QWEN3 = "Qwen3ForCausalLM" # tested
QWEN3_MOE = "Qwen3MoeForCausalLM" # not tested
# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
SupportedModel.LLAMA: hf_to_mcore_config_dense,
SupportedModel.QWEN2: hf_to_mcore_config_dense,
SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
SupportedModel.QWEN3: hf_to_mcore_config_dense,
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
}
# Registry for model initializers
MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
SupportedModel.LLAMA: DenseModel,
SupportedModel.QWEN2: DenseModel,
SupportedModel.QWEN2_MOE: Qwen2MoEModel,
SupportedModel.MIXTRAL: MixtralModel,
SupportedModel.DEEPSEEK_V3: DenseModel,
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
SupportedModel.LLAMA4: DenseModel,
SupportedModel.QWEN3: DenseModel,
SupportedModel.QWEN3_MOE: Qwen3MoEModel,
}
# Registry for model forward functions
MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
SupportedModel.LLAMA: gptmodel_forward,
SupportedModel.QWEN2: gptmodel_forward,
SupportedModel.QWEN2_MOE: gptmodel_forward,
SupportedModel.MIXTRAL: gptmodel_forward,
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
SupportedModel.QWEN2_5_VL: gptmodel_forward,
SupportedModel.LLAMA4: gptmodel_forward,
SupportedModel.QWEN3: gptmodel_forward,
SupportedModel.QWEN3_MOE: gptmodel_forward,
}
# Registry for model weight converters
MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {
SupportedModel.LLAMA: McoreToHFWeightConverterDense,
SupportedModel.QWEN2: McoreToHFWeightConverterDense,
SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
}
def get_supported_model(model_type: str) -> SupportedModel:
try:
return SupportedModel(model_type)
except ValueError as err:
supported_models = [e.value for e in SupportedModel]
raise NotImplementedError(f"Model Type: {model_type} not supported. Supported models: {supported_models}") from err
def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig:
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype)
def init_mcore_model(
tfconfig: TransformerConfig,
hf_config: PretrainedConfig,
pre_process: bool = True,
post_process: bool = None,
*,
share_embeddings_and_output_weights: bool = False,
value: bool = False,
**extra_kwargs, # may be used for vlm and moe
) -> nn.Module:
"""
Initialize a Mcore model.
Args:
tfconfig: The transformer config.
hf_config: The HuggingFace config.
pre_process: Optional pre-processing function.
post_process: Optional post-processing function.
share_embeddings_and_output_weights: Whether to share embeddings and output weights.
value: Whether to use value.
**extra_kwargs: Additional keyword arguments.
Returns:
The initialized model.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
initializer_cls = MODEL_INITIALIZER_REGISTRY[model]
initializer = initializer_cls(tfconfig, hf_config)
return initializer.initialize(pre_process=pre_process, post_process=post_process, share_embeddings_and_output_weights=share_embeddings_and_output_weights, value=value, **extra_kwargs)
def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
return MODEL_FORWARD_REGISTRY[model]
def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
"""
Get the weight converter for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
tfconfig = hf_to_mcore_config(hf_config, dtype)
return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)