File size: 11,356 Bytes
1faccd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

import torch
from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.multi_token_prediction import (
    MTPLossAutoScaler,
    MTPLossLoggingHelper,
    roll_tensor,
)

try:
    from megatron.core.utils import unwrap_model
except ImportError:
    from verl.utils.megatron_utils import unwrap_model


def _get_patching_model(model: torch.nn.Module):
    model = unwrap_model(model)
    if isinstance(model, GPTModel):
        return model

    if not (hasattr(model, "language_model") and isinstance(model.language_model, GPTModel)):
        print(f"Model {model.__class__.__name__} is not a supported for fused forward")
        return None

    return model.language_model


def patch_postprocess(model: torch.nn.Module):
    model = _get_patching_model(model)
    if model is not None:
        model._postprocess_backup = model._postprocess
        model._postprocess = _megatron_gptmodel_postprocess.__get__(model, model.__class__)


def unpatch_postprocess(model: torch.nn.Module):
    model = _get_patching_model(model)
    if model is not None:
        model._postprocess = model._postprocess_backup


# copy from https://github.com/NVIDIA/Megatron-LM/blob/23e092f41ec8bc659020e401ddac9576c1cfed7e/megatron/core/models/gpt/gpt_model.py
# patch the postprocess method of GPTModel to support advanced features like MTP, 1f1b overlap, etc.
def _megatron_gptmodel_postprocess(
    self,
    hidden_states,
    input_ids,
    position_ids,
    labels,
    rotary_pos_emb,
    rotary_pos_cos,
    rotary_pos_sin,
    mtp_in_postprocess=None,
    loss_mask=None,
    decoder_input=None,
    attention_mask=None,
    inference_params=None,
    packed_seq_params=None,
    sequence_len_offset=None,
    runtime_gather_output=None,
    extra_block_kwargs=None,
    inference_context=None,
):
    """Postprocesses decoder hidden states to generate logits or compute loss.

    Applies Multi-Token Prediction if enabled, generates output logits through
    the output layer, and computes language model loss when labels are provided.
    """

    # logits and loss
    output_weight = None
    if self.share_embeddings_and_output_weights:
        output_weight = self.shared_embedding_or_output_weight()

    if mtp_in_postprocess and labels is not None:
        hidden_states = self.mtp(
            input_ids=input_ids,
            position_ids=position_ids,
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            inference_params=inference_params,
            rotary_pos_emb=rotary_pos_emb,
            rotary_pos_cos=rotary_pos_cos,
            rotary_pos_sin=rotary_pos_sin,
            packed_seq_params=packed_seq_params,
            sequence_len_offset=sequence_len_offset,
            embedding=self.embedding,
            **(extra_block_kwargs or {}),
        )

    if not self.post_process:
        return hidden_states

    # Skip when mtp_num_layers is None or 0
    if self.config.mtp_num_layers and labels is not None:
        mtp_labels = labels.clone()

        hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
        hidden_states = hidden_states_list[0]
        if loss_mask is None:
            # if loss_mask is not provided, use all ones as loss_mask
            loss_mask = torch.ones_like(mtp_labels)
        for mtp_layer_number in range(self.config.mtp_num_layers):
            # Calc loss for the current Multi-Token Prediction (MTP) layers.
            mtp_labels, _ = roll_tensor(
                mtp_labels,
                shifts=-1,
                dims=-1,
                cp_group=self.cp_group,
                packed_seq_params=packed_seq_params,
            )
            loss_mask, num_tokens = roll_tensor(
                loss_mask,
                shifts=-1,
                dims=-1,
                cp_group=self.cp_group,
                packed_seq_params=packed_seq_params,
            )

            # Compute mtp loss without storing logits to save memory.
            mtp_loss = self.compute_output_layer_and_language_model_loss(
                hidden_states_list[mtp_layer_number + 1],
                labels=mtp_labels,
                weight=self.shared_embedding_or_output_weight(),
                sequence_parallel_enabled=self.output_layer.sequence_parallel,
                column_parallel_linear=self.output_layer,
                col_linear_kwargs={
                    "weight": output_weight,
                    "runtime_gather_output": runtime_gather_output,
                },
            )

            mtp_loss = loss_mask * mtp_loss
            if self.training:
                # TODO(shifangx): remove the use of parallel_state here
                # after moving loss logging to loss_func in pretrain_gpt.py
                MTPLossLoggingHelper.save_loss_to_tracker(
                    torch.sum(mtp_loss) / num_tokens,
                    mtp_layer_number,
                    self.config.mtp_num_layers,
                    avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
                )
            mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
            if self.config.calculate_per_token_loss:
                hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
            else:
                hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)

    logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
    # [s b h] => [b s h]
    return logits.transpose(0, 1).contiguous()


def patch_mtp_layer_get_embeddings(model: torch.nn.Module):
    """Patch the _get_embeddings method of MultiTokenPredictionLayer"""
    from megatron.core.models.gpt.gpt_model import GPTModel
    from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer

    # Unwrap each model in the actor_module to get the actual GPTModel
    model = _get_patching_model(model)
    # Collect all MultiTokenPredictionLayer instances
    target_layers = []

    if isinstance(model, GPTModel):
        # Check if GPTModel has MTP and find the layers
        if hasattr(model, "mtp") and hasattr(model.mtp, "layers"):
            for layer in model.mtp.layers:
                if isinstance(layer, MultiTokenPredictionLayer):
                    target_layers.append(layer)
    elif hasattr(model, "layers"):
        # Check if any layer in the model is MultiTokenPredictionLayer
        for layer in model.layers:
            if isinstance(layer, MultiTokenPredictionLayer):
                target_layers.append(layer)

    if target_layers:
        for layer in target_layers:
            layer._get_embeddings_backup = layer._get_embeddings
            layer._get_embeddings = _patched_get_embeddings_for_detach.__get__(layer, layer.__class__)
        print(f"Found and patched {len(target_layers)} MTP layer(s) in any of the actor modules")
        return True
    else:
        print("No MTP layers found to patch in any of the actor modules")
        return False


def unpatch_mtp_layer_get_embeddings(model: torch.nn.Module):
    """Unpatch the _get_embeddings method of MultiTokenPredictionLayer"""
    from megatron.core.models.gpt.gpt_model import GPTModel
    from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer

    # Unwrap each model in the actor_module to get the actual GPTModel
    model = _get_patching_model(model)

    # Collect all MultiTokenPredictionLayer instances
    target_layers = []

    if isinstance(model, GPTModel):
        # Check if GPTModel has MTP and find the layers
        if hasattr(model, "mtp") and hasattr(model.mtp, "layers"):
            for layer in model.mtp.layers:
                if isinstance(layer, MultiTokenPredictionLayer):
                    target_layers.append(layer)
    elif hasattr(model, "layers"):
        # Check if any layer in the model is MultiTokenPredictionLayer
        for layer in model.layers:
            if isinstance(layer, MultiTokenPredictionLayer):
                target_layers.append(layer)

    unpatched_count = 0
    for layer in target_layers:
        if hasattr(layer, "_get_embeddings_backup"):
            layer._get_embeddings = layer._get_embeddings_backup
            delattr(layer, "_get_embeddings_backup")
            unpatched_count += 1

    if unpatched_count > 0:
        print(f"Unpatched {unpatched_count} MTP layer(s)")
        return True
    return False


def _patched_get_embeddings_for_detach(
    self,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    embedding: Callable,
    hidden_states: torch.Tensor,
    packed_seq_params=None,
):
    """
    Patched version of _get_embeddings method for MultiTokenPredictionLayer.

    This is a modified version that you can customize according to your needs.
    The original implementation is preserved below with modifications.
    """

    # You can modify the logic here as needed
    # For example, you could:
    # - Change the shift amount in roll_tensor
    # - Apply custom transformations to input_ids or position_ids
    # - Add debugging information
    # - Modify the embedding computation

    # Original logic with custom modifications
    from megatron.core.transformer.multi_token_prediction import roll_tensor
    from megatron.core.utils import make_viewless_tensor

    # Calc logits for the current Multi-Token Prediction (MTP) layers.
    input_ids, _ = roll_tensor(
        input_ids,
        shifts=-1,  # You can modify this shift value
        dims=-1,
        cp_group=self.cp_group,
        packed_seq_params=packed_seq_params,
    )
    position_ids, _ = roll_tensor(
        position_ids,
        shifts=-1,  # You can modify this shift value
        dims=-1,
        cp_group=self.cp_group,
        packed_seq_params=packed_seq_params,
    )

    # embedding computation - you can modify this part
    decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)

    # Apply custom transformations if needed
    # For example: decoder_input = some_custom_function(decoder_input)

    hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)

    # detach decoder_input and hidden_states
    decoder_input = decoder_input.detach()
    hidden_states = hidden_states.detach()

    return input_ids, position_ids, decoder_input, hidden_states