| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Callable |
|
|
| import torch |
| from megatron.core import parallel_state |
| from megatron.core.models.gpt.gpt_model import GPTModel |
| from megatron.core.transformer.multi_token_prediction import ( |
| MTPLossAutoScaler, |
| MTPLossLoggingHelper, |
| roll_tensor, |
| ) |
|
|
| try: |
| from megatron.core.utils import unwrap_model |
| except ImportError: |
| from verl.utils.megatron_utils import unwrap_model |
|
|
|
|
| def _get_patching_model(model: torch.nn.Module): |
| model = unwrap_model(model) |
| if isinstance(model, GPTModel): |
| return model |
|
|
| if not (hasattr(model, "language_model") and isinstance(model.language_model, GPTModel)): |
| print(f"Model {model.__class__.__name__} is not a supported for fused forward") |
| return None |
|
|
| return model.language_model |
|
|
|
|
| def patch_postprocess(model: torch.nn.Module): |
| model = _get_patching_model(model) |
| if model is not None: |
| model._postprocess_backup = model._postprocess |
| model._postprocess = _megatron_gptmodel_postprocess.__get__(model, model.__class__) |
|
|
|
|
| def unpatch_postprocess(model: torch.nn.Module): |
| model = _get_patching_model(model) |
| if model is not None: |
| model._postprocess = model._postprocess_backup |
|
|
|
|
| |
| |
| def _megatron_gptmodel_postprocess( |
| self, |
| hidden_states, |
| input_ids, |
| position_ids, |
| labels, |
| rotary_pos_emb, |
| rotary_pos_cos, |
| rotary_pos_sin, |
| mtp_in_postprocess=None, |
| loss_mask=None, |
| decoder_input=None, |
| attention_mask=None, |
| inference_params=None, |
| packed_seq_params=None, |
| sequence_len_offset=None, |
| runtime_gather_output=None, |
| extra_block_kwargs=None, |
| inference_context=None, |
| ): |
| """Postprocesses decoder hidden states to generate logits or compute loss. |
| |
| Applies Multi-Token Prediction if enabled, generates output logits through |
| the output layer, and computes language model loss when labels are provided. |
| """ |
|
|
| |
| output_weight = None |
| if self.share_embeddings_and_output_weights: |
| output_weight = self.shared_embedding_or_output_weight() |
|
|
| if mtp_in_postprocess and labels is not None: |
| hidden_states = self.mtp( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| inference_params=inference_params, |
| rotary_pos_emb=rotary_pos_emb, |
| rotary_pos_cos=rotary_pos_cos, |
| rotary_pos_sin=rotary_pos_sin, |
| packed_seq_params=packed_seq_params, |
| sequence_len_offset=sequence_len_offset, |
| embedding=self.embedding, |
| **(extra_block_kwargs or {}), |
| ) |
|
|
| if not self.post_process: |
| return hidden_states |
|
|
| |
| if self.config.mtp_num_layers and labels is not None: |
| mtp_labels = labels.clone() |
|
|
| hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
| hidden_states = hidden_states_list[0] |
| if loss_mask is None: |
| |
| loss_mask = torch.ones_like(mtp_labels) |
| for mtp_layer_number in range(self.config.mtp_num_layers): |
| |
| mtp_labels, _ = roll_tensor( |
| mtp_labels, |
| shifts=-1, |
| dims=-1, |
| cp_group=self.cp_group, |
| packed_seq_params=packed_seq_params, |
| ) |
| loss_mask, num_tokens = roll_tensor( |
| loss_mask, |
| shifts=-1, |
| dims=-1, |
| cp_group=self.cp_group, |
| packed_seq_params=packed_seq_params, |
| ) |
|
|
| |
| mtp_loss = self.compute_output_layer_and_language_model_loss( |
| hidden_states_list[mtp_layer_number + 1], |
| labels=mtp_labels, |
| weight=self.shared_embedding_or_output_weight(), |
| sequence_parallel_enabled=self.output_layer.sequence_parallel, |
| column_parallel_linear=self.output_layer, |
| col_linear_kwargs={ |
| "weight": output_weight, |
| "runtime_gather_output": runtime_gather_output, |
| }, |
| ) |
|
|
| mtp_loss = loss_mask * mtp_loss |
| if self.training: |
| |
| |
| MTPLossLoggingHelper.save_loss_to_tracker( |
| torch.sum(mtp_loss) / num_tokens, |
| mtp_layer_number, |
| self.config.mtp_num_layers, |
| avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), |
| ) |
| mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers |
| if self.config.calculate_per_token_loss: |
| hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) |
| else: |
| hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) |
|
|
| logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) |
| |
| return logits.transpose(0, 1).contiguous() |
|
|
|
|
| def patch_mtp_layer_get_embeddings(model: torch.nn.Module): |
| """Patch the _get_embeddings method of MultiTokenPredictionLayer""" |
| from megatron.core.models.gpt.gpt_model import GPTModel |
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer |
|
|
| |
| model = _get_patching_model(model) |
| |
| target_layers = [] |
|
|
| if isinstance(model, GPTModel): |
| |
| if hasattr(model, "mtp") and hasattr(model.mtp, "layers"): |
| for layer in model.mtp.layers: |
| if isinstance(layer, MultiTokenPredictionLayer): |
| target_layers.append(layer) |
| elif hasattr(model, "layers"): |
| |
| for layer in model.layers: |
| if isinstance(layer, MultiTokenPredictionLayer): |
| target_layers.append(layer) |
|
|
| if target_layers: |
| for layer in target_layers: |
| layer._get_embeddings_backup = layer._get_embeddings |
| layer._get_embeddings = _patched_get_embeddings_for_detach.__get__(layer, layer.__class__) |
| print(f"Found and patched {len(target_layers)} MTP layer(s) in any of the actor modules") |
| return True |
| else: |
| print("No MTP layers found to patch in any of the actor modules") |
| return False |
|
|
|
|
| def unpatch_mtp_layer_get_embeddings(model: torch.nn.Module): |
| """Unpatch the _get_embeddings method of MultiTokenPredictionLayer""" |
| from megatron.core.models.gpt.gpt_model import GPTModel |
| from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer |
|
|
| |
| model = _get_patching_model(model) |
|
|
| |
| target_layers = [] |
|
|
| if isinstance(model, GPTModel): |
| |
| if hasattr(model, "mtp") and hasattr(model.mtp, "layers"): |
| for layer in model.mtp.layers: |
| if isinstance(layer, MultiTokenPredictionLayer): |
| target_layers.append(layer) |
| elif hasattr(model, "layers"): |
| |
| for layer in model.layers: |
| if isinstance(layer, MultiTokenPredictionLayer): |
| target_layers.append(layer) |
|
|
| unpatched_count = 0 |
| for layer in target_layers: |
| if hasattr(layer, "_get_embeddings_backup"): |
| layer._get_embeddings = layer._get_embeddings_backup |
| delattr(layer, "_get_embeddings_backup") |
| unpatched_count += 1 |
|
|
| if unpatched_count > 0: |
| print(f"Unpatched {unpatched_count} MTP layer(s)") |
| return True |
| return False |
|
|
|
|
| def _patched_get_embeddings_for_detach( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor, |
| embedding: Callable, |
| hidden_states: torch.Tensor, |
| packed_seq_params=None, |
| ): |
| """ |
| Patched version of _get_embeddings method for MultiTokenPredictionLayer. |
| |
| This is a modified version that you can customize according to your needs. |
| The original implementation is preserved below with modifications. |
| """ |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| from megatron.core.transformer.multi_token_prediction import roll_tensor |
| from megatron.core.utils import make_viewless_tensor |
|
|
| |
| input_ids, _ = roll_tensor( |
| input_ids, |
| shifts=-1, |
| dims=-1, |
| cp_group=self.cp_group, |
| packed_seq_params=packed_seq_params, |
| ) |
| position_ids, _ = roll_tensor( |
| position_ids, |
| shifts=-1, |
| dims=-1, |
| cp_group=self.cp_group, |
| packed_seq_params=packed_seq_params, |
| ) |
|
|
| |
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) |
|
|
| |
| |
|
|
| hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) |
|
|
| |
| decoder_input = decoder_input.detach() |
| hidden_states = hidden_states.detach() |
|
|
| return input_ids, position_ids, decoder_input, hidden_states |
|
|