# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os try: from mindspeed.megatron_adaptor import repatch except ImportError: repatch = None from verl.trainer.config import CheckpointConfig from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig from ..base import EngineRegistry from ..megatron import MegatronEngineWithLMHead logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @EngineRegistry.register(model_type="language_model", backend="megatron", device="npu") class MindspeedEngineWithLMHead(MegatronEngineWithLMHead): def __init__( self, model_config: HFModelConfig, engine_config: McoreEngineConfig, optimizer_config: McoreOptimizerConfig, checkpoint_config: CheckpointConfig, ): super().__init__(model_config, engine_config, optimizer_config, checkpoint_config) repatch_config = self.engine_config.get("override_transformer_config", {}) repatch_config["use_flash_attn"] = True if self.engine_config.context_parallel_size > 1: repatch_config["context_parallel_size"] = self.engine_config.context_parallel_size repatch(repatch_config)