diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index 41c21d93d..ef80f72d6 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): msc = MultiStorageClientFeature.import_package() return msc.torch.load(load_path, map_location='cpu') else: - return torch.load(load_path, map_location='cpu') + return torch.load(load_path, map_location='cpu', weights_only=False) except FileNotFoundError as e: err_msg = f'Common file {load_path} does not exist' if MultiStorageClientFeature.is_enabled(): diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index ccf5242a2..9b6d3e31f 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li _restore_dict_types(x_val, templ_val) +@dataclass +class MCoreMetadata(Metadata): + """Metadata with mcore specific data.""" + + # holds data related to flattened_range + # TODO: remove when flattened_range is properly removed + mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor + + @dataclass(frozen=True) class MCoreSavePlan(SavePlan): """SavePlan with MCore specific data.""" @@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner): def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: """Merges MCore data for all plans.""" global_plan, metadata = super().create_global_plan(all_plans) - metadata.mcore_data = dict( + mcore_data = dict( ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type] ) + metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata)) return global_plan, metadata def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: @@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): def _validate_global_shapes(self, metadata, sharded_tensors): for sh_ten in sharded_tensors: if sh_ten.key not in metadata.state_dict_metadata: - raise KeyError( - f"{sh_ten.key} from model not in state dict:" - f" {sorted(metadata.state_dict_metadata.keys())}" - ) + # raise KeyError( + # f"{sh_ten.key} from model not in state dict:" + # f" {sorted(metadata.state_dict_metadata.keys())}" + # ) + print(f"{sh_ten.key} from model not in state dict, will skip") + continue loaded_shape = metadata.state_dict_metadata[sh_ten.key].size expected_shape = self._expected_shape(sh_ten) if loaded_shape != expected_shape: @@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): tensor_metadata = self.metadata.state_dict_metadata metadata_with_sizes = [ (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) - for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() + for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata ] try: # Temporarily set sizes to expected shapes @@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): planner=MCoreLoadPlanner( shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, + allow_partial_load=True, ), ) diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py index fe26e8b43..4451f2776 100644 --- a/megatron/core/distributed/__init__.py +++ b/megatron/core/distributed/__init__.py @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig + +# Backward compatibility patch for FSDP module reorganization +import sys +import importlib.util + +spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') +if spec: + custom_fsdp = importlib.util.module_from_spec(spec) + spec.loader.exec_module(custom_fsdp) + sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp + if hasattr(custom_fsdp, 'MegatronFSDP'): + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 7727efe1e..966fe652a 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear): ) for param in self.parameters(): + setattr(param, "parallel_mode", parallel_mode) if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 860ee64a9..80944b702 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec( qk_l2_norm: Optional[bool] = False, use_te_op_fuser: Optional[bool] = False, use_kitchen: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). @@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec( ), ), self_attn_bda=get_bias_dropout_add, + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, sharded_state_dict_keys_map={ "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 6aec66e6d..6ca48b55f 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -355,6 +355,7 @@ class GPTModel(LanguageModule): *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[Tensor] = None, + mtp_kwargs: Optional[dict] = {}, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoeder and finally into the post @@ -410,6 +411,7 @@ class GPTModel(LanguageModule): runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, + mtp_kwargs=mtp_kwargs, ) def _postprocess( @@ -431,6 +433,7 @@ class GPTModel(LanguageModule): runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + mtp_kwargs={}, ): """Postprocesses decoder hidden states to generate logits or compute loss. @@ -446,7 +449,7 @@ class GPTModel(LanguageModule): if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -465,25 +468,37 @@ class GPTModel(LanguageModule): if not self.post_process: return hidden_states - if self.mtp_process: - mtp_labels = labels.clone() + if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None: + mtp_labels = mtp_kwargs['mtp_labels'].clone() + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) hidden_states = hidden_states_list[0] if loss_mask is None: # if loss_mask is not provided, use all ones as loss_mask loss_mask = torch.ones_like(mtp_labels) + else: + # Otherwise, roll the loss_mask to keep up with the mtp_labels + loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) for mtp_layer_number in range(self.config.mtp_num_layers): # output - mtp_logits, _ = self.output_layer( - hidden_states_list[mtp_layer_number + 1], - weight=output_weight, - runtime_gather_output=runtime_gather_output, + output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()} + output_layer_buffers = dict(self.output_layer.named_buffers()) + mtp_logits, _ = torch.func.functional_call( + self.output_layer, + {**output_layer_params, **output_layer_buffers}, + (hidden_states_list[mtp_layer_number + 1],), + { + "weight": output_weight.detach() if output_weight else None, + "runtime_gather_output": runtime_gather_output, + }, ) # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) - loss_mask, num_tokens = roll_tensor( - loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + new_loss_mask, num_tokens = roll_tensor( + loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params ) + loss_mask = new_loss_mask * loss_mask mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) mtp_loss = loss_mask * mtp_loss if self.training: diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index a36b67364..ed8883e32 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) + if "step" in param_group and param_group["step"] is None: + del param_group["step"] # Grad scaler state. if self.grad_scaler: diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index a40c85a88..86688c331 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional import numpy as np import torch +import torch.distributed as dist from .utils import GlobalMemoryBuffer, is_torch_min_version diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 63ee9d1f5..b90b744c1 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group + torch.distributed.isend, tensor_send_next, next_pipeline_rank, ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, ) ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index c749bac43..dde8d50e7 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): # Get the query, key and value tensors based on the type of attention - # self or cross attn. nvtx_range_push(suffix="qkv") - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + if self.config.use_gated_attention: + query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) + else: + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") # =================================================== @@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): # Output. [sq, b, h] # ================= + if self.config.use_gated_attention: + nvtx_range_push(suffix="sigmoid_gate") + core_attn_out = core_attn_out * torch.sigmoid(gate) + nvtx_range_pop(suffix="sigmoid_gate") + nvtx_range_push(suffix="linear_proj") output, bias = self.linear_proj(core_attn_out) nvtx_range_pop(suffix="linear_proj") @@ -879,19 +887,34 @@ class SelfAttention(Attention): model_comm_pgs=model_comm_pgs, ) - self.linear_qkv = build_module( - submodules.linear_qkv, - self.config.hidden_size, - self.query_projection_size + 2 * self.kv_projection_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear or self.config.add_qkv_bias, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name='qkv', - tp_group=self.model_comm_pgs.tp, - ) + if self.config.use_gated_attention: + self.linear_qgkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + 2 * (self.query_projection_size + self.kv_projection_size), + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + tp_group=self.model_comm_pgs.tp, + ) + else: + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + tp_group=self.model_comm_pgs.tp, + ) if submodules.q_layernorm is not None: self.q_layernorm = build_module( @@ -1036,6 +1059,65 @@ class SelfAttention(Attention): return query, key, value + # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 + def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] + mixed_qgkv, _ = self.linear_qgkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] + new_tensor_shape = mixed_qgkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) + * self.hidden_size_per_attention_head + ), + ) + mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) + else: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + gate = gate.reshape(query.size(0), query.size(1), -1) + + if self.q_layernorm is not None: + query = self.q_layernorm(query) + + if self.k_layernorm is not None: + key = self.k_layernorm(key) + + if self.config.test_mode: + self.run_realtime_tests() + + return query, gate, key, value + def backward_dw(self) -> NoReturn: """Execute weight update operations""" self._backward_qkv_proj() diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 235b6f6af..fbcffe278 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -566,6 +566,9 @@ def topk_routing_with_score_function( else: return torch.topk(scores, k=topk, dim=1) + from slime.utils.routing_replay import get_routing_replay_compute_topk + compute_topk = get_routing_replay_compute_topk(compute_topk) + if score_function == "softmax": if use_pre_softmax: scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 6b20b8622..459e65921 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -156,6 +156,9 @@ class TopKRouter(Router): self.local_tokens_per_expert = None self.expert_bias = None + from slime.utils.routing_replay import register_routing_replay + register_routing_replay(self) + def _maintain_float32_expert_bias(self): """ Maintain the expert bias in float32. diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index b7884e18e..f0104f861 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union import torch from torch import Tensor +import warnings from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -105,17 +106,21 @@ def tie_output_layer_state_dict( ) -def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): - """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support. - This function extends the original roll_tensor to support Context Parallelism, which allows - MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, - and tensor rolling requires communication between adjacent CP ranks to properly handle the - boundary conditions. +def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): + """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support. + + This function extends the original roll_tensor to support Context Parallelism and Packed Sequences. + When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires + communication between adjacent CP ranks to properly handle the boundary conditions. + When packed sequences are used, rolling is performed within each individual sequence boundary + to prevent mixing tokens between different packed sequences. For CP=1 (default behavior): Uses standard torch.roll with zero padding For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges boundary elements between adjacent CP ranks to maintain sequence continuity. + For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens. + Args: tensor (Tensor): The input tensor to roll. @@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): dims (int): The dimension to roll (typically -1 for sequence dimension). cp_group (ProcessGroup): The context parallelism process group. If None or size=1, falls back to standard rolling behavior. + packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. + If provided, rolling respects sequence boundaries. Returns: tuple: (rolled_tensor, sum_of_rolled_tensor) """ + + if packed_seq_params is not None: + return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) + # Standard rolling behavior when CP is not enabled (cp_group is None or size=1) if cp_group is None or cp_group.size() == 1: rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) @@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): return rolled_tensor, rolled_tensor.sum() +def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): + """Roll tensor with packed sequence support. + + This function handles rolling for packed sequences by respecting sequence boundaries + defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual + sequence to prevent mixing tokens between different packed sequences. When Context + Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata + so we slice out the portion of every packed sequence that lives on the current rank and + reuse the standard CP boundary exchange to populate the rolling window. + + Args: + tensor (Tensor): The input tensor to roll. + shifts (int): The shift of the tensor (typically -1 for MTP). + dims (int): The dimension to roll (typically -1 for sequence dimension). + packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. + cp_group (ProcessGroup): The context parallelism process group. + + Returns: + tuple: (rolled_tensor, sum_of_rolled_tensor) + """ + + # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once. + assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension." + assert shifts == -1, "Packed sequence roll only supports a single-token left shift." + cu_seqlens = packed_seq_params.cu_seqlens_q + assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q." + + rolled_tensor = tensor.clone() + + cp_size = cp_group.size() if cp_group is not None else 1 + if cp_size == 1: + # CP disabled: simply roll inside each packed sequence boundary. + for i in range(len(cu_seqlens) - 1): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + seq_slice = tensor[..., start_idx:end_idx] + rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) + rolled_seq[..., shifts:] = 0 + rolled_tensor[..., start_idx:end_idx] = rolled_seq + return rolled_tensor, rolled_tensor.sum() + + # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). + local_rank = torch.distributed.get_rank(group=cp_group) + global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) + next_rank = global_ranks[(local_rank + 1) % cp_size] + prev_rank = global_ranks[(local_rank - 1) % cp_size] + + # iterate over each sequence individually + for i in range(len(cu_seqlens) - 1): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + + # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx + local_start_idx = start_idx // cp_size + local_end_idx = end_idx // cp_size + tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() + + # The following code is very similar as the code in roll_tensor function + local_chunks = tensor_slice.chunk(2, dim=dims) + rolled_chunks = [ + torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks + ] + + tensor_send_list = [] + tensor_recv_list = [] + for chunk in rolled_chunks: + boundary = chunk.select(dims, shifts).contiguous().clone() + tensor_send_list.append(boundary) + tensor_recv_list.append(torch.empty_like(boundary)) + + ops = [] + if local_rank != 0: + ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) + ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) + else: + tensor_recv_list[1].zero_() + + if local_rank != cp_size - 1: + ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) + ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) + else: + tensor_recv_list[0].copy_(tensor_send_list[1]) + + for op in ops: + op.wait() + + index = [slice(None)] * rolled_chunks[0].dim() + index[dims] = shifts + for chunk, recv in zip(rolled_chunks, tensor_recv_list): + chunk[tuple(index)] = recv + + seq_result = torch.cat(rolled_chunks, dim=dims) + + # update the rolled tensor + rolled_tensor[..., local_start_idx:local_end_idx] = seq_result + + return rolled_tensor, rolled_tensor.sum() class MTPLossLoggingHelper: """Helper class for logging MTP losses.""" @@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule): def _get_embeddings( self, input_ids: torch.Tensor, - position_ids: torch.Tensor, embedding: Callable, hidden_states: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, ): """ Preprocesses input data for the Multi-Token Prediction (MTP) layers. @@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule): sequence length, b is the batch size, and h is the hidden size. """ # Calc logits for the current Multi-Token Prediction (MTP) layers. - input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group) - position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group) + input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + + # Prepare/roll position ids only when applicable. + if position_ids is None: + # Fallback position ids for learned absolute embedding. + seq_len = input_ids.size(-1) + position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + position_ids, _ = roll_tensor( + position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params + ) # embedding decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + decoder_input = decoder_input.detach() - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) return input_ids, position_ids, decoder_input, hidden_states @@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule): return hidden_states def _checkpointed_forward(self, forward_func, *args, **kwargs): + """Wrap `forward_func` with activation checkpointing while only passing tensors. + + Non-tensor arguments (e.g., configuration objects, None) are captured via closure so + that checkpoint implementations never receive them directly, avoiding save_for_backward + issues with non-tensor inputs. + """ + + # TODO(jiajun): Is there any better implementation here? + positional_specs = [] + kw_specs = [] + tensor_args: List[torch.Tensor] = [] + + for arg in args: + if torch.is_tensor(arg): + positional_specs.append(('tensor', len(tensor_args))) + tensor_args.append(arg) + else: + positional_specs.append(('const', arg)) + + for key, value in kwargs.items(): + if torch.is_tensor(value): + kw_specs.append((key, ('tensor', len(tensor_args)))) + tensor_args.append(value) + else: + kw_specs.append((key, ('const', value))) + + def run(*flat_tensor_args): + rebuilt_args = [] + for spec_type, payload in positional_specs: + if spec_type == 'tensor': + rebuilt_args.append(flat_tensor_args[payload]) + else: + rebuilt_args.append(payload) + + rebuilt_kwargs = {} + for key, (spec_type, payload) in kw_specs: + if spec_type == 'tensor': + rebuilt_kwargs[key] = flat_tensor_args[payload] + else: + rebuilt_kwargs[key] = payload + + return forward_func(*rebuilt_args, **rebuilt_kwargs) + + tensor_args_tuple = tuple(tensor_args) + def checkpoint_handler(): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`.""" if self.config.fp8: from megatron.core.extensions.transformer_engine import te_checkpoint return te_checkpoint( - forward_func, + run, self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, parallel_state.get_tensor_model_parallel_group(), - *args, - **kwargs, + *tensor_args_tuple, ) else: return tensor_parallel.checkpoint( - forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() + run, self.config.distribute_saved_activations, *tensor_args_tuple ) if self.config.recompute_method == 'uniform': @@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule): [s, b, h], and optionally the updated context tensor if cross-attention is used. """ assert context is None, f"multi token prediction + cross attention is not yet supported." - assert ( - packed_seq_params is None - ), f"multi token prediction + sequence packing is not yet supported." input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, position_ids=position_ids, embedding=embedding, hidden_states=hidden_states, + packed_seq_params=packed_seq_params, ) if self.config.recompute_granularity == 'full' and self.training: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d55bebe7e..1eecbbd38 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): qk_layernorm: bool = False """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False + use_gated_attention: bool = False + test_mode: bool = False """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 84f22bdea..f0f3f8e86 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp @@ -232,6 +233,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) @@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) + self.post_self_attn_layernorm = build_module( + submodules.post_self_attn_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, @@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon + ) + self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False @@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): attention_output_with_bias[0] ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) + # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") @@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): else: mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + mlp_output, mlp_output_bias = mlp_output_with_bias + mlp_output = self.post_mlp_layernorm(mlp_output) + mlp_output_with_bias = (mlp_output, mlp_output_bias) + if self.recompute_pre_mlp_layernorm: # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index e3459c5ee..7346bf35b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -937,8 +937,6 @@ def validate_args(args, defaults={}): # MoE Spec check if args.num_experts == 0: args.num_experts = None - if args.num_experts is not None: - assert args.spec is None, "Model Spec must be None when using MoEs" if args.num_experts is not None and args.moe_ffn_hidden_size is None: args.moe_ffn_hidden_size = args.ffn_hidden_size print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") @@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): if args.is_hybrid_model: kw_args['is_hybrid_model'] = args.is_hybrid_model + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm + kw_args['use_gated_attention'] = args.use_gated_attention + # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. @@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') + group.add_argument('--post-self-attn-layernorm', action='store_true', + help='If set, use post self attention layernorm.') + group.add_argument('--post-mlp-layernorm', action='store_true', + help='If set, use post MLP layernorm.') + group.add_argument('--use-gated-attention', action='store_true', + help='If set, use gated attention as in Qwen3Next') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index 5cf222ccc..d1554ca4c 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): f"The transformers library must be installed to use huggingface_tokenizer_provider" ) + if "trust_remote_code" not in kwargs: + kwargs["trust_remote_code"] = True # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there self._tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs