JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py
index 41c21d93d..ef80f72d6 100644
--- a/megatron/core/dist_checkpointing/strategies/common.py
+++ b/megatron/core/dist_checkpointing/strategies/common.py
@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
msc = MultiStorageClientFeature.import_package()
return msc.torch.load(load_path, map_location='cpu')
else:
- return torch.load(load_path, map_location='cpu')
+ return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
if MultiStorageClientFeature.is_enabled():
diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py
index 5a1ea308d..aa701237f 100644
--- a/megatron/core/dist_checkpointing/strategies/torch.py
+++ b/megatron/core/dist_checkpointing/strategies/torch.py
@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
if sh_ten.key not in metadata.state_dict_metadata:
- raise KeyError(
- f"{sh_ten.key} from model not in state dict:"
- f" {sorted(metadata.state_dict_metadata.keys())}"
- )
+ # raise KeyError(
+ # f"{sh_ten.key} from model not in state dict:"
+ # f" {sorted(metadata.state_dict_metadata.keys())}"
+ # )
+ print(f"{sh_ten.key} from model not in state dict, will skip")
+ continue
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
expected_shape = self._expected_shape(sh_ten)
if loaded_shape != expected_shape:
@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
tensor_metadata = self.metadata.state_dict_metadata
metadata_with_sizes = [
(tensor_metadata[key], tensor_metadata[key].size, sharded_tensor)
- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items()
+ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata
]
try:
# Temporarily set sizes to expected shapes
@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
+ allow_partial_load=True,
),
)
diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py
index acb93ef78..d239db4ab 100644
--- a/megatron/core/extensions/transformer_engine.py
+++ b/megatron/core/extensions/transformer_engine.py
@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear):
)
for param in self.parameters():
+ setattr(param, "parallel_mode", parallel_mode)
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, "allreduce", not self.expert_parallel)
@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
if HAVE_TE and is_te_min_version("1.9.0.dev0"):
+ def ceil_div(x: int, y: int) -> int:
+ return (x + y - 1) // y
+
+ class _FakeInt4QuantizationSTE(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, group_size):
+ m, n = x.shape
+ block_size_m, block_size_n = 1, group_size
+
+
+ m_padded = ceil_div(m, block_size_m) * block_size_m
+ n_padded = ceil_div(n, block_size_n) * block_size_n
+
+ x_padded = torch.zeros(
+ (m_padded, n_padded),
+ dtype=x.dtype, device=x.device
+ )
+ x_padded[:m, :n] = x
+
+ x_view = x_padded.view(
+ m_padded // block_size_m,
+ block_size_m,
+ n_padded // block_size_n,
+ block_size_n
+ )
+
+ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True)
+ q_max = 7
+ x_scale = x_max / q_max
+
+ x_scale = x_scale.clamp(min=1e-5)
+
+ x_div = x_view / x_scale
+ x_round = torch.round(x_div)
+
+ x_q_clamped = x_round.clamp(-q_max, q_max)
+
+ x_dequant_view = x_q_clamped * x_scale
+
+ x_dequant_full = x_dequant_view.view_as(x_padded)
+ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype)
+
+ return x_out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None
+
+ def fake_int4_quantization_ste(x, group_size):
+ x_out = _FakeInt4QuantizationSTE.apply(x, group_size)
+
+ if hasattr(x, 'main_grad'):
+ x_out.main_grad = x.main_grad
+
+ return x_out
class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
+
out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
return out
return out, None
+ def _get_weight_tensors(self):
+ """Get the weight tensors of the module."""
+ weight_tensors = super()._get_weight_tensors()
+
+ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1":
+ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128"))
+
+ weight_tensors = [
+ fake_int4_quantization_ste(w, group_size)
+ for w in weight_tensors
+ ]
+
+ return weight_tensors
+
def _encode_extra_state(self, state):
# TE 2.0 changed the format of extra_state to be a byte tensor
if is_te_min_version("2.0.0"):
diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py
index 1fd5dcfae..c9aeef1f0 100644
--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py
+++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py
@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel(
SIN,
emb_dim: tl.constexpr,
k_dim: tl.constexpr,
+ k_dim_ceil: tl.constexpr,
v_dim: tl.constexpr,
head_num: tl.constexpr,
batch_size,
@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel(
cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))
sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))
- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads
- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads
- mask = kv_off < head_num * stride_kv_nheads
- k_in_off = kv_off + tl.arange(0, k_dim)[None, :]
- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :]
- k = tl.load(KV_ptr + k_in_off, mask=mask)
- v = tl.load(KV_ptr + v_in_off, mask=mask)
+ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads
+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
+ kj_range = tl.arange(0, k_dim_ceil)[None, :]
+ mask_k = (ki_range < head_num) & (kj_range < k_dim)
+ mask_v = ki_range < head_num
+ k_off = ki_range * stride_kv_nheads + kj_range
+ if v_dim > 0:
+ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
+ v = tl.load(KV_ptr + v_off, mask=mask_v)
+ else:
+ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty)
+ k = tl.load(KV_ptr + k_off, mask=mask_k)
- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads
- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads
+ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads
+ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads
- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :]
- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :]
- tl.store(K_ptr + k_out_off, k, mask=mask)
- tl.store(V_ptr + v_out_off, v, mask=mask)
+ k_out_off = ki_range * stride_k_nheads + kj_range
+ tl.store(K_ptr + k_out_off, k, mask=mask_k)
+ if v_dim > 0:
+ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :]
+ tl.store(V_ptr + v_out_off, v, mask=mask_v)
EMB = K_POS_EMB + pid_m * stride_emb_seq
# x1 = t[..., 0::2], x2 = t[..., 1::2]
@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel(
x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)
x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)
+ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
+ mask_x = x_range < head_num
x_left_off = (
- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads
+ x_range * stride_k_nheads
+ k_dim
+ tl.arange(0, emb_dim // 2)[None, :]
)
x_right_off = x_left_off + emb_dim // 2
- tl.store(K_ptr + x_left_off, x_left, mask=mask)
- tl.store(K_ptr + x_right_off, x_right, mask=mask)
+ tl.store(K_ptr + x_left_off, x_left, mask=mask_x)
+ tl.store(K_ptr + x_right_off, x_right, mask=mask_x)
@triton.autotune(
@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel(
SIN,
emb_dim: tl.constexpr,
k_dim: tl.constexpr,
+ k_dim_ceil: tl.constexpr,
v_dim: tl.constexpr,
head_num: tl.constexpr,
batch_size,
@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel(
else:
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
- mask = dkv_off < head_num * stride_dkv_nheads
- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :]
- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :]
-
- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads
- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads
- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :]
- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
- dk = tl.load(dK_ptr + dk_in_off, mask=mask)
- dv = tl.load(dV_ptr + dv_in_off, mask=mask)
- tl.store(dKV_ptr + dk_out_off, dk, mask=mask)
- tl.store(dKV_ptr + dv_out_off, dv, mask=mask)
+ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads
+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
+ kj_range = tl.arange(0, k_dim_ceil)[None, :]
+ mask_k = (ki_range < head_num) & (kj_range < k_dim)
+ mask_v = ki_range < head_num
+ dk_out_off = ki_range * stride_dkv_nheads + kj_range
+
+ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads
+ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads
+ dk_in_off = ki_range * stride_dk_nheads + kj_range
+
+ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k)
+ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k)
+
+ if v_dim > 0:
+ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
+ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
+ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v)
+ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v)
if pid_head == 0:
x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)):
- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads
- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim
+ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads
+ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads
mask = x_off < head_num * stride_dk_nheads
x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :]
x_right_off = x_left_off + emb_dim // 2
@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim)
o_value = kv.new_empty(total_seqlen, nheads, v_dim)
+ k_dim_ceil = triton.next_power_of_2(k_dim)
grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
rotary_fwd_kv_kernel[grid](
@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
sin,
emb_dim,
k_dim,
+ k_dim_ceil,
v_dim,
nheads,
batch_size,
@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim)
d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim)
+ k_dim_ceil = triton.next_power_of_2(ctx.k_dim)
grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
rotary_bwd_kv_kernel[grid](
@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
sin,
ctx.emb_dim,
ctx.k_dim,
+ k_dim_ceil,
ctx.v_dim,
nheads,
batch_size,
diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py
index 13d74aa52..060898a7a 100644
--- a/megatron/core/models/common/language_module/language_module.py
+++ b/megatron/core/models/common/language_module/language_module.py
@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule):
assert (
column_parallel_linear is not None
), "column_parallel_linear cannot be None when not using fused linear cross entropy."
- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs)
+ # output
+ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()}
+ output_layer_buffers = dict(column_parallel_linear.named_buffers())
+ logits, _ = torch.func.functional_call(
+ column_parallel_linear,
+ {**output_layer_params, **output_layer_buffers},
+ (hidden,),
+ col_linear_kwargs,
+ )
return self.compute_language_model_loss(labels, logits)
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
index e21127b87..712793853 100755
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ b/megatron/core/models/gpt/gpt_layer_specs.py
@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec(
use_kitchen: bool = False,
use_te_activation_func: bool = False,
fallback_to_eager_attn: bool = False,
+ post_self_attn_layernorm: bool = False,
+ post_mlp_layernorm: bool = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec(
mlp=mlp,
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
normalization=normalization,
+ post_self_attn_layernorm=post_self_attn_layernorm,
+ post_mlp_layernorm=post_mlp_layernorm,
)
@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend(
mlp: ModuleSpec,
sharded_state_dict_keys_map: Optional[dict] = None,
normalization: Optional[str] = None,
+ post_self_attn_layernorm: bool = False,
+ post_mlp_layernorm: bool = False,
) -> ModuleSpec:
"""Helper function to get module spec for TransformerLayer"""
@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend(
input_layernorm=input_layernorm,
self_attention=attention,
self_attn_bda=get_bias_dropout_add,
+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp,
pre_mlp_layernorm=pre_mlp_layernorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
index a1230568c..1fd52f65a 100644
--- a/megatron/core/models/gpt/gpt_model.py
+++ b/megatron/core/models/gpt/gpt_model.py
@@ -446,6 +446,7 @@ class GPTModel(LanguageModule):
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
+ mtp_kwargs: Optional[dict] = {},
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
@@ -508,6 +509,7 @@ class GPTModel(LanguageModule):
runtime_gather_output=runtime_gather_output,
extra_block_kwargs=extra_block_kwargs,
inference_context=inference_context,
+ mtp_kwargs=mtp_kwargs,
)
def _postprocess(
@@ -529,6 +531,7 @@ class GPTModel(LanguageModule):
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
+ mtp_kwargs={},
):
"""Postprocesses decoder hidden states to generate logits or compute loss.
@@ -543,7 +546,8 @@ class GPTModel(LanguageModule):
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
- if mtp_in_postprocess:
+
+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None:
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
@@ -563,13 +567,18 @@ class GPTModel(LanguageModule):
return hidden_states
# Skip when mtp_num_layers is None or 0
- if self.config.mtp_num_layers:
- mtp_labels = labels.clone()
+ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None:
+ mtp_labels = mtp_kwargs['mtp_labels'].clone()
+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
+
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
+ else:
+ # Otherwise, roll the loss_mask to keep up with the mtp_labels
+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
@@ -595,7 +604,7 @@ class GPTModel(LanguageModule):
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
- 'weight': output_weight,
+ 'weight': output_weight.detach() if output_weight else None,
'runtime_gather_output': runtime_gather_output,
},
)
diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
index 6e093f96f..eac21a3ea 100644
--- a/megatron/core/optimizer/distrib_optimizer.py
+++ b/megatron/core/optimizer/distrib_optimizer.py
@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# TE FusedAdam will not accumulate step for empty param groups, so we need to
# align the step across param groups.
param_group["step"] = int(step)
+ if "step" in param_group and param_group["step"] is None:
+ del param_group["step"]
# Grad scaler state.
if self.grad_scaler:
@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
if key == 'padding':
tensors[key] = LocalNonpersistentObject(tensors[key])
continue
+ if key == 'step':
+ continue
assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), (
tensors[key].shape,
gbuf_local_start,
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index a273002b9..4f821cfd5 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -11,6 +11,7 @@ from typing import Callable, List, Optional
import numpy as np
import torch
+import torch.distributed as dist
from .utils import GlobalMemoryBuffer, is_torch_min_version
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index ac839c21f..f18309217 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -26,22 +26,22 @@ def _batched_p2p_ops(
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
+ torch.distributed.isend, tensor_send_next, next_pipeline_rank,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank,
)
ops.append(recv_next_op)
if len(ops) > 0:
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
index 28cff06f5..58dc4bb70 100644
--- a/megatron/core/transformer/moe/moe_utils.py
+++ b/megatron/core/transformer/moe/moe_utils.py
@@ -587,6 +587,9 @@ def topk_routing_with_score_function(
else:
return torch.topk(scores, k=topk, dim=1)
+ from slime.utils.routing_replay import get_routing_replay_compute_topk
+ compute_topk = get_routing_replay_compute_topk(compute_topk)
+
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py
index 16fc9d9af..517944f25 100644
--- a/megatron/core/transformer/moe/router.py
+++ b/megatron/core/transformer/moe/router.py
@@ -201,6 +201,9 @@ class TopKRouter(Router):
self.global_tokens_per_expert = None
self.ga_steps = None
+ from slime.utils.routing_replay import register_routing_replay
+ register_routing_replay(self)
+
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py
index a8f4abfcd..f33f6f05e 100755
--- a/megatron/core/transformer/multi_token_prediction.py
+++ b/megatron/core/transformer/multi_token_prediction.py
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
import torch
from torch import Tensor
+import warnings
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule):
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
- position_ids, _ = roll_tensor(
- position_ids,
- shifts=-1,
- dims=-1,
- cp_group=self.cp_group,
- packed_seq_params=packed_seq_params,
- )
+ if position_ids is not None:
+ position_ids, _ = roll_tensor(
+ position_ids,
+ shifts=-1,
+ dims=-1,
+ cp_group=self.cp_group,
+ packed_seq_params=packed_seq_params,
+ )
# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
+ decoder_input = decoder_input.detach()
- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False)
return input_ids, position_ids, decoder_input, hidden_states
@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule):
return hidden_states
def _checkpointed_forward(self, forward_func, *args, **kwargs):
+ """Wrap `forward_func` with activation checkpointing while only passing tensors.
+
+ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so
+ that checkpoint implementations never receive them directly, avoiding save_for_backward
+ issues with non-tensor inputs.
+ """
+
+ # TODO(jiajun): Is there any better implementation here?
+ positional_specs = []
+ kw_specs = []
+ tensor_args: List[torch.Tensor] = []
+
+ for arg in args:
+ if torch.is_tensor(arg):
+ positional_specs.append(('tensor', len(tensor_args)))
+ tensor_args.append(arg)
+ else:
+ positional_specs.append(('const', arg))
+
+ for key, value in kwargs.items():
+ if torch.is_tensor(value):
+ kw_specs.append((key, ('tensor', len(tensor_args))))
+ tensor_args.append(value)
+ else:
+ kw_specs.append((key, ('const', value)))
+
+ def run(*flat_tensor_args):
+ rebuilt_args = []
+ for spec_type, payload in positional_specs:
+ if spec_type == 'tensor':
+ rebuilt_args.append(flat_tensor_args[payload])
+ else:
+ rebuilt_args.append(payload)
+
+ rebuilt_kwargs = {}
+ for key, (spec_type, payload) in kw_specs:
+ if spec_type == 'tensor':
+ rebuilt_kwargs[key] = flat_tensor_args[payload]
+ else:
+ rebuilt_kwargs[key] = payload
+
+ return forward_func(*rebuilt_args, **rebuilt_kwargs)
+
+ tensor_args_tuple = tuple(tensor_args)
+
def checkpoint_handler():
"""Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
if self.config.fp8:
@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule):
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
- *args,
- **kwargs,
+ *tensor_args_tuple,
)
else:
return tensor_parallel.checkpoint(
- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values()
+ run, self.config.distribute_saved_activations, *tensor_args_tuple
)
if self.config.recompute_method == 'uniform':
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index e2705bd9f..a0aa109b5 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig):
attention_output_gate: bool = False
"""Whether to apply output gate to the attention layers."""
+ post_self_attn_layernorm: bool = False
+ post_mlp_layernorm: bool = False
+
test_mode: bool = False
"""Whether to run real-time tests."""
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 3ea405770..5a42001b9 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -223,6 +223,7 @@ class TransformerLayerSubmodules:
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
@@ -231,6 +232,7 @@ class TransformerLayerSubmodules:
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
+ self.post_self_attn_layernorm = build_module(
+ submodules.post_self_attn_layernorm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon,
+ )
+
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
self.is_moe_layer = isinstance(self.mlp, MoELayer)
+ self.post_mlp_layernorm = build_module(
+ submodules.post_mlp_layernorm,
+ config=self.config,
+ hidden_size=self.config.hidden_size,
+ eps=self.config.layernorm_epsilon
+ )
+
self.recompute_input_layernorm = False
self.recompute_pre_mlp_layernorm = False
self.recompute_mlp = False
@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
attention_output_with_bias[0]
)
+ attention_output, attention_output_bias = attention_output_with_bias
+ attention_output = self.post_self_attn_layernorm(attention_output)
+ attention_output_with_bias = (attention_output, attention_output_bias)
+
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push(suffix="self_attn_bda")
@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
else:
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
+ mlp_output, mlp_output_bias = mlp_output_with_bias
+ mlp_output = self.post_mlp_layernorm(mlp_output)
+ mlp_output_with_bias = (mlp_output, mlp_output_bias)
+
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index b267c8a81..83736acdc 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['inference_sampling_seed'] = args.seed
+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
+
# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
# Kitchen library is available.
@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser):
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
+ group.add_argument('--post-self-attn-layernorm', action='store_true',
+ help='If set, use post self attention layernorm.')
+ group.add_argument('--post-mlp-layernorm', action='store_true',
+ help='If set, use post MLP layernorm.')
+ group.add_argument('--use-gated-attention', action='store_true',
+ help='If set, use gated attention as in Qwen3Next')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
index 13b7526ca..6c590f653 100644
--- a/megatron/training/tokenizer/tokenizer.py
+++ b/megatron/training/tokenizer/tokenizer.py
@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer):
# TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
- trust_remote_code=trust_remote_code,
+ trust_remote_code=True,
**kwargs,
)
self._vocab = self._tokenizer.get_vocab()