| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import logging |
| import os |
| from functools import partial |
| from typing import Any, Callable, ContextManager, Iterator, Optional |
|
|
| import torch |
| import torch.distributed |
| from megatron.core import parallel_state as mpu |
| from megatron.core.pipeline_parallel import get_forward_backward_func |
| from omegaconf import OmegaConf |
| from tensordict import TensorDict |
|
|
| import verl.utils.torch_functional as verl_F |
| from verl.models.mcore import get_mcore_forward_fused_no_padding_fn, get_mcore_weight_converter |
| from verl.trainer.config import CheckpointConfig |
| from verl.utils import tensordict_utils as tu |
| from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager |
| from verl.utils.dataset.dataset_utils import DatasetPadMode |
| from verl.utils.debug import log_gpu_memory_usage |
| from verl.utils.device import get_device_id, get_device_name |
| from verl.utils.megatron.pipeline_parallel import make_batch_generator |
| from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch |
| from verl.utils.megatron.router_replay_utils import ( |
| RouterReplayHelper, |
| merge_router_topk_indices, |
| pp_gather, |
| reorder_and_merge_vpp_layers, |
| set_router_replay_data, |
| ) |
| from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits |
| from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm |
| from verl.utils.megatron_utils import ( |
| check_mtp_config, |
| get_megatron_module_device, |
| get_megatron_mtp_loss, |
| load_megatron_model_to_gpu, |
| load_megatron_optimizer, |
| offload_megatron_model_to_cpu, |
| offload_megatron_optimizer, |
| patch_engine_mtp, |
| register_megatron_training_hooks, |
| unwrap_model, |
| ) |
| from verl.utils.model import extract_multi_modal_inputs, load_mcore_dist_weights |
| from verl.utils.seqlen_balancing import restore_dynamic_batch |
| from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig |
|
|
| from ..base import BaseEngine, BaseEngineCtx, EngineRegistry |
| from ..utils import postprocess_batch_func, prepare_micro_batches |
| from .utils import set_random_seed |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| class MegatronEngine(BaseEngine): |
| def __init__( |
| self, |
| model_config: HFModelConfig, |
| engine_config: McoreEngineConfig, |
| optimizer_config: McoreOptimizerConfig, |
| checkpoint_config: CheckpointConfig, |
| ): |
| super().__init__() |
|
|
| self.model_config = model_config |
| self.engine_config = engine_config |
| self.optimizer_config = optimizer_config |
| self.checkpoint_config = checkpoint_config |
| assert self.engine_config.use_mbridge, "use_mbridge must be True" |
| self._init_device_mesh() |
|
|
| set_random_seed(seed=self.engine_config.seed) |
|
|
| self._is_offload_param = self.engine_config.param_offload |
| self._is_offload_grad = self.engine_config.grad_offload |
| self._is_offload_optimizer = self.engine_config.optimizer_offload |
|
|
| self.mode = None |
|
|
| self.layer_name_mapping = { |
| "qkv_layer_name": "self_attention.linear_qkv.", |
| "gate_proj_layer_name": "linear_fc1.", |
| } |
| self.weight_converter = None |
|
|
| |
| self.enable_routing_replay = self.engine_config.router_replay.mode != "disabled" |
| logger.info(f"enable_routing_replay in MegatronEngine: {self.enable_routing_replay}") |
| if self.enable_routing_replay: |
| apply_router_replay_patch() |
| self.mini_layer_topk_idx_list = [] |
|
|
| def _init_device_mesh(self): |
| |
| if mpu.is_initialized(): |
| return |
|
|
| mpu.initialize_model_parallel( |
| tensor_model_parallel_size=self.engine_config.tensor_model_parallel_size, |
| pipeline_model_parallel_size=self.engine_config.pipeline_model_parallel_size, |
| virtual_pipeline_model_parallel_size=self.engine_config.virtual_pipeline_model_parallel_size, |
| use_sharp=False, |
| context_parallel_size=self.engine_config.context_parallel_size, |
| expert_model_parallel_size=self.engine_config.expert_model_parallel_size, |
| expert_tensor_parallel_size=self.engine_config.expert_tensor_parallel_size, |
| nccl_communicator_config_path=None, |
| ) |
|
|
| def _build_tf_config(self): |
| from verl.utils.megatron_utils import mapping_string_to_attn_backend |
| from verl.utils.torch_dtypes import PrecisionType |
|
|
| check_mtp_config(self.model_config, self.engine_config) |
|
|
| self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype) |
| self.dtype = PrecisionType.to_dtype(self.param_dtype) |
|
|
| override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config}) |
| if self.enable_routing_replay: |
| override_transformer_config["enable_routing_replay"] = True |
|
|
| self.provider = None |
| self.vanilla_bridge = self.engine_config.vanilla_mbridge |
|
|
| if self.vanilla_bridge: |
| from verl.models.mcore.mbridge import AutoBridge |
|
|
| bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) |
| bridge.set_extra_args(**override_transformer_config) |
| tf_config = bridge.config |
| tf_config.fp16 = self.param_dtype == torch.float16 |
| tf_config.bf16 = self.param_dtype == torch.bfloat16 |
| else: |
| from verl.models.mcore.bridge import AutoBridge |
|
|
| |
| bridge = AutoBridge.from_hf_pretrained( |
| self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code |
| ) |
| |
| provider = bridge.to_megatron_provider(load_weights=False) |
|
|
| |
| provider.params_dtype = self.param_dtype |
|
|
| |
| provider.fp16 = self.param_dtype == torch.float16 |
| provider.bf16 = self.param_dtype == torch.bfloat16 |
|
|
| |
| provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size |
| provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size |
| provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size |
| provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size |
| provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size |
| provider.context_parallel_size = self.engine_config.context_parallel_size |
| provider.sequence_parallel = self.engine_config.sequence_parallel |
|
|
| |
| from megatron.core.transformer.enums import AttnBackend |
|
|
| provider.attention_backend = AttnBackend.flash |
| provider.variable_seq_lengths = True |
| provider.moe_token_dispatcher_type = "alltoall" |
| provider.moe_router_load_balancing_type = "none" |
|
|
| |
| for key, value in override_transformer_config.items(): |
| setattr(provider, key, value) |
|
|
| provider.finalize() |
| self.provider = provider |
| tf_config = None |
| self.bridge = bridge |
|
|
| if not self.bridge: |
| self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype) |
|
|
| if torch.distributed.get_rank() == 0: |
| if tf_config is not None: |
| print(f"TF config: {tf_config}") |
| self.tf_config = tf_config |
|
|
| from verl.workers.config.megatron_peft import get_peft_cls |
|
|
| self.peft_cls = get_peft_cls( |
| model_config=self.model_config, bridge=self.bridge, provider=self.provider, dtype=self.param_dtype |
| ) |
|
|
| def _build_megatron_module(self): |
| from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module |
| from verl.utils.model import print_model_size |
|
|
| |
| is_value_model = ( |
| "ForTokenClassification" in self.model_config.architectures[0] |
| or "ForSequenceClassification" in self.model_config.architectures[0] |
| ) |
|
|
| self.is_value_model = is_value_model |
|
|
| if self.engine_config.forward_only: |
| wrap_with_ddp = False |
| else: |
| wrap_with_ddp = True |
|
|
| wrap_config = McoreModuleWrapperConfig( |
| is_value_model=is_value_model, |
| share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights, |
| wrap_with_ddp=wrap_with_ddp, |
| use_distributed_optimizer=self.engine_config.use_distributed_optimizer, |
| ) |
| module, updated_tf_config = make_megatron_module( |
| wrap_config=wrap_config, |
| tf_config=self.tf_config, |
| hf_config=self.model_config.hf_config, |
| bridge=self.bridge, |
| provider=self.provider, |
| override_model_config=self.engine_config.override_mcore_model_config, |
| override_ddp_config=self.engine_config.override_ddp_config, |
| peft_cls=self.peft_cls, |
| peft_config=self.model_config.get("lora", None), |
| ) |
| self.tf_config = updated_tf_config |
| print(f"module: {len(module)}") |
|
|
| if self.engine_config.use_dist_checkpointing: |
| load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model) |
| else: |
| if self.vanilla_bridge: |
| self.bridge.load_weights(module, self.model_config.local_path) |
| else: |
| allowed_mismatched_params = [] |
| if self.is_value_model: |
| allowed_mismatched_params = ["output_layer.weight"] |
| self.bridge.load_hf_weights( |
| module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params |
| ) |
|
|
| if torch.distributed.get_rank() == 0: |
| print_model_size(module[0]) |
|
|
| if self.enable_routing_replay: |
| print(f"routing replay layers: {len(RouterReplay.router_instances)}") |
|
|
| return module |
|
|
| def _maybe_enable_fused_kernels(self): |
| if not self.engine_config.use_fused_kernels: |
| return |
|
|
| if self.is_value_model or self.model_config.mtp.enable: |
| logger.warning_once( |
| "Fused kernels are not supported for value models or when MTP is enabled in Megatron engine; disabling." |
| ) |
| self.engine_config.use_fused_kernels = False |
| return |
|
|
| from verl.models.mcore.model_forward_fused import patch_fused_forward |
|
|
| for model in self.module: |
| patch_fused_forward(model) |
|
|
| def _build_optimizer(self): |
| from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config |
|
|
| optim_config_megatron = init_megatron_optim_config( |
| self.optimizer_config, |
| use_distributed_optimizer=self.engine_config.use_distributed_optimizer, |
| fp16=self.param_dtype == torch.float16, |
| ) |
| optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron) |
| register_megatron_training_hooks(self.module, optimizer) |
| return optimizer |
|
|
| def _build_lr_scheduler(self): |
| from verl.utils.megatron.optimizer import get_megatron_optimizer_param_scheduler |
|
|
| optimizer_scheduler = get_megatron_optimizer_param_scheduler( |
| optimizer=self.optimizer, config=self.optimizer_config |
| ) |
| return optimizer_scheduler |
|
|
| @property |
| def is_param_offload_enabled(self) -> bool: |
| return self._is_offload_param |
|
|
| @property |
| def is_optimizer_offload_enabled(self) -> bool: |
| return self._is_offload_optimizer |
|
|
| def is_mp_src_rank_with_outputs(self): |
| return ( |
| mpu.get_tensor_model_parallel_rank() == 0 |
| and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 |
| and mpu.get_context_parallel_rank() == 0 |
| ) |
|
|
| def initialize(self): |
| self._build_tf_config() |
|
|
| self.module = self._build_megatron_module() |
|
|
| self._maybe_enable_fused_kernels() |
|
|
| if self.model_config.mtp.enable: |
| patch_engine_mtp(self.module, self.model_config) |
|
|
| |
| if self.engine_config.forward_only: |
| self.optimizer = None |
| self.lr_scheduler = None |
| self.to(device="cpu", model=self._is_offload_param, optimizer=False, grad=False) |
| log_gpu_memory_usage("After offload model during init (forward_only)", logger=logger) |
| return |
|
|
| self.optimizer = self._build_optimizer() |
| self.lr_scheduler = self._build_lr_scheduler() |
|
|
| full_reshardable = self.engine_config.dist_ckpt_optim_fully_reshardable |
| mem_eff = self.engine_config.distrib_optim_fully_reshardable_mem_efficient |
|
|
| tmp_config = OmegaConf.create( |
| { |
| "model": {"path": self.model_config.local_path}, |
| "megatron": { |
| "dist_ckpt_optim_fully_reshardable": full_reshardable, |
| "distrib_optim_fully_reshardable_mem_efficient": mem_eff, |
| }, |
| } |
| ) |
|
|
| role = "actor" if not self.is_value_model else "critic" |
|
|
| self.checkpoint_mananager = MegatronCheckpointManager( |
| config=tmp_config, |
| checkpoint_config=self.checkpoint_config, |
| model_config=self.model_config.hf_config, |
| transformer_config=self.tf_config, |
| role=role, |
| model=self.module, |
| arch=self.model_config.architectures[0], |
| hf_config=self.model_config.hf_config, |
| param_dtype=self.param_dtype, |
| share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights, |
| processing_class=self.model_config.get_processor(), |
| optimizer=self.optimizer, |
| optimizer_scheduler=self.lr_scheduler, |
| use_distributed_optimizer=self.engine_config.use_distributed_optimizer, |
| use_checkpoint_opt_param_scheduler=self.optimizer_config.use_checkpoint_opt_param_scheduler, |
| bridge=self.bridge, |
| provider=self.provider, |
| peft_cls=self.peft_cls, |
| use_dist_checkpointing=self.engine_config.use_dist_checkpointing, |
| ) |
|
|
| self.to( |
| device="cpu", |
| model=self._is_offload_param, |
| optimizer=self._is_offload_optimizer, |
| grad=self._is_offload_param, |
| ) |
|
|
| log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger) |
|
|
| def train_mode(self, **kwargs): |
| """ |
| Context manager entry for switching the engine and model into training mode. |
| |
| Usage: |
| with engine.train_mode(): |
| # runs in training mode |
| """ |
| return EngineTrainModeCtx(self, **kwargs) |
|
|
| def eval_mode(self, **kwargs): |
| """ |
| Context manager entry for switching the engine and model into evaluation mode. |
| |
| Usage: |
| with engine.eval_mode(): |
| # runs in evaluation mode |
| """ |
| return EngineEvalModeCtx(self, **kwargs) |
|
|
| def optimizer_zero_grad(self): |
| """ |
| Zero out gradients of all parameters before starting a new backward pass. |
| """ |
| self.optimizer.zero_grad() |
| |
| for chunk in self.module: |
| |
| chunk.zero_grad_buffer() |
|
|
| def optimizer_step(self): |
| """ |
| Perform an optimization step to update model parameters based on accumulated gradients. |
| |
| Returns: |
| grad_norm (float): The norm of the gradients before clipping or update. |
| """ |
| update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() |
|
|
| if update_successful: |
| |
| pass |
| else: |
| raise NotImplementedError("Megatron optimizer step failed. This should not happen") |
|
|
| return grad_norm |
|
|
| def lr_scheduler_step(self): |
| """ |
| Advance the learning rate scheduler by one step. |
| |
| Returns: |
| current_lr (float or list[float]): Updated learning rate(s). |
| """ |
| from verl.utils.megatron.optimizer import get_megatron_last_lr |
|
|
| self.lr_scheduler.step(1) |
| return get_megatron_last_lr(self.optimizer) |
|
|
| def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): |
| """ |
| Move model parameters, optimizer states, or both to the specified device. |
| Note that this function executes irrespective of offload config. It serves as manual control |
| |
| Args: |
| device: Target device identifier. |
| model: If True, move the model. |
| optimizer: If True, move the optimizer states. |
| """ |
| super().to(device=device, model=model, optimizer=optimizer, grad=grad) |
|
|
| device_name = get_device_name() |
|
|
| assert device in (device_name, "cpu") |
| if device == device_name: |
| if model: |
| load_megatron_model_to_gpu(self.module, load_grad=grad) |
| if optimizer and self.optimizer is not None: |
| load_megatron_optimizer(self.optimizer) |
| elif device == "cpu": |
| if model: |
| offload_megatron_model_to_cpu(self.module) |
| if optimizer and self.optimizer is not None: |
| offload_megatron_optimizer(self.optimizer) |
| else: |
| raise ValueError(f"Invalid device type: {device}") |
|
|
| def get_data_parallel_rank(self): |
| return mpu.get_data_parallel_rank() |
|
|
| def get_data_parallel_size(self): |
| return mpu.get_data_parallel_world_size() |
|
|
| def get_data_parallel_group(self): |
| return mpu.get_data_parallel_group() |
|
|
| def get_model_parallel_group(self): |
| return mpu.get_model_parallel_group() |
|
|
| def get_context_parallel_group(self): |
| return mpu.get_context_parallel_group() |
|
|
| def save_checkpoint( |
| self, |
| local_path: str, |
| hdfs_path: Optional[str] = None, |
| global_step: int = 0, |
| max_ckpt_to_keep: Optional[int] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| Save model, optimizer, and scheduler states to a checkpoint. |
| |
| Args: |
| local_path: Local filesystem path to save checkpoint. |
| hdfs_path: Optional HDFS path to copy checkpoint. |
| global_step: Integer training step number for naming. |
| max_ckpt_to_keep: Maximum number of recent checkpoints to retain. |
| """ |
| origin_module_device = get_megatron_module_device(self.module) |
| if self._is_offload_param or origin_module_device == "cpu": |
| load_megatron_model_to_gpu(self.module, load_grad=True) |
| self.checkpoint_mananager.save_checkpoint( |
| local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep |
| ) |
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.module) |
|
|
| def load_checkpoint( |
| self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs |
| ) -> None: |
| """ |
| Load model, optimizer, and scheduler states from a checkpoint. |
| |
| Args: |
| local_path: Local filesystem path of the checkpoint. |
| hdfs_path: Optional HDFS path where checkpoint is stored. |
| del_local_after_load: Whether to delete local copy after loading. |
| """ |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.module) |
| self.checkpoint_mananager.load_checkpoint( |
| local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load |
| ) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.module) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.optimizer) |
|
|
| def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any: |
| tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size) |
|
|
| |
| batch_num_tokens = data["loss_mask"].sum().to(get_device_id()) |
| torch.distributed.all_reduce( |
| batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group() |
| ) |
| tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item()) |
| tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size()) |
|
|
| vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() |
| if vpp_size is not None and vpp_size > 1: |
| num_batches_divided_by = self.tf_config.microbatch_group_size_per_vp_stage |
| else: |
| num_batches_divided_by = None |
|
|
| micro_batches, indices = prepare_micro_batches( |
| data=data, |
| dp_group=self.get_data_parallel_group(), |
| num_batches_divided_by=num_batches_divided_by, |
| same_micro_num_in_dp=True, |
| min_num_micro_batch=None, |
| ) |
|
|
| if num_batches_divided_by is not None: |
| assert len(micro_batches) % num_batches_divided_by == 0, ( |
| f"micro_batches {micro_batches} must be divisible by num_batches_divided_by " |
| f"{num_batches_divided_by} for megatron backend" |
| ) |
|
|
| |
| n_micro_batch = len(micro_batches) |
|
|
| for micro_batch in micro_batches: |
| tu.assign_non_tensor(micro_batch, num_micro_batch=n_micro_batch) |
|
|
| forward_backward_func = get_forward_backward_func() |
|
|
| postprocess_micro_batch_func = partial( |
| self.postprocess_micro_batch_func, |
| forward_only=forward_only, |
| loss_function=loss_function, |
| ) |
|
|
| tu.assign_non_tensor(data, num_micro_batch=n_micro_batch) |
|
|
| forward_step = partial(self.forward_step, postprocess_micro_batch_func=postprocess_micro_batch_func) |
|
|
| enable_routing_replay = tu.get_non_tensor_data(data, key="enable_routing_replay", default=False) |
|
|
| if enable_routing_replay: |
| |
| RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) |
| if forward_only and self.engine_config.router_replay.mode == "R2": |
| |
| RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) |
|
|
| |
| batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.module)) |
|
|
| |
| |
| losses_reduced = forward_backward_func( |
| forward_step_func=forward_step, |
| data_iterator=batch_generator, |
| model=self.module, |
| num_microbatches=n_micro_batch, |
| seq_length=1, |
| micro_batch_size=1, |
| forward_only=forward_only, |
| ) |
|
|
| if self.model_config.mtp.enable and self.is_mp_src_rank_with_outputs(): |
| |
| metrics = get_megatron_mtp_loss(n_micro_batch) |
| if "metrics" not in losses_reduced[0]: |
| losses_reduced[0]["metrics"] = {} |
| losses_reduced[0]["metrics"].update(metrics) |
|
|
| if RouterReplayHelper.is_r2_record_action(self.tf_config): |
| if self.tf_config.virtual_pipeline_model_parallel_size is not None: |
| |
| vp_size = len(self.module) |
| microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage |
| bs = n_micro_batch |
| topk_idx_td = reorder_and_merge_vpp_layers( |
| self.mini_layer_topk_idx_list, bs, vp_size, microbatch_group_size_per_vp_stage |
| ) |
| else: |
| tensors = [tensor for nt in self.mini_layer_topk_idx_list for tensor in nt.unbind()] |
| topk_idx_td = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) |
| self.mini_layer_topk_idx_list = [] |
|
|
| layers_topk_idx = pp_gather(topk_idx_td.to(torch.uint8), self.tf_config) |
| use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True) |
| if use_dynamic_bsz and indices is not None: |
| layers_topk_idx = restore_dynamic_batch(layers_topk_idx, indices) |
|
|
| output = {} |
| if mpu.is_pipeline_last_stage(ignore_virtual=True): |
| output = postprocess_batch_func(output_lst=losses_reduced, indices=indices, data=data) |
| if RouterReplayHelper.is_r2_record_action(self.tf_config): |
| output["model_output"]["routed_experts"] = layers_topk_idx |
| if enable_routing_replay: |
| RouterReplay.clear_global_indices() |
| RouterReplay.clear_global_router_replay_action() |
| return output |
|
|
| def get_per_tensor_param(self, base_sync_done=False, **kwargs): |
| peft_config = None |
| non_merge_lora_sync = self.peft_cls is not None and not self.model_config.lora.get("merge", False) |
| adapter_only = base_sync_done and non_merge_lora_sync |
| |
| load_megatron_model_to_gpu(self.module, load_grad=False, load_frozen_params=not adapter_only) |
| if self.vanilla_bridge: |
| per_tensor_param = self.bridge.export_weights(self.module) |
| elif adapter_only: |
| |
| peft_config = build_peft_config_for_vllm(self.model_config.lora) |
| per_tensor_param = self.bridge.export_adapter_weights(self.module) |
| else: |
| per_tensor_param = self.bridge.export_hf_weights(self.module) |
| if non_merge_lora_sync: |
| per_tensor_param = add_base_layer_suffix( |
| per_tensor_param, model_type=self.model_config.hf_config.model_type |
| ) |
| return per_tensor_param, peft_config |
|
|
| def disable_adapter(self) -> ContextManager: |
| return self.peft_cls.disable_adapter(self.module) |
|
|
| def forward_step(self, batch_iter, model, postprocess_micro_batch_func): |
| raise NotImplementedError("forward_step must be implemented in subclass") |
|
|
| def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function): |
| raise NotImplementedError("postprocess_micro_batch_func must be implemented in subclass") |
|
|
|
|
| class EngineEvalModeCtx(BaseEngineCtx): |
| def __init__(self, engine: MegatronEngine, **kwargs): |
| super().__init__(engine=engine, mode="eval", **kwargs) |
|
|
| def __enter__(self): |
| assert isinstance(self.engine, MegatronEngine) |
| super().__enter__() |
| |
| for module in self.engine.module: |
| module.eval() |
|
|
| def __exit__(self, exc_type, exc_value, traceback): |
| assert isinstance(self.engine, MegatronEngine) |
| super().__exit__(exc_type, exc_value, traceback) |
|
|
|
|
| class EngineTrainModeCtx(BaseEngineCtx): |
| def __init__(self, engine: MegatronEngine, **kwargs): |
| super().__init__(engine=engine, mode="train", **kwargs) |
|
|
| def __enter__(self): |
| assert isinstance(self.engine, MegatronEngine) |
| super().__enter__() |
| |
| for module in self.engine.module: |
| module.train() |
|
|
| def __exit__(self, exc_type, exc_value, traceback): |
| assert isinstance(self.engine, MegatronEngine) |
| self.engine.optimizer_zero_grad() |
| super().__exit__(exc_type, exc_value, traceback) |
|
|
|
|
| @EngineRegistry.register(model_type="language_model", backend="megatron") |
| class MegatronEngineWithLMHead(MegatronEngine): |
| def prepare_model_inputs(self, batch: TensorDict): |
| input_ids = batch["input_ids"] |
| loss_mask = batch["loss_mask"].to(bool) |
| multi_modal_inputs = extract_multi_modal_inputs(batch.get("multi_modal_inputs", [])) |
|
|
| routed_experts = batch.get("routed_experts", None) |
|
|
| return { |
| "input_ids": input_ids, |
| "loss_mask": loss_mask, |
| "multi_modal_inputs": multi_modal_inputs, |
| "routed_experts": routed_experts, |
| } |
|
|
| def prepare_model_outputs(self, output: dict, data: TensorDict): |
| calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False) |
|
|
| log_prob = output["log_probs"] |
| model_output = {"log_probs": log_prob} |
| if calculate_entropy: |
| entropy = output["entropy"] |
| model_output["entropy"] = entropy |
|
|
| return model_output |
|
|
| def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micro_batch_func): |
| batch: TensorDict = next(batch_iter) |
| batch = batch.to(get_device_id()) |
| use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False) |
| calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False) |
| pad_mode = tu.get_non_tensor_data(batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) |
| temperature = batch["temperature"] |
| model_inputs = self.prepare_model_inputs(batch) |
| input_ids = model_inputs["input_ids"] |
| multi_modal_inputs = model_inputs["multi_modal_inputs"] |
| loss_mask = model_inputs["loss_mask"] |
|
|
| unwrapped_model = unwrap_model(model) |
| if hasattr(unwrapped_model, "vp_stage"): |
| vp_rank = unwrapped_model.vp_stage |
| else: |
| vp_rank = 0 |
|
|
| if RouterReplayHelper.is_replay_backward_action(self.tf_config, vp_rank): |
| router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank) |
| for router in router_instance_list: |
| router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) |
|
|
| if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): |
| layers_topk_idx = model_inputs["routed_experts"] |
| set_router_replay_data(layers_topk_idx, None, self.tf_config, vp_rank) |
|
|
| if pad_mode == DatasetPadMode.NO_PADDING: |
| label = input_ids.clone() |
| else: |
| raise NotImplementedError(f"Pad mode {pad_mode} is not supported for megatron engine") |
|
|
| from verl.models.mcore import get_mcore_forward_no_padding_fn |
|
|
| if use_fused_kernels: |
| if not self.engine_config.use_remove_padding: |
| logger.warning_once( |
| "Fused kernels require `use_remove_padding=True` for Megatron engine. Falling back to non-fused." |
| ) |
| use_fused_kernels = False |
| elif isinstance(temperature, torch.Tensor): |
| if temperature.numel() != 1: |
| logger.warning_once( |
| "Fused kernels do not support per-sample temperature. Falling back to non-fused." |
| ) |
| use_fused_kernels = False |
| else: |
| temperature_value = float(temperature.item()) |
| else: |
| temperature_value = float(temperature) |
|
|
| if use_fused_kernels: |
| fused_forward_fn = get_mcore_forward_fused_no_padding_fn(self.model_config.hf_config) |
| output = fused_forward_fn( |
| model=model, |
| input_ids=input_ids, |
| labels=label, |
| multi_modal_inputs=multi_modal_inputs, |
| temperature=temperature_value, |
| calculate_entropy=calculate_entropy, |
| pad_token_id=self.model_config.tokenizer.pad_token_id, |
| ) |
| else: |
| if not isinstance(temperature, torch.Tensor): |
| temperature = torch.tensor([temperature] * input_ids.shape[0], device=input_ids.device) |
|
|
| temperature = temperature.to(torch.float32) |
| assert temperature.shape[0] == input_ids.shape[0] |
| temperature = verl_F.expand_as_nested(temperature, input_ids) |
|
|
| forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config) |
|
|
| def logits_processor(logits, label, temperature): |
| assert logits.shape[:2] == label.shape[:2] |
| |
| temperature[temperature <= 0] = 1e-8 |
| assert torch.all(temperature > 0).item(), f"temperature tensor must be positive. Got {temperature}" |
| logits.div_(temperature.unsqueeze(dim=-1).to(logits.dtype)) |
| ret = {} |
| if calculate_entropy: |
| logits_bak = logits.clone() |
| |
| |
| |
| |
| |
| |
| |
| entropy = vocab_parallel_entropy(logits) |
| ret["entropy"] = entropy |
| else: |
| logits_bak = logits |
|
|
| log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) |
| ret["log_probs"] = log_probs |
| return ret |
|
|
| logits_processor_args = {"label": label, "temperature": temperature, "loss_mask": loss_mask} |
|
|
| output = forward_fn( |
| model, |
| input_ids, |
| multi_modal_inputs, |
| logits_processor=logits_processor, |
| logits_processor_args=logits_processor_args, |
| vision_model=hasattr(self.model_config.hf_config, "vision_config"), |
| pad_token_id=self.model_config.tokenizer.pad_token_id, |
| data_format="thd" if self.engine_config.use_remove_padding else "bshd", |
| mtp_enable_train=self.model_config.mtp.enable and self.model_config.mtp.enable_train, |
| ) |
|
|
| |
| if RouterReplayHelper.is_r2_record_action(self.tf_config, vp_rank): |
| merge_router_topk_indices(None, input_ids, self.mini_layer_topk_idx_list, self.tf_config, vp_rank) |
|
|
| |
| if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): |
| router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank) |
| for router in router_instance_list: |
| router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) |
|
|
| return output, partial(postprocess_micro_batch_func, data=batch) |
|
|
| def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function): |
| |
| |
| device = data["input_ids"].device |
| model_output = self.prepare_model_outputs(output, data) |
|
|
| if loss_function is not None: |
| loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) |
| |
| |
| scaled_loss = loss * data["num_micro_batch"] |
| else: |
| assert forward_only, "forward_only must be True when loss_function is None" |
| loss = torch.tensor(1.0, device=device) |
| scaled_loss = loss |
| metrics = {} |
|
|
| output = { |
| "model_output": model_output, |
| "loss": loss.detach().item(), |
| "metrics": metrics, |
| } |
|
|
| |
| return scaled_loss, output |
|
|
|
|
| @EngineRegistry.register(model_type="value_model", backend="megatron") |
| class MegatronEngineWithValueHead(MegatronEngineWithLMHead): |
| |
| def forward_step(self, batch_iter, model, postprocess_micro_batch_func): |
| batch: TensorDict = next(batch_iter) |
| batch = batch.to(get_device_id()) |
| model_inputs = self.prepare_model_inputs(batch) |
| input_ids = model_inputs["input_ids"] |
| multi_modal_inputs = model_inputs["multi_modal_inputs"] |
|
|
| from verl.models.mcore import get_mcore_forward_no_padding_fn |
|
|
| forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config) |
|
|
| output = forward_fn( |
| model, |
| input_ids, |
| multi_modal_inputs, |
| value_model=True, |
| vision_model=hasattr(self.model_config.hf_config, "vision_config"), |
| pad_token_id=self.model_config.tokenizer.pad_token_id, |
| enable_mtp=self.model_config.mtp.enable_train, |
| ) |
|
|
| return output, partial(postprocess_micro_batch_func, data=batch) |
|
|
| def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict): |
| return {"values": output} |
|
|