| |
| |
| |
| |
| @@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): |
| msc = MultiStorageClientFeature.import_package() |
| return msc.torch.load(load_path, map_location='cpu') |
| else: |
| - return torch.load(load_path, map_location='cpu') |
| + return torch.load(load_path, map_location='cpu', weights_only=False) |
| except FileNotFoundError as e: |
| err_msg = f'Common file {load_path} does not exist' |
| if MultiStorageClientFeature.is_enabled(): |
| |
| |
| |
| |
| @@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li |
| _restore_dict_types(x_val, templ_val) |
| |
| |
| +@dataclass |
| +class MCoreMetadata(Metadata): |
| + """Metadata with mcore specific data.""" |
| + |
| + # holds data related to flattened_range |
| + # TODO: remove when flattened_range is properly removed |
| + mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor |
| + |
| + |
| @dataclass(frozen=True) |
| class MCoreSavePlan(SavePlan): |
| """SavePlan with MCore specific data.""" |
| @@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner): |
| def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: |
| """Merges MCore data for all plans.""" |
| global_plan, metadata = super().create_global_plan(all_plans) |
| - metadata.mcore_data = dict( |
| + mcore_data = dict( |
| ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type] |
| ) |
| + metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata)) |
| return global_plan, metadata |
| |
| def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: |
| @@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): |
| def _validate_global_shapes(self, metadata, sharded_tensors): |
| for sh_ten in sharded_tensors: |
| if sh_ten.key not in metadata.state_dict_metadata: |
| - raise KeyError( |
| - f"{sh_ten.key} from model not in state dict:" |
| - f" {sorted(metadata.state_dict_metadata.keys())}" |
| - ) |
| + # raise KeyError( |
| + # f"{sh_ten.key} from model not in state dict:" |
| + # f" {sorted(metadata.state_dict_metadata.keys())}" |
| + # ) |
| + print(f"{sh_ten.key} from model not in state dict, will skip") |
| + continue |
| loaded_shape = metadata.state_dict_metadata[sh_ten.key].size |
| expected_shape = self._expected_shape(sh_ten) |
| if loaded_shape != expected_shape: |
| @@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): |
| tensor_metadata = self.metadata.state_dict_metadata |
| metadata_with_sizes = [ |
| (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) |
| - for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() |
| + for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata |
| ] |
| try: |
| # Temporarily set sizes to expected shapes |
| @@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): |
| planner=MCoreLoadPlanner( |
| shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, |
| allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, |
| + allow_partial_load=True, |
| ), |
| ) |
| |
| |
| |
| |
| |
| @@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads |
| from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel |
| from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel |
| from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig |
| + |
| +# Backward compatibility patch for FSDP module reorganization |
| +import sys |
| +import importlib.util |
| + |
| +spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') |
| +if spec: |
| + custom_fsdp = importlib.util.module_from_spec(spec) |
| + spec.loader.exec_module(custom_fsdp) |
| + sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp |
| + if hasattr(custom_fsdp, 'MegatronFSDP'): |
| + custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP |
| |
| |
| |
| |
| @@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear): |
| ) |
| |
| for param in self.parameters(): |
| + setattr(param, "parallel_mode", parallel_mode) |
| if is_expert: |
| # Reduce the gradient on the expert_data_parallel group for expert linear layers |
| setattr(param, "allreduce", not self.expert_parallel) |
| |
| |
| |
| |
| @@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec( |
| qk_l2_norm: Optional[bool] = False, |
| use_te_op_fuser: Optional[bool] = False, |
| use_kitchen: bool = False, |
| + post_self_attn_layernorm: bool = False, |
| + post_mlp_layernorm: bool = False, |
| ) -> ModuleSpec: |
| """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). |
| |
| @@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec( |
| ), |
| ), |
| self_attn_bda=get_bias_dropout_add, |
| + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, |
| pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, |
| mlp=mlp, |
| mlp_bda=get_bias_dropout_add, |
| + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, |
| sharded_state_dict_keys_map={ |
| "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", |
| "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", |
| |
| |
| |
| |
| @@ -355,6 +355,7 @@ class GPTModel(LanguageModule): |
| *, |
| inference_params: Optional[BaseInferenceContext] = None, |
| loss_mask: Optional[Tensor] = None, |
| + mtp_kwargs: Optional[dict] = {}, |
| ) -> Tensor: |
| """Forward function of the GPT Model This function passes the input tensors |
| through the embedding layer, and then the decoeder and finally into the post |
| @@ -410,6 +411,7 @@ class GPTModel(LanguageModule): |
| runtime_gather_output=runtime_gather_output, |
| extra_block_kwargs=extra_block_kwargs, |
| inference_context=inference_context, |
| + mtp_kwargs=mtp_kwargs, |
| ) |
| |
| def _postprocess( |
| @@ -431,6 +433,7 @@ class GPTModel(LanguageModule): |
| runtime_gather_output=None, |
| extra_block_kwargs=None, |
| inference_context=None, |
| + mtp_kwargs={}, |
| ): |
| """Postprocesses decoder hidden states to generate logits or compute loss. |
| |
| @@ -446,7 +449,7 @@ class GPTModel(LanguageModule): |
| if self.share_embeddings_and_output_weights: |
| output_weight = self.shared_embedding_or_output_weight() |
| |
| - if mtp_in_postprocess: |
| + if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: |
| hidden_states = self.mtp( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| @@ -465,25 +468,37 @@ class GPTModel(LanguageModule): |
| if not self.post_process: |
| return hidden_states |
| |
| - if self.mtp_process: |
| - mtp_labels = labels.clone() |
| + if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None: |
| + mtp_labels = mtp_kwargs['mtp_labels'].clone() |
| + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) |
| + |
| hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
| hidden_states = hidden_states_list[0] |
| if loss_mask is None: |
| # if loss_mask is not provided, use all ones as loss_mask |
| loss_mask = torch.ones_like(mtp_labels) |
| + else: |
| + # Otherwise, roll the loss_mask to keep up with the mtp_labels |
| + loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) |
| for mtp_layer_number in range(self.config.mtp_num_layers): |
| # output |
| - mtp_logits, _ = self.output_layer( |
| - hidden_states_list[mtp_layer_number + 1], |
| - weight=output_weight, |
| - runtime_gather_output=runtime_gather_output, |
| + output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()} |
| + output_layer_buffers = dict(self.output_layer.named_buffers()) |
| + mtp_logits, _ = torch.func.functional_call( |
| + self.output_layer, |
| + {**output_layer_params, **output_layer_buffers}, |
| + (hidden_states_list[mtp_layer_number + 1],), |
| + { |
| + "weight": output_weight.detach() if output_weight else None, |
| + "runtime_gather_output": runtime_gather_output, |
| + }, |
| ) |
| # Calc loss for the current Multi-Token Prediction (MTP) layers. |
| - mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) |
| - loss_mask, num_tokens = roll_tensor( |
| - loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group |
| + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) |
| + new_loss_mask, num_tokens = roll_tensor( |
| + loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params |
| ) |
| + loss_mask = new_loss_mask * loss_mask |
| mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) |
| mtp_loss = loss_mask * mtp_loss |
| if self.training: |
| |
| |
| |
| |
| @@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): |
| # TE FusedAdam will not accumulate step for empty param groups, so we need to |
| # align the step across param groups. |
| param_group["step"] = int(step) |
| + if "step" in param_group and param_group["step"] is None: |
| + del param_group["step"] |
| |
| # Grad scaler state. |
| if self.grad_scaler: |
| |
| |
| |
| |
| @@ -9,6 +9,7 @@ from typing import Callable, List, Optional |
| |
| import numpy as np |
| import torch |
| +import torch.distributed as dist |
| |
| from .utils import GlobalMemoryBuffer, is_torch_min_version |
| |
| |
| |
| |
| |
| @@ -26,22 +26,22 @@ def _batched_p2p_ops( |
| ops = [] |
| if tensor_send_prev is not None: |
| send_prev_op = torch.distributed.P2POp( |
| - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group |
| + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, |
| ) |
| ops.append(send_prev_op) |
| if tensor_recv_prev is not None: |
| recv_prev_op = torch.distributed.P2POp( |
| - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group |
| + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, |
| ) |
| ops.append(recv_prev_op) |
| if tensor_send_next is not None: |
| send_next_op = torch.distributed.P2POp( |
| - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group |
| + torch.distributed.isend, tensor_send_next, next_pipeline_rank, |
| ) |
| ops.append(send_next_op) |
| if tensor_recv_next is not None: |
| recv_next_op = torch.distributed.P2POp( |
| - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group |
| + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, |
| ) |
| ops.append(recv_next_op) |
| if len(ops) > 0: |
| |
| |
| |
| |
| @@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): |
| # Get the query, key and value tensors based on the type of attention - |
| # self or cross attn. |
| nvtx_range_push(suffix="qkv") |
| - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) |
| + if self.config.use_gated_attention: |
| + query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) |
| + else: |
| + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) |
| nvtx_range_pop(suffix="qkv") |
| |
| # |
| @@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): |
| # Output. [sq, b, h] |
| # |
| |
| + if self.config.use_gated_attention: |
| + nvtx_range_push(suffix="sigmoid_gate") |
| + core_attn_out = core_attn_out * torch.sigmoid(gate) |
| + nvtx_range_pop(suffix="sigmoid_gate") |
| + |
| nvtx_range_push(suffix="linear_proj") |
| output, bias = self.linear_proj(core_attn_out) |
| nvtx_range_pop(suffix="linear_proj") |
| @@ -879,19 +887,34 @@ class SelfAttention(Attention): |
| model_comm_pgs=model_comm_pgs, |
| ) |
| |
| - self.linear_qkv = build_module( |
| - submodules.linear_qkv, |
| - self.config.hidden_size, |
| - self.query_projection_size + 2 * self.kv_projection_size, |
| - config=self.config, |
| - init_method=self.config.init_method, |
| - gather_output=False, |
| - bias=self.config.add_bias_linear or self.config.add_qkv_bias, |
| - skip_bias_add=False, |
| - is_expert=False, |
| - tp_comm_buffer_name='qkv', |
| - tp_group=self.model_comm_pgs.tp, |
| - ) |
| + if self.config.use_gated_attention: |
| + self.linear_qgkv = build_module( |
| + submodules.linear_qkv, |
| + self.config.hidden_size, |
| + 2 * (self.query_projection_size + self.kv_projection_size), |
| + config=self.config, |
| + init_method=self.config.init_method, |
| + gather_output=False, |
| + bias=self.config.add_bias_linear or self.config.add_qkv_bias, |
| + skip_bias_add=False, |
| + is_expert=False, |
| + tp_comm_buffer_name='qkv', |
| + tp_group=self.model_comm_pgs.tp, |
| + ) |
| + else: |
| + self.linear_qkv = build_module( |
| + submodules.linear_qkv, |
| + self.config.hidden_size, |
| + self.query_projection_size + 2 * self.kv_projection_size, |
| + config=self.config, |
| + init_method=self.config.init_method, |
| + gather_output=False, |
| + bias=self.config.add_bias_linear or self.config.add_qkv_bias, |
| + skip_bias_add=False, |
| + is_expert=False, |
| + tp_comm_buffer_name='qkv', |
| + tp_group=self.model_comm_pgs.tp, |
| + ) |
| |
| if submodules.q_layernorm is not None: |
| self.q_layernorm = build_module( |
| @@ -1036,6 +1059,65 @@ class SelfAttention(Attention): |
| |
| return query, key, value |
| |
| + # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 |
| + def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): |
| + """ |
| + Derives `query`, `key` and `value` tensors from `hidden_states`. |
| + """ |
| + # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] |
| + mixed_qgkv, _ = self.linear_qgkv(hidden_states) |
| + |
| + # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] |
| + new_tensor_shape = mixed_qgkv.size()[:-1] + ( |
| + self.num_query_groups_per_partition, |
| + ( |
| + 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) |
| + * self.hidden_size_per_attention_head |
| + ), |
| + ) |
| + mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) |
| + |
| + split_arg_list = [ |
| + ( |
| + self.num_attention_heads_per_partition |
| + // self.num_query_groups_per_partition |
| + * self.hidden_size_per_attention_head |
| + ), |
| + ( |
| + self.num_attention_heads_per_partition |
| + // self.num_query_groups_per_partition |
| + * self.hidden_size_per_attention_head |
| + ), |
| + self.hidden_size_per_attention_head, |
| + self.hidden_size_per_attention_head, |
| + ] |
| + |
| + if SplitAlongDim is not None: |
| + |
| + # [sq, b, ng, (np/ng + 2) * hn] |
| + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] |
| + (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) |
| + else: |
| + |
| + # [sq, b, ng, (np/ng + 2) * hn] |
| + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] |
| + (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) |
| + |
| + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] |
| + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) |
| + gate = gate.reshape(query.size(0), query.size(1), -1) |
| + |
| + if self.q_layernorm is not None: |
| + query = self.q_layernorm(query) |
| + |
| + if self.k_layernorm is not None: |
| + key = self.k_layernorm(key) |
| + |
| + if self.config.test_mode: |
| + self.run_realtime_tests() |
| + |
| + return query, gate, key, value |
| + |
| def backward_dw(self) -> NoReturn: |
| """Execute weight update operations""" |
| self._backward_qkv_proj() |
| |
| |
| |
| |
| @@ -566,6 +566,9 @@ def topk_routing_with_score_function( |
| else: |
| return torch.topk(scores, k=topk, dim=1) |
| |
| + from slime.utils.routing_replay import get_routing_replay_compute_topk |
| + compute_topk = get_routing_replay_compute_topk(compute_topk) |
| + |
| if score_function == "softmax": |
| if use_pre_softmax: |
| scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) |
| |
| |
| |
| |
| @@ -156,6 +156,9 @@ class TopKRouter(Router): |
| self.local_tokens_per_expert = None |
| self.expert_bias = None |
| |
| + from slime.utils.routing_replay import register_routing_replay |
| + register_routing_replay(self) |
| + |
| def _maintain_float32_expert_bias(self): |
| """ |
| Maintain the expert bias in float32. |
| |
| |
| |
| |
| @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union |
| |
| import torch |
| from torch import Tensor |
| +import warnings |
| |
| from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel |
| from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
| @@ -105,17 +106,21 @@ def tie_output_layer_state_dict( |
| ) |
| |
| |
| -def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): |
| - """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support. |
| |
| - This function extends the original roll_tensor to support Context Parallelism, which allows |
| - MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, |
| - and tensor rolling requires communication between adjacent CP ranks to properly handle the |
| - boundary conditions. |
| +def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): |
| + """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support. |
| + |
| + This function extends the original roll_tensor to support Context Parallelism and Packed Sequences. |
| + When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires |
| + communication between adjacent CP ranks to properly handle the boundary conditions. |
| + When packed sequences are used, rolling is performed within each individual sequence boundary |
| + to prevent mixing tokens between different packed sequences. |
| |
| For CP=1 (default behavior): Uses standard torch.roll with zero padding |
| For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges |
| boundary elements between adjacent CP ranks to maintain sequence continuity. |
| + For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens. |
| + |
| |
| Args: |
| tensor (Tensor): The input tensor to roll. |
| @@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): |
| dims (int): The dimension to roll (typically -1 for sequence dimension). |
| cp_group (ProcessGroup): The context parallelism process group. If None or size=1, |
| falls back to standard rolling behavior. |
| + packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. |
| + If provided, rolling respects sequence boundaries. |
| Returns: |
| tuple: (rolled_tensor, sum_of_rolled_tensor) |
| """ |
| + |
| + if packed_seq_params is not None: |
| + return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) |
| + |
| # Standard rolling behavior when CP is not enabled (cp_group is None or size=1) |
| if cp_group is None or cp_group.size() == 1: |
| rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) |
| @@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): |
| |
| return rolled_tensor, rolled_tensor.sum() |
| |
| +def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): |
| + """Roll tensor with packed sequence support. |
| + |
| + This function handles rolling for packed sequences by respecting sequence boundaries |
| + defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual |
| + sequence to prevent mixing tokens between different packed sequences. When Context |
| + Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata |
| + so we slice out the portion of every packed sequence that lives on the current rank and |
| + reuse the standard CP boundary exchange to populate the rolling window. |
| + |
| + Args: |
| + tensor (Tensor): The input tensor to roll. |
| + shifts (int): The shift of the tensor (typically -1 for MTP). |
| + dims (int): The dimension to roll (typically -1 for sequence dimension). |
| + packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. |
| + cp_group (ProcessGroup): The context parallelism process group. |
| + |
| + Returns: |
| + tuple: (rolled_tensor, sum_of_rolled_tensor) |
| + """ |
| + |
| + # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once. |
| + assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension." |
| + assert shifts == -1, "Packed sequence roll only supports a single-token left shift." |
| + cu_seqlens = packed_seq_params.cu_seqlens_q |
| + assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q." |
| + |
| + rolled_tensor = tensor.clone() |
| + |
| + cp_size = cp_group.size() if cp_group is not None else 1 |
| + if cp_size == 1: |
| + # CP disabled: simply roll inside each packed sequence boundary. |
| + for i in range(len(cu_seqlens) - 1): |
| + start_idx = cu_seqlens[i] |
| + end_idx = cu_seqlens[i + 1] |
| + seq_slice = tensor[..., start_idx:end_idx] |
| + rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) |
| + rolled_seq[..., shifts:] = 0 |
| + rolled_tensor[..., start_idx:end_idx] = rolled_seq |
| + return rolled_tensor, rolled_tensor.sum() |
| + |
| + # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). |
| + local_rank = torch.distributed.get_rank(group=cp_group) |
| + global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) |
| + next_rank = global_ranks[(local_rank + 1) % cp_size] |
| + prev_rank = global_ranks[(local_rank - 1) % cp_size] |
| + |
| + # iterate over each sequence individually |
| + for i in range(len(cu_seqlens) - 1): |
| + start_idx = cu_seqlens[i] |
| + end_idx = cu_seqlens[i + 1] |
| + |
| + # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx |
| + local_start_idx = start_idx // cp_size |
| + local_end_idx = end_idx // cp_size |
| + tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() |
| + |
| + # The following code is very similar as the code in roll_tensor function |
| + local_chunks = tensor_slice.chunk(2, dim=dims) |
| + rolled_chunks = [ |
| + torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks |
| + ] |
| + |
| + tensor_send_list = [] |
| + tensor_recv_list = [] |
| + for chunk in rolled_chunks: |
| + boundary = chunk.select(dims, shifts).contiguous().clone() |
| + tensor_send_list.append(boundary) |
| + tensor_recv_list.append(torch.empty_like(boundary)) |
| + |
| + ops = [] |
| + if local_rank != 0: |
| + ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) |
| + ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) |
| + else: |
| + tensor_recv_list[1].zero_() |
| + |
| + if local_rank != cp_size - 1: |
| + ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) |
| + ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) |
| + else: |
| + tensor_recv_list[0].copy_(tensor_send_list[1]) |
| + |
| + for op in ops: |
| + op.wait() |
| + |
| + index = [slice(None)] * rolled_chunks[0].dim() |
| + index[dims] = shifts |
| + for chunk, recv in zip(rolled_chunks, tensor_recv_list): |
| + chunk[tuple(index)] = recv |
| + |
| + seq_result = torch.cat(rolled_chunks, dim=dims) |
| + |
| + # update the rolled tensor |
| + rolled_tensor[..., local_start_idx:local_end_idx] = seq_result |
| + |
| + return rolled_tensor, rolled_tensor.sum() |
| |
| class MTPLossLoggingHelper: |
| """Helper class for logging MTP losses.""" |
| @@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule): |
| def _get_embeddings( |
| self, |
| input_ids: torch.Tensor, |
| - position_ids: torch.Tensor, |
| embedding: Callable, |
| hidden_states: torch.Tensor, |
| + position_ids: Optional[torch.Tensor] = None, |
| + packed_seq_params: Optional[PackedSeqParams] = None, |
| ): |
| """ |
| Preprocesses input data for the Multi-Token Prediction (MTP) layers. |
| @@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule): |
| sequence length, b is the batch size, and h is the hidden size. |
| """ |
| # Calc logits for the current Multi-Token Prediction (MTP) layers. |
| - input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group) |
| - position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group) |
| + input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) |
| + |
| + # Prepare/roll position ids only when applicable. |
| + if position_ids is None: |
| + # Fallback position ids for learned absolute embedding. |
| + seq_len = input_ids.size(-1) |
| + position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) |
| + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| + |
| + position_ids, _ = roll_tensor( |
| + position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params |
| + ) |
| # embedding |
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) |
| + decoder_input = decoder_input.detach() |
| |
| - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) |
| + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) |
| |
| return input_ids, position_ids, decoder_input, hidden_states |
| |
| @@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule): |
| return hidden_states |
| |
| def _checkpointed_forward(self, forward_func, *args, **kwargs): |
| + """Wrap `forward_func` with activation checkpointing while only passing tensors. |
| + |
| + Non-tensor arguments (e.g., configuration objects, None) are captured via closure so |
| + that checkpoint implementations never receive them directly, avoiding save_for_backward |
| + issues with non-tensor inputs. |
| + """ |
| + |
| + # TODO(jiajun): Is there any better implementation here? |
| + positional_specs = [] |
| + kw_specs = [] |
| + tensor_args: List[torch.Tensor] = [] |
| + |
| + for arg in args: |
| + if torch.is_tensor(arg): |
| + positional_specs.append(('tensor', len(tensor_args))) |
| + tensor_args.append(arg) |
| + else: |
| + positional_specs.append(('const', arg)) |
| + |
| + for key, value in kwargs.items(): |
| + if torch.is_tensor(value): |
| + kw_specs.append((key, ('tensor', len(tensor_args)))) |
| + tensor_args.append(value) |
| + else: |
| + kw_specs.append((key, ('const', value))) |
| + |
| + def run(*flat_tensor_args): |
| + rebuilt_args = [] |
| + for spec_type, payload in positional_specs: |
| + if spec_type == 'tensor': |
| + rebuilt_args.append(flat_tensor_args[payload]) |
| + else: |
| + rebuilt_args.append(payload) |
| + |
| + rebuilt_kwargs = {} |
| + for key, (spec_type, payload) in kw_specs: |
| + if spec_type == 'tensor': |
| + rebuilt_kwargs[key] = flat_tensor_args[payload] |
| + else: |
| + rebuilt_kwargs[key] = payload |
| + |
| + return forward_func(*rebuilt_args, **rebuilt_kwargs) |
| + |
| + tensor_args_tuple = tuple(tensor_args) |
| + |
| def checkpoint_handler(): |
| - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" |
| + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`.""" |
| if self.config.fp8: |
| from megatron.core.extensions.transformer_engine import te_checkpoint |
| |
| return te_checkpoint( |
| - forward_func, |
| + run, |
| self.config.distribute_saved_activations, |
| tensor_parallel.random.get_cuda_rng_tracker, |
| parallel_state.get_tensor_model_parallel_group(), |
| - *args, |
| - **kwargs, |
| + *tensor_args_tuple, |
| ) |
| else: |
| return tensor_parallel.checkpoint( |
| - forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() |
| + run, self.config.distribute_saved_activations, *tensor_args_tuple |
| ) |
| |
| if self.config.recompute_method == 'uniform': |
| @@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule): |
| [s, b, h], and optionally the updated context tensor if cross-attention is used. |
| """ |
| assert context is None, f"multi token prediction + cross attention is not yet supported." |
| - assert ( |
| - packed_seq_params is None |
| - ), f"multi token prediction + sequence packing is not yet supported." |
| |
| input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| embedding=embedding, |
| hidden_states=hidden_states, |
| + packed_seq_params=packed_seq_params, |
| ) |
| |
| if self.config.recompute_granularity == 'full' and self.training: |
| |
| |
| |
| |
| @@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): |
| qk_layernorm: bool = False |
| """Whether to apply `normalization` type of normalization to the query and key embeddings.""" |
| |
| + post_self_attn_layernorm: bool = False |
| + post_mlp_layernorm: bool = False |
| + use_gated_attention: bool = False |
| + |
| test_mode: bool = False |
| """Whether to run real-time tests.""" |
| |
| |
| |
| |
| |
| @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: |
| input_layernorm: Union[ModuleSpec, type] = IdentityOp |
| self_attention: Union[ModuleSpec, type] = IdentityOp |
| self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp |
| + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp |
| |
| pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp |
| cross_attention: Union[ModuleSpec, type] = IdentityOp |
| @@ -232,6 +233,7 @@ class TransformerLayerSubmodules: |
| pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp |
| mlp: Union[ModuleSpec, type] = IdentityOp |
| mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp |
| + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp |
| |
| # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method |
| sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) |
| @@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| # [Module 3: BiasDropoutFusion] |
| self.self_attn_bda = build_module(submodules.self_attn_bda) |
| |
| + self.post_self_attn_layernorm = build_module( |
| + submodules.post_self_attn_layernorm, |
| + config=self.config, |
| + hidden_size=self.config.hidden_size, |
| + eps=self.config.layernorm_epsilon, |
| + ) |
| + |
| # [Module 4: Post SelfAttention] Optional Layernorm after self-attn |
| self.pre_cross_attn_layernorm = build_module( |
| submodules.pre_cross_attn_layernorm, |
| @@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| # [Module 9: BiasDropoutFusion] |
| self.mlp_bda = build_module(submodules.mlp_bda) |
| |
| + self.post_mlp_layernorm = build_module( |
| + submodules.post_mlp_layernorm, |
| + config=self.config, |
| + hidden_size=self.config.hidden_size, |
| + eps=self.config.layernorm_epsilon |
| + ) |
| + |
| self.recompute_input_layernorm = False |
| self.recompute_pre_mlp_layernorm = False |
| self.recompute_mlp = False |
| @@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| attention_output_with_bias[0] |
| ) |
| |
| + attention_output, attention_output_bias = attention_output_with_bias |
| + attention_output = self.post_self_attn_layernorm(attention_output) |
| + attention_output_with_bias = (attention_output, attention_output_bias) |
| + |
| # TODO: could we move `bias_dropout_add_exec_handler` itself |
| # inside the module provided in the `bias_dropout_add_spec` module? |
| nvtx_range_push(suffix="self_attn_bda") |
| @@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): |
| else: |
| mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) |
| |
| + mlp_output, mlp_output_bias = mlp_output_with_bias |
| + mlp_output = self.post_mlp_layernorm(mlp_output) |
| + mlp_output_with_bias = (mlp_output, mlp_output_bias) |
| + |
| if self.recompute_pre_mlp_layernorm: |
| # discard the output of the pre-mlp layernorm and register the recompute |
| # as a gradient hook of mlp_output_with_bias[0] |
| |
| |
| |
| |
| @@ -937,8 +937,6 @@ def validate_args(args, defaults={}): |
| # MoE Spec check |
| if args.num_experts == 0: |
| args.num_experts = None |
| - if args.num_experts is not None: |
| - assert args.spec is None, "Model Spec must be None when using MoEs" |
| if args.num_experts is not None and args.moe_ffn_hidden_size is None: |
| args.moe_ffn_hidden_size = args.ffn_hidden_size |
| print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") |
| @@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): |
| if args.is_hybrid_model: |
| kw_args['is_hybrid_model'] = args.is_hybrid_model |
| |
| + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm |
| + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm |
| + kw_args['use_gated_attention'] = args.use_gated_attention |
| + |
| # handle quantization config |
| # NOTE: Kitchen arguments are only added to the namespace when |
| # Kitchen library is available. |
| @@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): |
| action='store_true', |
| help='If set, use original BERT residula connection ' |
| 'ordering.') |
| + group.add_argument('--post-self-attn-layernorm', action='store_true', |
| + help='If set, use post self attention layernorm.') |
| + group.add_argument('--post-mlp-layernorm', action='store_true', |
| + help='If set, use post MLP layernorm.') |
| + group.add_argument('--use-gated-attention', action='store_true', |
| + help='If set, use gated attention as in Qwen3Next') |
| group.add_argument('--openai-gelu', action='store_true', |
| help='Use OpenAIs GeLU implementation. This option' |
| 'should not be used unless for backward compatibility' |
| |
| |
| |
| |
| @@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): |
| f"The transformers library must be installed to use huggingface_tokenizer_provider" |
| ) |
| |
| + if "trust_remote_code" not in kwargs: |
| + kwargs["trust_remote_code"] = True |
| # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there |
| self._tokenizer = transformers.AutoTokenizer.from_pretrained( |
| pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs |
|
|