JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py
index 41c21d93d..ef80f72d6 100644
--- a/megatron/core/dist_checkpointing/strategies/common.py
+++ b/megatron/core/dist_checkpointing/strategies/common.py
@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
msc = MultiStorageClientFeature.import_package()
return msc.torch.load(load_path, map_location='cpu')
else:
- return torch.load(load_path, map_location='cpu')
+ return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
if MultiStorageClientFeature.is_enabled():
diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py
index ccf5242a2..9b6d3e31f 100644
--- a/megatron/core/dist_checkpointing/strategies/torch.py
+++ b/megatron/core/dist_checkpointing/strategies/torch.py
@@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li
_restore_dict_types(x_val, templ_val)
+@dataclass
+class MCoreMetadata(Metadata):
+ """Metadata with mcore specific data."""
+
+ # holds data related to flattened_range
+ # TODO: remove when flattened_range is properly removed
+ mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor
+
+
@dataclass(frozen=True)
class MCoreSavePlan(SavePlan):
"""SavePlan with MCore specific data."""
@@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner):
def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]:
"""Merges MCore data for all plans."""
global_plan, metadata = super().create_global_plan(all_plans)
- metadata.mcore_data = dict(
+ mcore_data = dict(
ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type]
)
+ metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata))
return global_plan, metadata
def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
@@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
if sh_ten.key not in metadata.state_dict_metadata:
- raise KeyError(
- f"{sh_ten.key} from model not in state dict:"
- f" {sorted(metadata.state_dict_metadata.keys())}"
- )
+ # raise KeyError(
+ # f"{sh_ten.key} from model not in state dict:"
+ # f" {sorted(metadata.state_dict_metadata.keys())}"
+ # )
+ print(f"{sh_ten.key} from model not in state dict, will skip")
+ continue
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
expected_shape = self._expected_shape(sh_ten)
if loaded_shape != expected_shape:
@@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
tensor_metadata = self.metadata.state_dict_metadata
metadata_with_sizes = [
(tensor_metadata[key], tensor_metadata[key].size, sharded_tensor)
- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items()
+ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata
]
try:
# Temporarily set sizes to expected shapes
@@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
+ allow_partial_load=True,
),
)
diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py
index fe26e8b43..4451f2776 100644
--- a/megatron/core/distributed/__init__.py
+++ b/megatron/core/distributed/__init__.py
@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads
from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
+
+# Backward compatibility patch for FSDP module reorganization
+import sys
+import importlib.util
+
+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp')
+if spec:
+ custom_fsdp = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(custom_fsdp)
+ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp
+ if hasattr(custom_fsdp, 'MegatronFSDP'):
+ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP
diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py
index 7727efe1e..966fe652a 100644
--- a/megatron/core/extensions/transformer_engine.py
+++ b/megatron/core/extensions/transformer_engine.py
@@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear):
)
for param in self.parameters():
+ setattr(param, "parallel_mode", parallel_mode)
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, "allreduce", not self.expert_parallel)
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
index 860ee64a9..80944b702 100755
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ b/megatron/core/models/gpt/gpt_layer_specs.py
@@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec(
qk_l2_norm: Optional[bool] = False,
use_te_op_fuser: Optional[bool] = False,
use_kitchen: bool = False,
+ post_self_attn_layernorm: bool = False,
+ post_mlp_layernorm: bool = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
@@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec(
),
),
self_attn_bda=get_bias_dropout_add,
+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp,
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
index 6aec66e6d..6ca48b55f 100644
--- a/megatron/core/models/gpt/gpt_model.py
+++ b/megatron/core/models/gpt/gpt_model.py
@@ -355,6 +355,7 @@ class GPTModel(LanguageModule):
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
+ mtp_kwargs: Optional[dict] = {},
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
@@ -410,6 +411,7 @@ class GPTModel(LanguageModule):
runtime_gather_output=runtime_gather_output,
extra_block_kwargs=extra_block_kwargs,
inference_context=inference_context,
+ mtp_kwargs=mtp_kwargs,
)
def _postprocess(
@@ -431,6 +433,7 @@ class GPTModel(LanguageModule):
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
+ mtp_kwargs={},
):
"""Postprocesses decoder hidden states to generate logits or compute loss.
@@ -446,7 +449,7 @@ class GPTModel(LanguageModule):
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
- if mtp_in_postprocess:
+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
@@ -465,25 +468,37 @@ class GPTModel(LanguageModule):
if not self.post_process:
return hidden_states
- if self.mtp_process:
- mtp_labels = labels.clone()
+ if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None:
+ mtp_labels = mtp_kwargs['mtp_labels'].clone()
+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
+
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
+ else:
+ # Otherwise, roll the loss_mask to keep up with the mtp_labels
+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
for mtp_layer_number in range(self.config.mtp_num_layers):
# output
- mtp_logits, _ = self.output_layer(
- hidden_states_list[mtp_layer_number + 1],
- weight=output_weight,
- runtime_gather_output=runtime_gather_output,
+ output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()}
+ output_layer_buffers = dict(self.output_layer.named_buffers())
+ mtp_logits, _ = torch.func.functional_call(
+ self.output_layer,
+ {**output_layer_params, **output_layer_buffers},
+ (hidden_states_list[mtp_layer_number + 1],),
+ {
+ "weight": output_weight.detach() if output_weight else None,
+ "runtime_gather_output": runtime_gather_output,
+ },
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
- loss_mask, num_tokens = roll_tensor(
- loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
+ new_loss_mask, num_tokens = roll_tensor(
+ loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params
)
+ loss_mask = new_loss_mask * loss_mask
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
mtp_loss = loss_mask * mtp_loss
if self.training:
diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
index a36b67364..ed8883e32 100644
--- a/megatron/core/optimizer/distrib_optimizer.py
+++ b/megatron/core/optimizer/distrib_optimizer.py
@@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# TE FusedAdam will not accumulate step for empty param groups, so we need to
# align the step across param groups.
param_group["step"] = int(step)
+ if "step" in param_group and param_group["step"] is None:
+ del param_group["step"]
# Grad scaler state.
if self.grad_scaler:
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index a40c85a88..86688c331 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -9,6 +9,7 @@ from typing import Callable, List, Optional
import numpy as np
import torch
+import torch.distributed as dist
from .utils import GlobalMemoryBuffer, is_torch_min_version
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index 63ee9d1f5..b90b744c1 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -26,22 +26,22 @@ def _batched_p2p_ops(
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
+ torch.distributed.isend, tensor_send_next, next_pipeline_rank,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank,
)
ops.append(recv_next_op)
if len(ops) > 0:
diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
index c749bac43..dde8d50e7 100644
--- a/megatron/core/transformer/attention.py
+++ b/megatron/core/transformer/attention.py
@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC):
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
nvtx_range_push(suffix="qkv")
- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
+ if self.config.use_gated_attention:
+ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states)
+ else:
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
nvtx_range_pop(suffix="qkv")
# ===================================================
@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC):
# Output. [sq, b, h]
# =================
+ if self.config.use_gated_attention:
+ nvtx_range_push(suffix="sigmoid_gate")
+ core_attn_out = core_attn_out * torch.sigmoid(gate)
+ nvtx_range_pop(suffix="sigmoid_gate")
+
nvtx_range_push(suffix="linear_proj")
output, bias = self.linear_proj(core_attn_out)
nvtx_range_pop(suffix="linear_proj")
@@ -879,19 +887,34 @@ class SelfAttention(Attention):
model_comm_pgs=model_comm_pgs,
)
- self.linear_qkv = build_module(
- submodules.linear_qkv,
- self.config.hidden_size,
- self.query_projection_size + 2 * self.kv_projection_size,
- config=self.config,
- init_method=self.config.init_method,
- gather_output=False,
- bias=self.config.add_bias_linear or self.config.add_qkv_bias,
- skip_bias_add=False,
- is_expert=False,
- tp_comm_buffer_name='qkv',
- tp_group=self.model_comm_pgs.tp,
- )
+ if self.config.use_gated_attention:
+ self.linear_qgkv = build_module(
+ submodules.linear_qkv,
+ self.config.hidden_size,
+ 2 * (self.query_projection_size + self.kv_projection_size),
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qkv',
+ tp_group=self.model_comm_pgs.tp,
+ )
+ else:
+ self.linear_qkv = build_module(
+ submodules.linear_qkv,
+ self.config.hidden_size,
+ self.query_projection_size + 2 * self.kv_projection_size,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qkv',
+ tp_group=self.model_comm_pgs.tp,
+ )
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
@@ -1036,6 +1059,65 @@ class SelfAttention(Attention):
return query, key, value
+ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192
+ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None):
+ """
+ Derives `query`, `key` and `value` tensors from `hidden_states`.
+ """
+ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)]
+ mixed_qgkv, _ = self.linear_qgkv(hidden_states)
+
+ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn]
+ new_tensor_shape = mixed_qgkv.size()[:-1] + (
+ self.num_query_groups_per_partition,
+ (
+ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1)
+ * self.hidden_size_per_attention_head
+ ),
+ )
+ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape)
+
+ split_arg_list = [
+ (
+ self.num_attention_heads_per_partition
+ // self.num_query_groups_per_partition
+ * self.hidden_size_per_attention_head
+ ),
+ (
+ self.num_attention_heads_per_partition
+ // self.num_query_groups_per_partition
+ * self.hidden_size_per_attention_head
+ ),
+ self.hidden_size_per_attention_head,
+ self.hidden_size_per_attention_head,
+ ]
+
+ if SplitAlongDim is not None:
+
+ # [sq, b, ng, (np/ng + 2) * hn]
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
+ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list)
+ else:
+
+ # [sq, b, ng, (np/ng + 2) * hn]
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
+ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3)
+
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
+ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
+ gate = gate.reshape(query.size(0), query.size(1), -1)
+
+ if self.q_layernorm is not None:
+ query = self.q_layernorm(query)
+
+ if self.k_layernorm is not None:
+ key = self.k_layernorm(key)
+
+ if self.config.test_mode:
+ self.run_realtime_tests()
+
+ return query, gate, key, value
+
def backward_dw(self) -> NoReturn:
"""Execute weight update operations"""
self._backward_qkv_proj()
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
index 235b6f6af..fbcffe278 100644
--- a/megatron/core/transformer/moe/moe_utils.py
+++ b/megatron/core/transformer/moe/moe_utils.py
@@ -566,6 +566,9 @@ def topk_routing_with_score_function(
else:
return torch.topk(scores, k=topk, dim=1)
+ from slime.utils.routing_replay import get_routing_replay_compute_topk
+ compute_topk = get_routing_replay_compute_topk(compute_topk)
+
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py
index 6b20b8622..459e65921 100644
--- a/megatron/core/transformer/moe/router.py
+++ b/megatron/core/transformer/moe/router.py
@@ -156,6 +156,9 @@ class TopKRouter(Router):
self.local_tokens_per_expert = None
self.expert_bias = None
+ from slime.utils.routing_replay import register_routing_replay
+ register_routing_replay(self)
+
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py
index b7884e18e..f0104f861 100755
--- a/megatron/core/transformer/multi_token_prediction.py
+++ b/megatron/core/transformer/multi_token_prediction.py
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
import torch
from torch import Tensor
+import warnings
from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
@@ -105,17 +106,21 @@ def tie_output_layer_state_dict(
)
-def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None):
- """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support.
- This function extends the original roll_tensor to support Context Parallelism, which allows
- MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks,
- and tensor rolling requires communication between adjacent CP ranks to properly handle the
- boundary conditions.
+def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None):
+ """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support.
+
+ This function extends the original roll_tensor to support Context Parallelism and Packed Sequences.
+ When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires
+ communication between adjacent CP ranks to properly handle the boundary conditions.
+ When packed sequences are used, rolling is performed within each individual sequence boundary
+ to prevent mixing tokens between different packed sequences.
For CP=1 (default behavior): Uses standard torch.roll with zero padding
For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges
boundary elements between adjacent CP ranks to maintain sequence continuity.
+ For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens.
+
Args:
tensor (Tensor): The input tensor to roll.
@@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None):
dims (int): The dimension to roll (typically -1 for sequence dimension).
cp_group (ProcessGroup): The context parallelism process group. If None or size=1,
falls back to standard rolling behavior.
+ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing.
+ If provided, rolling respects sequence boundaries.
Returns:
tuple: (rolled_tensor, sum_of_rolled_tensor)
"""
+
+ if packed_seq_params is not None:
+ return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group)
+
# Standard rolling behavior when CP is not enabled (cp_group is None or size=1)
if cp_group is None or cp_group.size() == 1:
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
@@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None):
return rolled_tensor, rolled_tensor.sum()
+def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None):
+ """Roll tensor with packed sequence support.
+
+ This function handles rolling for packed sequences by respecting sequence boundaries
+ defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual
+ sequence to prevent mixing tokens between different packed sequences. When Context
+ Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata
+ so we slice out the portion of every packed sequence that lives on the current rank and
+ reuse the standard CP boundary exchange to populate the rolling window.
+
+ Args:
+ tensor (Tensor): The input tensor to roll.
+ shifts (int): The shift of the tensor (typically -1 for MTP).
+ dims (int): The dimension to roll (typically -1 for sequence dimension).
+ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing.
+ cp_group (ProcessGroup): The context parallelism process group.
+
+ Returns:
+ tuple: (rolled_tensor, sum_of_rolled_tensor)
+ """
+
+ # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once.
+ assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension."
+ assert shifts == -1, "Packed sequence roll only supports a single-token left shift."
+ cu_seqlens = packed_seq_params.cu_seqlens_q
+ assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q."
+
+ rolled_tensor = tensor.clone()
+
+ cp_size = cp_group.size() if cp_group is not None else 1
+ if cp_size == 1:
+ # CP disabled: simply roll inside each packed sequence boundary.
+ for i in range(len(cu_seqlens) - 1):
+ start_idx = cu_seqlens[i]
+ end_idx = cu_seqlens[i + 1]
+ seq_slice = tensor[..., start_idx:end_idx]
+ rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims)
+ rolled_seq[..., shifts:] = 0
+ rolled_tensor[..., start_idx:end_idx] = rolled_seq
+ return rolled_tensor, rolled_tensor.sum()
+
+ # CP enabled: each rank owns two chunks per sequence (front and mirrored tail).
+ local_rank = torch.distributed.get_rank(group=cp_group)
+ global_ranks = torch.distributed.get_process_group_ranks(group=cp_group)
+ next_rank = global_ranks[(local_rank + 1) % cp_size]
+ prev_rank = global_ranks[(local_rank - 1) % cp_size]
+
+ # iterate over each sequence individually
+ for i in range(len(cu_seqlens) - 1):
+ start_idx = cu_seqlens[i]
+ end_idx = cu_seqlens[i + 1]
+
+ # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx
+ local_start_idx = start_idx // cp_size
+ local_end_idx = end_idx // cp_size
+ tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone()
+
+ # The following code is very similar as the code in roll_tensor function
+ local_chunks = tensor_slice.chunk(2, dim=dims)
+ rolled_chunks = [
+ torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks
+ ]
+
+ tensor_send_list = []
+ tensor_recv_list = []
+ for chunk in rolled_chunks:
+ boundary = chunk.select(dims, shifts).contiguous().clone()
+ tensor_send_list.append(boundary)
+ tensor_recv_list.append(torch.empty_like(boundary))
+
+ ops = []
+ if local_rank != 0:
+ ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank))
+ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank))
+ else:
+ tensor_recv_list[1].zero_()
+
+ if local_rank != cp_size - 1:
+ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank))
+ ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank))
+ else:
+ tensor_recv_list[0].copy_(tensor_send_list[1])
+
+ for op in ops:
+ op.wait()
+
+ index = [slice(None)] * rolled_chunks[0].dim()
+ index[dims] = shifts
+ for chunk, recv in zip(rolled_chunks, tensor_recv_list):
+ chunk[tuple(index)] = recv
+
+ seq_result = torch.cat(rolled_chunks, dim=dims)
+
+ # update the rolled tensor
+ rolled_tensor[..., local_start_idx:local_end_idx] = seq_result
+
+ return rolled_tensor, rolled_tensor.sum()
class MTPLossLoggingHelper:
"""Helper class for logging MTP losses."""
@@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule):
def _get_embeddings(
self,
input_ids: torch.Tensor,
- position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ packed_seq_params: Optional[PackedSeqParams] = None,
):
"""
Preprocesses input data for the Multi-Token Prediction (MTP) layers.
@@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule):
sequence length, b is the batch size, and h is the hidden size.
"""
# Calc logits for the current Multi-Token Prediction (MTP) layers.
- input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group)
- position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group)
+ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
+
+ # Prepare/roll position ids only when applicable.
+ if position_ids is None:
+ # Fallback position ids for learned absolute embedding.
+ seq_len = input_ids.size(-1)
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+
+ position_ids, _ = roll_tensor(
+ position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params
+ )
# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
+ decoder_input = decoder_input.detach()
- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False)
return input_ids, position_ids, decoder_input, hidden_states
@@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule):
return hidden_states
def _checkpointed_forward(self, forward_func, *args, **kwargs):
+ """Wrap `forward_func` with activation checkpointing while only passing tensors.
+
+ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so
+ that checkpoint implementations never receive them directly, avoiding save_for_backward
+ issues with non-tensor inputs.
+ """
+
+ # TODO(jiajun): Is there any better implementation here?
+ positional_specs = []
+ kw_specs = []
+ tensor_args: List[torch.Tensor] = []
+
+ for arg in args:
+ if torch.is_tensor(arg):
+ positional_specs.append(('tensor', len(tensor_args)))
+ tensor_args.append(arg)
+ else:
+ positional_specs.append(('const', arg))
+
+ for key, value in kwargs.items():
+ if torch.is_tensor(value):
+ kw_specs.append((key, ('tensor', len(tensor_args))))
+ tensor_args.append(value)
+ else:
+ kw_specs.append((key, ('const', value)))
+
+ def run(*flat_tensor_args):
+ rebuilt_args = []
+ for spec_type, payload in positional_specs:
+ if spec_type == 'tensor':
+ rebuilt_args.append(flat_tensor_args[payload])
+ else:
+ rebuilt_args.append(payload)
+
+ rebuilt_kwargs = {}
+ for key, (spec_type, payload) in kw_specs:
+ if spec_type == 'tensor':
+ rebuilt_kwargs[key] = flat_tensor_args[payload]
+ else:
+ rebuilt_kwargs[key] = payload
+
+ return forward_func(*rebuilt_args, **rebuilt_kwargs)
+
+ tensor_args_tuple = tuple(tensor_args)
+
def checkpoint_handler():
- """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
+ """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`."""
if self.config.fp8:
from megatron.core.extensions.transformer_engine import te_checkpoint
return te_checkpoint(
- forward_func,
+ run,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
- *args,
- **kwargs,
+ *tensor_args_tuple,
)
else:
return tensor_parallel.checkpoint(
- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values()
+ run, self.config.distribute_saved_activations, *tensor_args_tuple
)
if self.config.recompute_method == 'uniform':
@@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule):
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert context is None, f"multi token prediction + cross attention is not yet supported."
- assert (
- packed_seq_params is None
- ), f"multi token prediction + sequence packing is not yet supported."
input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding=embedding,
hidden_states=hidden_states,
+ packed_seq_params=packed_seq_params,
)
if self.config.recompute_granularity == 'full' and self.training:
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index d55bebe7e..1eecbbd38 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig):
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""
+ post_self_attn_layernorm: bool = False
+ post_mlp_layernorm: bool = False
+ use_gated_attention: bool = False
+
test_mode: bool = False
"""Whether to run real-time tests."""
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 84f22bdea..f0f3f8e86 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -224,6 +224,7 @@ class TransformerLayerSubmodules:
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
@@ -232,6 +233,7 @@ class TransformerLayerSubmodules:
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
@@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
+ self.post_self_attn_layernorm = build_module(
+ submodules.post_self_attn_layernorm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon,
+ )
+
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
@@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
+ self.post_mlp_layernorm = build_module(
+ submodules.post_mlp_layernorm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon
+ )
+
self.recompute_input_layernorm = False
self.recompute_pre_mlp_layernorm = False
self.recompute_mlp = False
@@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
attention_output_with_bias[0]
)
+ attention_output, attention_output_bias = attention_output_with_bias
+ attention_output = self.post_self_attn_layernorm(attention_output)
+ attention_output_with_bias = (attention_output, attention_output_bias)
+
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
@@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer):
else:
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
+ mlp_output, mlp_output_bias = mlp_output_with_bias
+ mlp_output = self.post_mlp_layernorm(mlp_output)
+ mlp_output_with_bias = (mlp_output, mlp_output_bias)
+
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index e3459c5ee..7346bf35b 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -937,8 +937,6 @@ def validate_args(args, defaults={}):
# MoE Spec check
if args.num_experts == 0:
args.num_experts = None
- if args.num_experts is not None:
- assert args.spec is None, "Model Spec must be None when using MoEs"
if args.num_experts is not None and args.moe_ffn_hidden_size is None:
args.moe_ffn_hidden_size = args.ffn_hidden_size
print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.")
@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None):
if args.is_hybrid_model:
kw_args['is_hybrid_model'] = args.is_hybrid_model
+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
+ kw_args['use_gated_attention'] = args.use_gated_attention
+
# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
# Kitchen library is available.
@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser):
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
+ group.add_argument('--post-self-attn-layernorm', action='store_true',
+ help='If set, use post self attention layernorm.')
+ group.add_argument('--post-mlp-layernorm', action='store_true',
+ help='If set, use post MLP layernorm.')
+ group.add_argument('--use-gated-attention', action='store_true',
+ help='If set, use gated attention as in Qwen3Next')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
index 5cf222ccc..d1554ca4c 100644
--- a/megatron/training/tokenizer/tokenizer.py
+++ b/megatron/training/tokenizer/tokenizer.py
@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer):
f"The transformers library must be installed to use huggingface_tokenizer_provider"
)
+ if "trust_remote_code" not in kwargs:
+ kwargs["trust_remote_code"] = True
# TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs