arithmetic-grpo / verl /models /mcore /mtp_patch.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.multi_token_prediction import (
MTPLossAutoScaler,
MTPLossLoggingHelper,
roll_tensor,
)
try:
from megatron.core.utils import unwrap_model
except ImportError:
from verl.utils.megatron_utils import unwrap_model
def _get_patching_model(model: torch.nn.Module):
model = unwrap_model(model)
if isinstance(model, GPTModel):
return model
if not (hasattr(model, "language_model") and isinstance(model.language_model, GPTModel)):
print(f"Model {model.__class__.__name__} is not a supported for fused forward")
return None
return model.language_model
def patch_postprocess(model: torch.nn.Module):
model = _get_patching_model(model)
if model is not None:
model._postprocess_backup = model._postprocess
model._postprocess = _megatron_gptmodel_postprocess.__get__(model, model.__class__)
def unpatch_postprocess(model: torch.nn.Module):
model = _get_patching_model(model)
if model is not None:
model._postprocess = model._postprocess_backup
# copy from https://github.com/NVIDIA/Megatron-LM/blob/23e092f41ec8bc659020e401ddac9576c1cfed7e/megatron/core/models/gpt/gpt_model.py
# patch the postprocess method of GPTModel to support advanced features like MTP, 1f1b overlap, etc.
def _megatron_gptmodel_postprocess(
self,
hidden_states,
input_ids,
position_ids,
labels,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
mtp_in_postprocess=None,
loss_mask=None,
decoder_input=None,
attention_mask=None,
inference_params=None,
packed_seq_params=None,
sequence_len_offset=None,
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.
Applies Multi-Token Prediction if enabled, generates output logits through
the output layer, and computes language model loss when labels are provided.
"""
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if mtp_in_postprocess and labels is not None:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
**(extra_block_kwargs or {}),
)
if not self.post_process:
return hidden_states
# Skip when mtp_num_layers is None or 0
if self.config.mtp_num_layers and labels is not None:
mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
)
mtp_loss = loss_mask * mtp_loss
if self.training:
# TODO(shifangx): remove the use of parallel_state here
# after moving loss logging to loss_func in pretrain_gpt.py
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
# [s b h] => [b s h]
return logits.transpose(0, 1).contiguous()
def patch_mtp_layer_get_embeddings(model: torch.nn.Module):
"""Patch the _get_embeddings method of MultiTokenPredictionLayer"""
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer
# Unwrap each model in the actor_module to get the actual GPTModel
model = _get_patching_model(model)
# Collect all MultiTokenPredictionLayer instances
target_layers = []
if isinstance(model, GPTModel):
# Check if GPTModel has MTP and find the layers
if hasattr(model, "mtp") and hasattr(model.mtp, "layers"):
for layer in model.mtp.layers:
if isinstance(layer, MultiTokenPredictionLayer):
target_layers.append(layer)
elif hasattr(model, "layers"):
# Check if any layer in the model is MultiTokenPredictionLayer
for layer in model.layers:
if isinstance(layer, MultiTokenPredictionLayer):
target_layers.append(layer)
if target_layers:
for layer in target_layers:
layer._get_embeddings_backup = layer._get_embeddings
layer._get_embeddings = _patched_get_embeddings_for_detach.__get__(layer, layer.__class__)
print(f"Found and patched {len(target_layers)} MTP layer(s) in any of the actor modules")
return True
else:
print("No MTP layers found to patch in any of the actor modules")
return False
def unpatch_mtp_layer_get_embeddings(model: torch.nn.Module):
"""Unpatch the _get_embeddings method of MultiTokenPredictionLayer"""
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer
# Unwrap each model in the actor_module to get the actual GPTModel
model = _get_patching_model(model)
# Collect all MultiTokenPredictionLayer instances
target_layers = []
if isinstance(model, GPTModel):
# Check if GPTModel has MTP and find the layers
if hasattr(model, "mtp") and hasattr(model.mtp, "layers"):
for layer in model.mtp.layers:
if isinstance(layer, MultiTokenPredictionLayer):
target_layers.append(layer)
elif hasattr(model, "layers"):
# Check if any layer in the model is MultiTokenPredictionLayer
for layer in model.layers:
if isinstance(layer, MultiTokenPredictionLayer):
target_layers.append(layer)
unpatched_count = 0
for layer in target_layers:
if hasattr(layer, "_get_embeddings_backup"):
layer._get_embeddings = layer._get_embeddings_backup
delattr(layer, "_get_embeddings_backup")
unpatched_count += 1
if unpatched_count > 0:
print(f"Unpatched {unpatched_count} MTP layer(s)")
return True
return False
def _patched_get_embeddings_for_detach(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
packed_seq_params=None,
):
"""
Patched version of _get_embeddings method for MultiTokenPredictionLayer.
This is a modified version that you can customize according to your needs.
The original implementation is preserved below with modifications.
"""
# You can modify the logic here as needed
# For example, you could:
# - Change the shift amount in roll_tensor
# - Apply custom transformations to input_ids or position_ids
# - Add debugging information
# - Modify the embedding computation
# Original logic with custom modifications
from megatron.core.transformer.multi_token_prediction import roll_tensor
from megatron.core.utils import make_viewless_tensor
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(
input_ids,
shifts=-1, # You can modify this shift value
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
position_ids, _ = roll_tensor(
position_ids,
shifts=-1, # You can modify this shift value
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
# embedding computation - you can modify this part
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
# Apply custom transformations if needed
# For example: decoder_input = some_custom_function(decoder_input)
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
# detach decoder_input and hidden_states
decoder_input = decoder_input.detach()
hidden_states = hidden_states.detach()
return input_ids, position_ids, decoder_input, hidden_states