File size: 39,345 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from functools import partial
from typing import Any, Callable, ContextManager, Iterator, Optional
import torch
import torch.distributed
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
from omegaconf import OmegaConf
from tensordict import TensorDict
import verl.utils.torch_functional as verl_F
from verl.models.mcore import get_mcore_forward_fused_no_padding_fn, get_mcore_weight_converter
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager
from verl.utils.dataset.dataset_utils import DatasetPadMode
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import get_device_id, get_device_name
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch
from verl.utils.megatron.router_replay_utils import (
RouterReplayHelper,
merge_router_topk_indices,
pp_gather,
reorder_and_merge_vpp_layers,
set_router_replay_data,
)
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm
from verl.utils.megatron_utils import (
check_mtp_config,
get_megatron_module_device,
get_megatron_mtp_loss,
load_megatron_model_to_gpu,
load_megatron_optimizer,
offload_megatron_model_to_cpu,
offload_megatron_optimizer,
patch_engine_mtp,
register_megatron_training_hooks,
unwrap_model,
)
from verl.utils.model import extract_multi_modal_inputs, load_mcore_dist_weights
from verl.utils.seqlen_balancing import restore_dynamic_batch
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
from ..base import BaseEngine, BaseEngineCtx, EngineRegistry
from ..utils import postprocess_batch_func, prepare_micro_batches
from .utils import set_random_seed
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class MegatronEngine(BaseEngine):
def __init__(
self,
model_config: HFModelConfig,
engine_config: McoreEngineConfig,
optimizer_config: McoreOptimizerConfig,
checkpoint_config: CheckpointConfig,
):
super().__init__()
self.model_config = model_config
self.engine_config = engine_config
self.optimizer_config = optimizer_config
self.checkpoint_config = checkpoint_config
assert self.engine_config.use_mbridge, "use_mbridge must be True"
self._init_device_mesh()
set_random_seed(seed=self.engine_config.seed)
self._is_offload_param = self.engine_config.param_offload
self._is_offload_grad = self.engine_config.grad_offload
self._is_offload_optimizer = self.engine_config.optimizer_offload
self.mode = None
self.layer_name_mapping = {
"qkv_layer_name": "self_attention.linear_qkv.",
"gate_proj_layer_name": "linear_fc1.",
}
self.weight_converter = None
# Router replay configuration for MoE models
self.enable_routing_replay = self.engine_config.router_replay.mode != "disabled"
logger.info(f"enable_routing_replay in MegatronEngine: {self.enable_routing_replay}")
if self.enable_routing_replay:
apply_router_replay_patch()
self.mini_layer_topk_idx_list = []
def _init_device_mesh(self):
# TODO: set different parallelism for actor, critic, ref
if mpu.is_initialized():
return
mpu.initialize_model_parallel(
tensor_model_parallel_size=self.engine_config.tensor_model_parallel_size,
pipeline_model_parallel_size=self.engine_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=self.engine_config.virtual_pipeline_model_parallel_size,
use_sharp=False,
context_parallel_size=self.engine_config.context_parallel_size,
expert_model_parallel_size=self.engine_config.expert_model_parallel_size,
expert_tensor_parallel_size=self.engine_config.expert_tensor_parallel_size,
nccl_communicator_config_path=None,
)
def _build_tf_config(self):
from verl.utils.megatron_utils import mapping_string_to_attn_backend
from verl.utils.torch_dtypes import PrecisionType
check_mtp_config(self.model_config, self.engine_config)
self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype)
self.dtype = PrecisionType.to_dtype(self.param_dtype)
override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config})
if self.enable_routing_replay:
override_transformer_config["enable_routing_replay"] = True
self.provider = None
self.vanilla_bridge = self.engine_config.vanilla_mbridge
if self.vanilla_bridge:
from verl.models.mcore.mbridge import AutoBridge
bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype)
bridge.set_extra_args(**override_transformer_config)
tf_config = bridge.config
tf_config.fp16 = self.param_dtype == torch.float16
tf_config.bf16 = self.param_dtype == torch.bfloat16
else:
from verl.models.mcore.bridge import AutoBridge
# Use Megatron-Bridge to convert HF config to Megatron config
bridge = AutoBridge.from_hf_pretrained(
self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code
)
# Get Megatron provider and configure it
provider = bridge.to_megatron_provider(load_weights=False)
# In case of invalid overrides, we need to make sure some critical params are set correctly
provider.params_dtype = self.param_dtype
# Ensure dtype settings propagate to Megatron-Bridge/TE
provider.fp16 = self.param_dtype == torch.float16
provider.bf16 = self.param_dtype == torch.bfloat16
# Pass distributed info
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size
provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size
provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size
provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size
provider.context_parallel_size = self.engine_config.context_parallel_size
provider.sequence_parallel = self.engine_config.sequence_parallel
# Match verl implementation (need variable_seq_lengths)
from megatron.core.transformer.enums import AttnBackend
provider.attention_backend = AttnBackend.flash
provider.variable_seq_lengths = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_router_load_balancing_type = "none"
# Apply transformer config overrides
for key, value in override_transformer_config.items():
setattr(provider, key, value)
provider.finalize()
self.provider = provider
tf_config = None # Will be set after model creation
self.bridge = bridge
if not self.bridge:
self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype)
if torch.distributed.get_rank() == 0:
if tf_config is not None:
print(f"TF config: {tf_config}")
self.tf_config = tf_config
from verl.workers.config.megatron_peft import get_peft_cls
self.peft_cls = get_peft_cls(
model_config=self.model_config, bridge=self.bridge, provider=self.provider, dtype=self.param_dtype
)
def _build_megatron_module(self):
from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module
from verl.utils.model import print_model_size
# TODO: add more cases
is_value_model = (
"ForTokenClassification" in self.model_config.architectures[0]
or "ForSequenceClassification" in self.model_config.architectures[0]
)
self.is_value_model = is_value_model
if self.engine_config.forward_only:
wrap_with_ddp = False
else:
wrap_with_ddp = True
wrap_config = McoreModuleWrapperConfig(
is_value_model=is_value_model, # actor is not value model
share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights,
wrap_with_ddp=wrap_with_ddp,
use_distributed_optimizer=self.engine_config.use_distributed_optimizer,
)
module, updated_tf_config = make_megatron_module(
wrap_config=wrap_config,
tf_config=self.tf_config,
hf_config=self.model_config.hf_config,
bridge=self.bridge,
provider=self.provider,
override_model_config=self.engine_config.override_mcore_model_config,
override_ddp_config=self.engine_config.override_ddp_config,
peft_cls=self.peft_cls,
peft_config=self.model_config.get("lora", None),
)
self.tf_config = updated_tf_config
print(f"module: {len(module)}")
if self.engine_config.use_dist_checkpointing:
load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model)
else:
if self.vanilla_bridge:
self.bridge.load_weights(module, self.model_config.local_path)
else:
allowed_mismatched_params = []
if self.is_value_model:
allowed_mismatched_params = ["output_layer.weight"]
self.bridge.load_hf_weights(
module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params
)
if torch.distributed.get_rank() == 0:
print_model_size(module[0])
if self.enable_routing_replay:
print(f"routing replay layers: {len(RouterReplay.router_instances)}")
return module
def _maybe_enable_fused_kernels(self):
if not self.engine_config.use_fused_kernels:
return
if self.is_value_model or self.model_config.mtp.enable:
logger.warning_once(
"Fused kernels are not supported for value models or when MTP is enabled in Megatron engine; disabling."
)
self.engine_config.use_fused_kernels = False
return
from verl.models.mcore.model_forward_fused import patch_fused_forward
for model in self.module:
patch_fused_forward(model)
def _build_optimizer(self):
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
optim_config_megatron = init_megatron_optim_config(
self.optimizer_config,
use_distributed_optimizer=self.engine_config.use_distributed_optimizer,
fp16=self.param_dtype == torch.float16,
)
optimizer = get_megatron_optimizer(model=self.module, config=optim_config_megatron)
register_megatron_training_hooks(self.module, optimizer)
return optimizer
def _build_lr_scheduler(self):
from verl.utils.megatron.optimizer import get_megatron_optimizer_param_scheduler
optimizer_scheduler = get_megatron_optimizer_param_scheduler(
optimizer=self.optimizer, config=self.optimizer_config
)
return optimizer_scheduler
@property
def is_param_offload_enabled(self) -> bool:
return self._is_offload_param
@property
def is_optimizer_offload_enabled(self) -> bool:
return self._is_offload_optimizer
def is_mp_src_rank_with_outputs(self):
return (
mpu.get_tensor_model_parallel_rank() == 0
and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
and mpu.get_context_parallel_rank() == 0
)
def initialize(self):
self._build_tf_config()
self.module = self._build_megatron_module()
self._maybe_enable_fused_kernels()
if self.model_config.mtp.enable:
patch_engine_mtp(self.module, self.model_config)
# For forward_only, we don't need optimizer, lr_scheduler, checkpoint_mananager
if self.engine_config.forward_only:
self.optimizer = None
self.lr_scheduler = None
self.to(device="cpu", model=self._is_offload_param, optimizer=False, grad=False)
log_gpu_memory_usage("After offload model during init (forward_only)", logger=logger)
return
self.optimizer = self._build_optimizer()
self.lr_scheduler = self._build_lr_scheduler()
full_reshardable = self.engine_config.dist_ckpt_optim_fully_reshardable
mem_eff = self.engine_config.distrib_optim_fully_reshardable_mem_efficient
tmp_config = OmegaConf.create(
{
"model": {"path": self.model_config.local_path},
"megatron": {
"dist_ckpt_optim_fully_reshardable": full_reshardable,
"distrib_optim_fully_reshardable_mem_efficient": mem_eff,
},
}
)
role = "actor" if not self.is_value_model else "critic"
self.checkpoint_mananager = MegatronCheckpointManager(
config=tmp_config,
checkpoint_config=self.checkpoint_config,
model_config=self.model_config.hf_config,
transformer_config=self.tf_config,
role=role,
model=self.module,
arch=self.model_config.architectures[0],
hf_config=self.model_config.hf_config,
param_dtype=self.param_dtype,
share_embeddings_and_output_weights=self.model_config.share_embeddings_and_output_weights,
processing_class=self.model_config.get_processor(),
optimizer=self.optimizer,
optimizer_scheduler=self.lr_scheduler,
use_distributed_optimizer=self.engine_config.use_distributed_optimizer,
use_checkpoint_opt_param_scheduler=self.optimizer_config.use_checkpoint_opt_param_scheduler,
bridge=self.bridge,
provider=self.provider,
peft_cls=self.peft_cls,
use_dist_checkpointing=self.engine_config.use_dist_checkpointing,
)
self.to(
device="cpu",
model=self._is_offload_param,
optimizer=self._is_offload_optimizer,
grad=self._is_offload_param,
)
log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger)
def train_mode(self, **kwargs):
"""
Context manager entry for switching the engine and model into training mode.
Usage:
with engine.train_mode():
# runs in training mode
"""
return EngineTrainModeCtx(self, **kwargs)
def eval_mode(self, **kwargs):
"""
Context manager entry for switching the engine and model into evaluation mode.
Usage:
with engine.eval_mode():
# runs in evaluation mode
"""
return EngineEvalModeCtx(self, **kwargs)
def optimizer_zero_grad(self):
"""
Zero out gradients of all parameters before starting a new backward pass.
"""
self.optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()
def optimizer_step(self):
"""
Perform an optimization step to update model parameters based on accumulated gradients.
Returns:
grad_norm (float): The norm of the gradients before clipping or update.
"""
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step()
if update_successful:
# allgather already execute in optimizer.step in new megatron
pass
else:
raise NotImplementedError("Megatron optimizer step failed. This should not happen")
return grad_norm
def lr_scheduler_step(self):
"""
Advance the learning rate scheduler by one step.
Returns:
current_lr (float or list[float]): Updated learning rate(s).
"""
from verl.utils.megatron.optimizer import get_megatron_last_lr
self.lr_scheduler.step(1)
return get_megatron_last_lr(self.optimizer)
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
"""
Move model parameters, optimizer states, or both to the specified device.
Note that this function executes irrespective of offload config. It serves as manual control
Args:
device: Target device identifier.
model: If True, move the model.
optimizer: If True, move the optimizer states.
"""
super().to(device=device, model=model, optimizer=optimizer, grad=grad)
device_name = get_device_name()
assert device in (device_name, "cpu")
if device == device_name:
if model:
load_megatron_model_to_gpu(self.module, load_grad=grad)
if optimizer and self.optimizer is not None:
load_megatron_optimizer(self.optimizer)
elif device == "cpu":
if model:
offload_megatron_model_to_cpu(self.module)
if optimizer and self.optimizer is not None:
offload_megatron_optimizer(self.optimizer)
else:
raise ValueError(f"Invalid device type: {device}")
def get_data_parallel_rank(self):
return mpu.get_data_parallel_rank()
def get_data_parallel_size(self):
return mpu.get_data_parallel_world_size()
def get_data_parallel_group(self):
return mpu.get_data_parallel_group()
def get_model_parallel_group(self):
return mpu.get_model_parallel_group()
def get_context_parallel_group(self):
return mpu.get_context_parallel_group()
def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
"""
Save model, optimizer, and scheduler states to a checkpoint.
Args:
local_path: Local filesystem path to save checkpoint.
hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
"""
origin_module_device = get_megatron_module_device(self.module)
if self._is_offload_param or origin_module_device == "cpu":
load_megatron_model_to_gpu(self.module, load_grad=True)
self.checkpoint_mananager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
torch.distributed.barrier()
if self._is_offload_param:
offload_megatron_model_to_cpu(self.module)
def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
) -> None:
"""
Load model, optimizer, and scheduler states from a checkpoint.
Args:
local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading.
"""
if self._is_offload_param:
load_megatron_model_to_gpu(self.module)
self.checkpoint_mananager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
if self._is_offload_param:
offload_megatron_model_to_cpu(self.module)
if self._is_offload_optimizer:
offload_megatron_optimizer(self.optimizer)
def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any:
tu.assign_non_tensor(data, sp_size=self.engine_config.context_parallel_size)
# compute num_tokens in global batch for loss normalization
batch_num_tokens = data["loss_mask"].sum().to(get_device_id())
torch.distributed.all_reduce(
batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group()
)
tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item())
tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size())
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is not None and vpp_size > 1:
num_batches_divided_by = self.tf_config.microbatch_group_size_per_vp_stage
else:
num_batches_divided_by = None
micro_batches, indices = prepare_micro_batches(
data=data,
dp_group=self.get_data_parallel_group(),
num_batches_divided_by=num_batches_divided_by,
same_micro_num_in_dp=True,
min_num_micro_batch=None,
)
if num_batches_divided_by is not None:
assert len(micro_batches) % num_batches_divided_by == 0, (
f"micro_batches {micro_batches} must be divisible by num_batches_divided_by "
f"{num_batches_divided_by} for megatron backend"
)
# compute input shapes for pp stages
n_micro_batch = len(micro_batches)
for micro_batch in micro_batches:
tu.assign_non_tensor(micro_batch, num_micro_batch=n_micro_batch)
forward_backward_func = get_forward_backward_func()
postprocess_micro_batch_func = partial(
self.postprocess_micro_batch_func,
forward_only=forward_only,
loss_function=loss_function,
)
tu.assign_non_tensor(data, num_micro_batch=n_micro_batch)
forward_step = partial(self.forward_step, postprocess_micro_batch_func=postprocess_micro_batch_func)
enable_routing_replay = tu.get_non_tensor_data(data, key="enable_routing_replay", default=False)
if enable_routing_replay:
# Set to REPLAY mode: for R3 mode or actor update phase in R2 mode
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
if forward_only and self.engine_config.router_replay.mode == "R2":
# In R2 mode, forward_only calls (e.g., compute_log_probs) need to record routing information
RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.module,
num_microbatches=n_micro_batch,
seq_length=1, # the communication shape is obtained via p2p comm
micro_batch_size=1, # the communication shape is obtained via p2p comm
forward_only=forward_only,
)
if self.model_config.mtp.enable and self.is_mp_src_rank_with_outputs():
# add mtp_losses
metrics = get_megatron_mtp_loss(n_micro_batch)
if "metrics" not in losses_reduced[0]:
losses_reduced[0]["metrics"] = {}
losses_reduced[0]["metrics"].update(metrics)
if RouterReplayHelper.is_r2_record_action(self.tf_config):
if self.tf_config.virtual_pipeline_model_parallel_size is not None:
# config = self.actor_module[0].module.module.config
vp_size = len(self.module)
microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage
bs = n_micro_batch
topk_idx_td = reorder_and_merge_vpp_layers(
self.mini_layer_topk_idx_list, bs, vp_size, microbatch_group_size_per_vp_stage
)
else:
tensors = [tensor for nt in self.mini_layer_topk_idx_list for tensor in nt.unbind()]
topk_idx_td = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
self.mini_layer_topk_idx_list = []
layers_topk_idx = pp_gather(topk_idx_td.to(torch.uint8), self.tf_config)
use_dynamic_bsz = tu.get_non_tensor_data(data=data, key="use_dynamic_bsz", default=True)
if use_dynamic_bsz and indices is not None:
layers_topk_idx = restore_dynamic_batch(layers_topk_idx, indices)
output = {}
if mpu.is_pipeline_last_stage(ignore_virtual=True):
output = postprocess_batch_func(output_lst=losses_reduced, indices=indices, data=data)
if RouterReplayHelper.is_r2_record_action(self.tf_config):
output["model_output"]["routed_experts"] = layers_topk_idx
if enable_routing_replay:
RouterReplay.clear_global_indices()
RouterReplay.clear_global_router_replay_action()
return output
def get_per_tensor_param(self, base_sync_done=False, **kwargs):
peft_config = None
non_merge_lora_sync = self.peft_cls is not None and not self.model_config.lora.get("merge", False)
adapter_only = base_sync_done and non_merge_lora_sync
# when lora adapter only, we only load adapter weights when base sync is done, otherwise load all weights
load_megatron_model_to_gpu(self.module, load_grad=False, load_frozen_params=not adapter_only)
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.module)
elif adapter_only:
# Only export adapter weights
peft_config = build_peft_config_for_vllm(self.model_config.lora)
per_tensor_param = self.bridge.export_adapter_weights(self.module)
else:
per_tensor_param = self.bridge.export_hf_weights(self.module)
if non_merge_lora_sync:
per_tensor_param = add_base_layer_suffix(
per_tensor_param, model_type=self.model_config.hf_config.model_type
)
return per_tensor_param, peft_config
def disable_adapter(self) -> ContextManager:
return self.peft_cls.disable_adapter(self.module)
def forward_step(self, batch_iter, model, postprocess_micro_batch_func):
raise NotImplementedError("forward_step must be implemented in subclass")
def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function):
raise NotImplementedError("postprocess_micro_batch_func must be implemented in subclass")
class EngineEvalModeCtx(BaseEngineCtx):
def __init__(self, engine: MegatronEngine, **kwargs):
super().__init__(engine=engine, mode="eval", **kwargs)
def __enter__(self):
assert isinstance(self.engine, MegatronEngine)
super().__enter__()
# mcore module is a list of model chunk in each vpp stage
for module in self.engine.module:
module.eval()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, MegatronEngine)
super().__exit__(exc_type, exc_value, traceback)
class EngineTrainModeCtx(BaseEngineCtx):
def __init__(self, engine: MegatronEngine, **kwargs):
super().__init__(engine=engine, mode="train", **kwargs)
def __enter__(self):
assert isinstance(self.engine, MegatronEngine)
super().__enter__()
# mcore module is a list of model chunk in each vpp stage
for module in self.engine.module:
module.train()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, MegatronEngine)
self.engine.optimizer_zero_grad()
super().__exit__(exc_type, exc_value, traceback)
@EngineRegistry.register(model_type="language_model", backend="megatron")
class MegatronEngineWithLMHead(MegatronEngine):
def prepare_model_inputs(self, batch: TensorDict):
input_ids = batch["input_ids"]
loss_mask = batch["loss_mask"].to(bool)
multi_modal_inputs = extract_multi_modal_inputs(batch.get("multi_modal_inputs", []))
routed_experts = batch.get("routed_experts", None)
return {
"input_ids": input_ids,
"loss_mask": loss_mask,
"multi_modal_inputs": multi_modal_inputs,
"routed_experts": routed_experts,
}
def prepare_model_outputs(self, output: dict, data: TensorDict):
calculate_entropy = tu.get_non_tensor_data(data, key="calculate_entropy", default=False)
log_prob = output["log_probs"]
model_output = {"log_probs": log_prob}
if calculate_entropy:
entropy = output["entropy"]
model_output["entropy"] = entropy
return model_output
def forward_step(self, batch_iter: Iterator[TensorDict], model, postprocess_micro_batch_func):
batch: TensorDict = next(batch_iter)
batch = batch.to(get_device_id())
use_fused_kernels = tu.get_non_tensor_data(batch, key="use_fused_kernels", default=False)
calculate_entropy = tu.get_non_tensor_data(batch, key="calculate_entropy", default=False)
pad_mode = tu.get_non_tensor_data(batch, key="pad_mode", default=DatasetPadMode.NO_PADDING)
temperature = batch["temperature"]
model_inputs = self.prepare_model_inputs(batch)
input_ids = model_inputs["input_ids"]
multi_modal_inputs = model_inputs["multi_modal_inputs"]
loss_mask = model_inputs["loss_mask"]
unwrapped_model = unwrap_model(model)
if hasattr(unwrapped_model, "vp_stage"):
vp_rank = unwrapped_model.vp_stage
else:
vp_rank = 0
if RouterReplayHelper.is_replay_backward_action(self.tf_config, vp_rank):
router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank)
for router in router_instance_list:
router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank):
layers_topk_idx = model_inputs["routed_experts"]
set_router_replay_data(layers_topk_idx, None, self.tf_config, vp_rank)
if pad_mode == DatasetPadMode.NO_PADDING:
label = input_ids.clone()
else:
raise NotImplementedError(f"Pad mode {pad_mode} is not supported for megatron engine")
from verl.models.mcore import get_mcore_forward_no_padding_fn
if use_fused_kernels:
if not self.engine_config.use_remove_padding:
logger.warning_once(
"Fused kernels require `use_remove_padding=True` for Megatron engine. Falling back to non-fused."
)
use_fused_kernels = False
elif isinstance(temperature, torch.Tensor):
if temperature.numel() != 1:
logger.warning_once(
"Fused kernels do not support per-sample temperature. Falling back to non-fused."
)
use_fused_kernels = False
else:
temperature_value = float(temperature.item())
else:
temperature_value = float(temperature)
if use_fused_kernels:
fused_forward_fn = get_mcore_forward_fused_no_padding_fn(self.model_config.hf_config)
output = fused_forward_fn(
model=model,
input_ids=input_ids,
labels=label,
multi_modal_inputs=multi_modal_inputs,
temperature=temperature_value,
calculate_entropy=calculate_entropy,
pad_token_id=self.model_config.tokenizer.pad_token_id,
)
else:
if not isinstance(temperature, torch.Tensor):
temperature = torch.tensor([temperature] * input_ids.shape[0], device=input_ids.device)
temperature = temperature.to(torch.float32)
assert temperature.shape[0] == input_ids.shape[0]
temperature = verl_F.expand_as_nested(temperature, input_ids) # (bsz, j1)
forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)
def logits_processor(logits, label, temperature):
assert logits.shape[:2] == label.shape[:2]
# avoid non-positive temperature such as padding
temperature[temperature <= 0] = 1e-8
assert torch.all(temperature > 0).item(), f"temperature tensor must be positive. Got {temperature}"
logits.div_(temperature.unsqueeze(dim=-1).to(logits.dtype))
ret = {}
if calculate_entropy:
logits_bak = logits.clone()
# # disable the hint until the fused_kernel is optimized for triton>=3.3
# if torch.distributed.get_rank() == 0:
# logger.warning_once(
# "For memory-efficient computation, enable fused kernels via "
# "`actor_rollout_ref.model.use_fused_kernels=True`. "
# "The current `clone()` operation ensures correctness but increases memory usage."
# )
entropy = vocab_parallel_entropy(logits)
ret["entropy"] = entropy
else:
logits_bak = logits
log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
ret["log_probs"] = log_probs
return ret
logits_processor_args = {"label": label, "temperature": temperature, "loss_mask": loss_mask}
output = forward_fn(
model,
input_ids,
multi_modal_inputs,
logits_processor=logits_processor,
logits_processor_args=logits_processor_args,
vision_model=hasattr(self.model_config.hf_config, "vision_config"),
pad_token_id=self.model_config.tokenizer.pad_token_id,
data_format="thd" if self.engine_config.use_remove_padding else "bshd",
mtp_enable_train=self.model_config.mtp.enable and self.model_config.mtp.enable_train,
)
# Router replay: record routing decisions for R2 mode
if RouterReplayHelper.is_r2_record_action(self.tf_config, vp_rank):
merge_router_topk_indices(None, input_ids, self.mini_layer_topk_idx_list, self.tf_config, vp_rank)
# Router replay: switch to backward replay mode for next backward pass
if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank):
router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank)
for router in router_instance_list:
router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD)
return output, partial(postprocess_micro_batch_func, data=batch)
def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: bool, loss_function):
# For memory efficiency
# We move calculation of entropy to compute_log_probs, forward_only == True
device = data["input_ids"].device
model_output = self.prepare_model_outputs(output, data)
if loss_function is not None:
loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
# scale loss by num_micro_batch because megatron will scale loss
# by n_micro_batch inside pp schedule
scaled_loss = loss * data["num_micro_batch"]
else:
assert forward_only, "forward_only must be True when loss_function is None"
loss = torch.tensor(1.0, device=device)
scaled_loss = loss
metrics = {}
output = {
"model_output": model_output,
"loss": loss.detach().item(),
"metrics": metrics,
}
# return loss and stats
return scaled_loss, output
@EngineRegistry.register(model_type="value_model", backend="megatron")
class MegatronEngineWithValueHead(MegatronEngineWithLMHead):
# for value head
def forward_step(self, batch_iter, model, postprocess_micro_batch_func):
batch: TensorDict = next(batch_iter)
batch = batch.to(get_device_id())
model_inputs = self.prepare_model_inputs(batch)
input_ids = model_inputs["input_ids"]
multi_modal_inputs = model_inputs["multi_modal_inputs"]
from verl.models.mcore import get_mcore_forward_no_padding_fn
forward_fn = get_mcore_forward_no_padding_fn(self.model_config.hf_config)
output = forward_fn(
model,
input_ids,
multi_modal_inputs,
value_model=True,
vision_model=hasattr(self.model_config.hf_config, "vision_config"),
pad_token_id=self.model_config.tokenizer.pad_token_id,
enable_mtp=self.model_config.mtp.enable_train,
)
return output, partial(postprocess_micro_batch_func, data=batch)
def prepare_model_outputs(self, output: dict | torch.Tensor, data: TensorDict):
return {"values": output}
|