diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index 41c21d93d..ef80f72d6 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): msc = MultiStorageClientFeature.import_package() return msc.torch.load(load_path, map_location='cpu') else: - return torch.load(load_path, map_location='cpu') + return torch.load(load_path, map_location='cpu', weights_only=False) except FileNotFoundError as e: err_msg = f'Common file {load_path} does not exist' if MultiStorageClientFeature.is_enabled(): diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index 5a1ea308d..aa701237f 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): def _validate_global_shapes(self, metadata, sharded_tensors): for sh_ten in sharded_tensors: if sh_ten.key not in metadata.state_dict_metadata: - raise KeyError( - f"{sh_ten.key} from model not in state dict:" - f" {sorted(metadata.state_dict_metadata.keys())}" - ) + # raise KeyError( + # f"{sh_ten.key} from model not in state dict:" + # f" {sorted(metadata.state_dict_metadata.keys())}" + # ) + print(f"{sh_ten.key} from model not in state dict, will skip") + continue loaded_shape = metadata.state_dict_metadata[sh_ten.key].size expected_shape = self._expected_shape(sh_ten) if loaded_shape != expected_shape: @@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): tensor_metadata = self.metadata.state_dict_metadata metadata_with_sizes = [ (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) - for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() + for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata ] try: # Temporarily set sizes to expected shapes @@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): planner=MCoreLoadPlanner( shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, + allow_partial_load=True, ), ) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index acb93ef78..d239db4ab 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): ) for param in self.parameters(): + setattr(param, "parallel_mode", parallel_mode) if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) @@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): if HAVE_TE and is_te_min_version("1.9.0.dev0"): + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + class _FakeInt4QuantizationSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x, group_size): + m, n = x.shape + block_size_m, block_size_n = 1, group_size + + + m_padded = ceil_div(m, block_size_m) * block_size_m + n_padded = ceil_div(n, block_size_n) * block_size_n + + x_padded = torch.zeros( + (m_padded, n_padded), + dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + + x_view = x_padded.view( + m_padded // block_size_m, + block_size_m, + n_padded // block_size_n, + block_size_n + ) + + x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) + q_max = 7 + x_scale = x_max / q_max + + x_scale = x_scale.clamp(min=1e-5) + + x_div = x_view / x_scale + x_round = torch.round(x_div) + + x_q_clamped = x_round.clamp(-q_max, q_max) + + x_dequant_view = x_q_clamped * x_scale + + x_dequant_full = x_dequant_view.view_as(x_padded) + x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) + + return x_out + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + def fake_int4_quantization_ste(x, group_size): + x_out = _FakeInt4QuantizationSTE.apply(x, group_size) + + if hasattr(x, 'main_grad'): + x_out.main_grad = x.main_grad + + return x_out class TEGroupedLinear(te.pytorch.GroupedLinear): """ @@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): _is_first_microbatch = ( None if self.disable_parameter_transpose_cache else self.is_first_microbatch ) + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) self.is_first_microbatch = False @@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): return out return out, None + def _get_weight_tensors(self): + """Get the weight tensors of the module.""" + weight_tensors = super()._get_weight_tensors() + + if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": + group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) + + weight_tensors = [ + fake_int4_quantization_ste(w, group_size) + for w in weight_tensors + ] + + return weight_tensors + def _encode_extra_state(self, state): # TE 2.0 changed the format of extra_state to be a byte tensor if is_te_min_version("2.0.0"): diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py index 1fd5dcfae..c9aeef1f0 100644 --- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py +++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py @@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( SIN, emb_dim: tl.constexpr, k_dim: tl.constexpr, + k_dim_ceil: tl.constexpr, v_dim: tl.constexpr, head_num: tl.constexpr, batch_size, @@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads - kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads - mask = kv_off < head_num * stride_kv_nheads - k_in_off = kv_off + tl.arange(0, k_dim)[None, :] - v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] - k = tl.load(KV_ptr + k_in_off, mask=mask) - v = tl.load(KV_ptr + v_in_off, mask=mask) + KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads + ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H + kj_range = tl.arange(0, k_dim_ceil)[None, :] + mask_k = (ki_range < head_num) & (kj_range < k_dim) + mask_v = ki_range < head_num + k_off = ki_range * stride_kv_nheads + kj_range + if v_dim > 0: + v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] + v = tl.load(KV_ptr + v_off, mask=mask_v) + else: + v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) + k = tl.load(KV_ptr + k_off, mask=mask_k) - K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads - V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads + K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads + V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads - k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] - v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] - tl.store(K_ptr + k_out_off, k, mask=mask) - tl.store(V_ptr + v_out_off, v, mask=mask) + k_out_off = ki_range * stride_k_nheads + kj_range + tl.store(K_ptr + k_out_off, k, mask=mask_k) + if v_dim > 0: + v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] + tl.store(V_ptr + v_out_off, v, mask=mask_v) EMB = K_POS_EMB + pid_m * stride_emb_seq # x1 = t[..., 0::2], x2 = t[..., 1::2] @@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H + mask_x = x_range < head_num x_left_off = ( - tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + x_range * stride_k_nheads + k_dim + tl.arange(0, emb_dim // 2)[None, :] ) x_right_off = x_left_off + emb_dim // 2 - tl.store(K_ptr + x_left_off, x_left, mask=mask) - tl.store(K_ptr + x_right_off, x_right, mask=mask) + tl.store(K_ptr + x_left_off, x_left, mask=mask_x) + tl.store(K_ptr + x_right_off, x_right, mask=mask_x) @triton.autotune( @@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( SIN, emb_dim: tl.constexpr, k_dim: tl.constexpr, + k_dim_ceil: tl.constexpr, v_dim: tl.constexpr, head_num: tl.constexpr, batch_size, @@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( else: token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) - dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads - dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads - mask = dkv_off < head_num * stride_dkv_nheads - dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] - dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] - - dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads - dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads - dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] - dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] - dk = tl.load(dK_ptr + dk_in_off, mask=mask) - dv = tl.load(dV_ptr + dv_in_off, mask=mask) - tl.store(dKV_ptr + dk_out_off, dk, mask=mask) - tl.store(dKV_ptr + dv_out_off, dv, mask=mask) + dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads + ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H + kj_range = tl.arange(0, k_dim_ceil)[None, :] + mask_k = (ki_range < head_num) & (kj_range < k_dim) + mask_v = ki_range < head_num + dk_out_off = ki_range * stride_dkv_nheads + kj_range + + dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads + dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads + dk_in_off = ki_range * stride_dk_nheads + kj_range + + dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) + tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) + + if v_dim > 0: + dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] + dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] + dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) + tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) if pid_head == 0: x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): - dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads - x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads + x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads mask = x_off < head_num * stride_dk_nheads x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] x_right_off = x_left_off + emb_dim // 2 @@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) o_value = kv.new_empty(total_seqlen, nheads, v_dim) + k_dim_ceil = triton.next_power_of_2(k_dim) grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) rotary_fwd_kv_kernel[grid]( @@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): sin, emb_dim, k_dim, + k_dim_ceil, v_dim, nheads, batch_size, @@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) + k_dim_ceil = triton.next_power_of_2(ctx.k_dim) grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) rotary_bwd_kv_kernel[grid]( @@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): sin, ctx.emb_dim, ctx.k_dim, + k_dim_ceil, ctx.v_dim, nheads, batch_size, diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 13d74aa52..060898a7a 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): assert ( column_parallel_linear is not None ), "column_parallel_linear cannot be None when not using fused linear cross entropy." - logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) + # output + output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} + output_layer_buffers = dict(column_parallel_linear.named_buffers()) + logits, _ = torch.func.functional_call( + column_parallel_linear, + {**output_layer_params, **output_layer_buffers}, + (hidden,), + col_linear_kwargs, + ) return self.compute_language_model_loss(labels, logits) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index e21127b87..712793853 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( use_kitchen: bool = False, use_te_activation_func: bool = False, fallback_to_eager_attn: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). @@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( mlp=mlp, sharded_state_dict_keys_map=sharded_state_dict_keys_map, normalization=normalization, + post_self_attn_layernorm=post_self_attn_layernorm, + post_mlp_layernorm=post_mlp_layernorm, ) @@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( mlp: ModuleSpec, sharded_state_dict_keys_map: Optional[dict] = None, normalization: Optional[str] = None, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Helper function to get module spec for TransformerLayer""" @@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( input_layernorm=input_layernorm, self_attention=attention, self_attn_bda=get_bias_dropout_add, + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, pre_mlp_layernorm=pre_mlp_layernorm, mlp=mlp, mlp_bda=get_bias_dropout_add, + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, sharded_state_dict_keys_map=sharded_state_dict_keys_map, ), ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index a1230568c..1fd52f65a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -446,6 +446,7 @@ class GPTModel(LanguageModule): *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[Tensor] = None, + mtp_kwargs: Optional[dict] = {}, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoder and finally into the post @@ -508,6 +509,7 @@ class GPTModel(LanguageModule): runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, + mtp_kwargs=mtp_kwargs, ) def _postprocess( @@ -529,6 +531,7 @@ class GPTModel(LanguageModule): runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + mtp_kwargs={}, ): """Postprocesses decoder hidden states to generate logits or compute loss. @@ -543,7 +546,8 @@ class GPTModel(LanguageModule): output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + + if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -563,13 +567,18 @@ class GPTModel(LanguageModule): return hidden_states # Skip when mtp_num_layers is None or 0 - if self.config.mtp_num_layers: - mtp_labels = labels.clone() + if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: + mtp_labels = mtp_kwargs['mtp_labels'].clone() + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) hidden_states = hidden_states_list[0] if loss_mask is None: # if loss_mask is not provided, use all ones as loss_mask loss_mask = torch.ones_like(mtp_labels) + else: + # Otherwise, roll the loss_mask to keep up with the mtp_labels + loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) for mtp_layer_number in range(self.config.mtp_num_layers): # Calc loss for the current Multi-Token Prediction (MTP) layers. mtp_labels, _ = roll_tensor( @@ -595,7 +604,7 @@ class GPTModel(LanguageModule): sequence_parallel_enabled=self.output_layer.sequence_parallel, column_parallel_linear=self.output_layer, col_linear_kwargs={ - 'weight': output_weight, + 'weight': output_weight.detach() if output_weight else None, 'runtime_gather_output': runtime_gather_output, }, ) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 6e093f96f..eac21a3ea 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) + if "step" in param_group and param_group["step"] is None: + del param_group["step"] # Grad scaler state. if self.grad_scaler: @@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): if key == 'padding': tensors[key] = LocalNonpersistentObject(tensors[key]) continue + if key == 'step': + continue assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( tensors[key].shape, gbuf_local_start, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index a273002b9..4f821cfd5 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -11,6 +11,7 @@ from typing import Callable, List, Optional import numpy as np import torch +import torch.distributed as dist from .utils import GlobalMemoryBuffer, is_torch_min_version diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index ac839c21f..f18309217 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group + torch.distributed.isend, tensor_send_next, next_pipeline_rank, ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, ) ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 28cff06f5..58dc4bb70 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -587,6 +587,9 @@ def topk_routing_with_score_function( else: return torch.topk(scores, k=topk, dim=1) + from slime.utils.routing_replay import get_routing_replay_compute_topk + compute_topk = get_routing_replay_compute_topk(compute_topk) + if score_function == "softmax": if use_pre_softmax: scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 16fc9d9af..517944f25 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -201,6 +201,9 @@ class TopKRouter(Router): self.global_tokens_per_expert = None self.ga_steps = None + from slime.utils.routing_replay import register_routing_replay + register_routing_replay(self) + def _maintain_float32_expert_bias(self): """ Maintain the expert bias in float32. diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index a8f4abfcd..f33f6f05e 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union import torch from torch import Tensor +import warnings from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): cp_group=self.cp_group, packed_seq_params=packed_seq_params, ) - position_ids, _ = roll_tensor( - position_ids, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) + if position_ids is not None: + position_ids, _ = roll_tensor( + position_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) # embedding decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + decoder_input = decoder_input.detach() - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) return input_ids, position_ids, decoder_input, hidden_states @@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): return hidden_states def _checkpointed_forward(self, forward_func, *args, **kwargs): + """Wrap `forward_func` with activation checkpointing while only passing tensors. + + Non-tensor arguments (e.g., configuration objects, None) are captured via closure so + that checkpoint implementations never receive them directly, avoiding save_for_backward + issues with non-tensor inputs. + """ + + # TODO(jiajun): Is there any better implementation here? + positional_specs = [] + kw_specs = [] + tensor_args: List[torch.Tensor] = [] + + for arg in args: + if torch.is_tensor(arg): + positional_specs.append(('tensor', len(tensor_args))) + tensor_args.append(arg) + else: + positional_specs.append(('const', arg)) + + for key, value in kwargs.items(): + if torch.is_tensor(value): + kw_specs.append((key, ('tensor', len(tensor_args)))) + tensor_args.append(value) + else: + kw_specs.append((key, ('const', value))) + + def run(*flat_tensor_args): + rebuilt_args = [] + for spec_type, payload in positional_specs: + if spec_type == 'tensor': + rebuilt_args.append(flat_tensor_args[payload]) + else: + rebuilt_args.append(payload) + + rebuilt_kwargs = {} + for key, (spec_type, payload) in kw_specs: + if spec_type == 'tensor': + rebuilt_kwargs[key] = flat_tensor_args[payload] + else: + rebuilt_kwargs[key] = payload + + return forward_func(*rebuilt_args, **rebuilt_kwargs) + + tensor_args_tuple = tuple(tensor_args) + def checkpoint_handler(): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" if self.config.fp8: @@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, parallel_state.get_tensor_model_parallel_group(), - *args, - **kwargs, + *tensor_args_tuple, ) else: return tensor_parallel.checkpoint( - forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() + run, self.config.distribute_saved_activations, *tensor_args_tuple ) if self.config.recompute_method == 'uniform': diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index e2705bd9f..a0aa109b5 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): attention_output_gate: bool = False """Whether to apply output gate to the attention layers.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False + test_mode: bool = False """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 3ea405770..5a42001b9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -223,6 +223,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp @@ -231,6 +232,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) @@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) + self.post_self_attn_layernorm = build_module( + submodules.post_self_attn_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, @@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): self.is_moe_layer = isinstance(self.mlp, MoELayer) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon + ) + self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False @@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): attention_output_with_bias[0] ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) + # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") @@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + mlp_output, mlp_output_bias = mlp_output_with_bias + mlp_output = self.post_mlp_layernorm(mlp_output) + mlp_output_with_bias = (mlp_output, mlp_output_bias) + if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b267c8a81..83736acdc 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['inference_sampling_seed'] = args.seed + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm + # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. @@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') + group.add_argument('--post-self-attn-layernorm', action='store_true', + help='If set, use post self attention layernorm.') + group.add_argument('--post-mlp-layernorm', action='store_true', + help='If set, use post MLP layernorm.') + group.add_argument('--use-gated-attention', action='store_true', + help='If set, use gated attention as in Qwen3Next') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index 13b7526ca..6c590f653 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there self._tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, + trust_remote_code=True, **kwargs, ) self._vocab = self._tokenizer.get_vocab()