# Copyright 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # there is some bug in mcore 0.12, so we need to patch it # 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None def apply_patch(): import megatron.core import torch import torch.nn.functional as F from megatron.core import parallel_state, tensor_parallel from megatron.core.transformer.multi_latent_attention import ( MLASelfAttention, MultiLatentAttention, apply_rotary_pos_emb, deprecate_inference_params, gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) from packaging import version mcore_ge_013 = version.parse(megatron.core.__version__) >= version.parse("0.13.0") def patch_get_query_key_value_tensors( self, hidden_states, key_value_states=None, position_ids=None, packed_seq_params=None, inference_context=None, *, inference_params=None, ): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ # s = sequence length, b = batch size, h = hidden size, n = num attention heads # Attention heads [s, b, n*h] assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" inference_context = deprecate_inference_params(inference_context, inference_params) # ========================================= # Prepare RoPE and seqlen related params # ========================================= rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_context, None, hidden_states, self.config, packed_seq_params ) # rotary_pos_emb:[s, b, 1, 64] mscale = 1.0 if self.config.rope_type == "rope": packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" try: # In case of TypeError: RotaryEmbedding.forward() got an unexpected keyword argument 'packed_seq' rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) except TypeError: rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) else: rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len) # ========================================= # QKV down projection and layernorm # ========================================= if self.config.q_lora_rank is not None: # if linear_q_down_proj is ColumnParallelLinear: # q_compressed: [s, b, q_lora_rank / TP] # elif linear_q_down_proj is Linear: # q_compressed: [s / TP, b, q_lora_rank] q_compressed, _ = self.linear_q_down_proj(hidden_states) # When output is sharded (ColumnParallelLinear), two things are needed to be # identical to a normal Linear. # 1. Manually gather output to restore output dim q_lora_rank; # 2. Scatter sequence back to s / TP if sequence-parallel since it was # gathered by ColumnParallelLinear. if q_compressed.size(-1) != self.config.q_lora_rank: q_compressed = gather_from_tensor_model_parallel_region(q_compressed) if self.config.sequence_parallel: q_compressed = scatter_to_sequence_parallel_region(q_compressed) q_compressed = self.q_layernorm(q_compressed) else: q_compressed = hidden_states # if linear_kv_down_proj is ColumnParallelLinear: # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] # elif linear_kv_down_proj is Linear: # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined, _ = self.linear_kv_down_proj(hidden_states) if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined = gather_from_tensor_model_parallel_region(kv_combined) # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] kv_compressed, k_pos_emb = torch.split( kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 ) if self.config.sequence_parallel: # kv_compressed:[s / TP, b, kv_lora_rank] kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) else: # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] kv_compressed, k_pos_emb = torch.split( kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 ) if parallel_state.get_tensor_model_parallel_world_size() > 1: # k_pos_emb: [s, b, qk_pos_emb_head_dim] k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) kv_compressed = self.kv_layernorm(kv_compressed) # ========================================= # QKV up projection and RoPE apply # ========================================= def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): if self.config.q_lora_rank is not None: q, _ = self.linear_q_up_proj(q_compressed) else: # hidden_states:[s, b, 2048], q: [s, b, n * 192] q, _ = self.linear_q_proj(q_compressed) q_len, bsz, _ = q.size() # q: [s, b, n, 192] q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim) # kv: [s, b, 2048] kv, _ = self.linear_kv_up_proj(kv_compressed) # kv: [s, b, n, 256] kv = kv.view( q_len, bsz, self.num_attention_heads_per_partition, self.config.qk_head_dim + self.config.v_head_dim, ) cp_size = parallel_state.get_context_parallel_world_size() if inference_context is not None: # add offset to the sequence start for inference sequence_start = inference_context.sequence_len_offset sequence_end = sequence_start + q_len rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] elif packed_seq_params is None or cp_size == 1: # Shorten rotary_pos_emb to the sequence length when inference_params # is not provided. This makes sure we can run forward directly with # any sequence length. During training, the sequence length is always # the full rotary_pos_emb length, except for sequence packing + CP. # When sequence packing and context parallel are both enabled, the # position embedding will not split rotary_pos_emb, so it may exceed # the sequence length on this CP rank, but we need the full rotary_pos_emb # to cover the full sequence, so we do not shorten it here. rotary_pos_emb = rotary_pos_emb[0:q_len] # [s, b, 64] -> [s, b, 1, 64] k_pos_emb = torch.unsqueeze(k_pos_emb, 2) # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64] q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) # k_no_pe: [s, b, n, 128], value: [s, b, n, 128] k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) if packed_seq_params is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q cu_seqlens_kv = packed_seq_params.cu_seqlens_kv q_pos_emb = q_pos_emb.squeeze(1) k_pos_emb = k_pos_emb.squeeze(1) q_no_pe = q_no_pe.squeeze(1) k_no_pe = k_no_pe.squeeze(1) value = value.squeeze(1) else: cu_seqlens_q = cu_seqlens_kv = None # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64] q_pos_emb = apply_rotary_pos_emb( q_pos_emb, rotary_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, mscale=mscale, ) k_pos_emb = apply_rotary_pos_emb( k_pos_emb, rotary_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, mscale=mscale, ) # query: [s, b, n, 192] query = torch.cat([q_no_pe, q_pos_emb], dim=-1) if packed_seq_params is not None: k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) key = torch.cat([k_no_pe, k_pos_emb], dim=-1) else: # key: [s, b, n, 192] k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) key = torch.cat([k_no_pe, k_pos_emb], dim=-1) query = query.contiguous() key = key.contiguous() value = value.contiguous() return query, key, value if self.recompute_up_proj: self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() query, key, value = self.qkv_up_checkpoint.checkpoint( qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb ) else: query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) return query, key, value def patch_forward( self, hidden_states, attention_mask, key_value_states=None, inference_context=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, attention_bias=None, packed_seq_params=None, position_ids=None, sequence_len_offset=None, *, inference_params=None, **kwargs, ): """Forward pass for multi-latent attention""" assert attention_bias is None, "Attention bias should not be passed into MLA." assert rotary_pos_cos is None and rotary_pos_sin is None, "MLA does not support Flash Decoding" # hidden_states: [sq, b, h] inference_context = deprecate_inference_params(inference_context, inference_params) # ===================== # Query, Key, and Value # ===================== # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] qkv = self.get_query_key_value_tensors( hidden_states, key_value_states, position_ids, packed_seq_params, inference_context=inference_context, ) query, key, value = qkv[:3] q_compressed = None # kv_compressed = None if len(qkv) > 4: q_compressed = qkv[3] # kv_compressed = qkv[4] # =================================================== # Adjust key, value for inference # =================================================== # rotary_pos_emb = None if mcore_ge_013: query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference( inference_context, query, key, value, rotary_pos_emb=None ) else: query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference( inference_context, query, key, value, rotary_pos_emb=None ) # TODO: Currently, TE can only accept contiguous tensors for MLA query = query.contiguous() key = key.contiguous() value = value.contiguous() # ================================== # core attention computation # ================================== # Need corresponding TE change non_dsa_thd_qkv_format = ( packed_seq_params and packed_seq_params.qkv_format == "thd" and getattr(self.config, "experimental_attention_variant", None) is None ) v_dim = value.shape[-1] if non_dsa_thd_qkv_format and query.shape[-1] != v_dim: value = F.pad(value, [0, query.shape[-1] - v_dim]) self.core_attention.hidden_size_per_attention_head_v = value.shape[-1] if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params ) else: extra_kwargs = {} if getattr(self.config, "experimental_attention_variant", None) == "dsa": # For dsa we need to pass in the original hidden states and the compressed # query representation. extra_kwargs["x"] = hidden_states extra_kwargs["qr"] = q_compressed core_attn_out = self.core_attention( query, key, value, attention_mask, packed_seq_params=packed_seq_params, attn_mask_type=attn_mask_type, **extra_kwargs, ) if non_dsa_thd_qkv_format: if core_attn_out.ndim == 2: core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-1], -1, value.shape[-1]) if query.shape[-1] != v_dim: core_attn_out = core_attn_out[..., :v_dim] # reshape to same output shape as unpacked case # (t, np, hn) -> (t, b=1, h=np*hn) # t is the pack size = sum (sq_i) # note that batch is a dummy dimension in the packed case core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) if self.recompute_up_proj: assert self.qkv_up_checkpoint is not None self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out) self.qkv_up_checkpoint = None # ================= # Output. [sq, b, h] # ================= output, bias = self.linear_proj(core_attn_out) return output, bias # This patch targets mcore 0.12 MLA behavior only. # For newer mcore, upstream MLA already has packed-seq + CP handling and # overriding it with the legacy implementation can break RoPE shapes. if not mcore_ge_013: MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors MultiLatentAttention.forward = patch_forward def apply_patch_mbridge(): try: from megatron.core.utils import get_tensor_model_parallel_group_if_none except ImportError: import warnings import megatron.core.utils import torch from megatron.core import parallel_state def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True): """Issue a deprecation warning if tp_group is None and return the default tp group.""" if not torch.distributed.is_initialized(): return None if tp_group is None: if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: warnings.warn( "Warning: tp_group is None, using default tp group. Passing tp_group will be mandatory soon", DeprecationWarning, stacklevel=2, ) if is_expert: tp_group = parallel_state.get_expert_tensor_parallel_group(check_initialized=check_initialized) else: tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=check_initialized) return tp_group megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none def apply_patch_megatron_v012_with_torch_v28(): # Error due to missing serialization_format in _write_item of megatron v012; # resolved by using megatron v013's implementation. import inspect import logging import os from pathlib import Path import megatron.core import torch from megatron.core.dist_checkpointing.strategies.async_utils import _disable_gc from megatron.core.dist_checkpointing.strategies.filesystem_async import _process_memory from packaging import version from torch import multiprocessing as mp from torch.distributed.checkpoint.filesystem import _write_item if ( version.parse(torch.__version__).base_version != "2.8.0" or version.parse(megatron.core.__version__).base_version != "0.12.1" ): return WriteBucket = tuple[Path, str, tuple[list, list]] @staticmethod @_disable_gc() def write_preloaded_data_patch( transform_list, local_proc_idx: int, write_bucket: WriteBucket, results_queue: mp.SimpleQueue, count_queue: mp.JoinableQueue, use_fsync: bool, **kwargs, ) -> None: """ Performs actual data saving to storage. Args: local_proc_idx (int): index of a local process that performs writing write_bucket (WriteBucket): data to write to storage results_queue (mp.Queue): queue to return the write results to the proxy checkpoint process. count_queue (mp.JoinableQueue): queue to marks worker task as completed use_fsync (bool): if True, calls os.fsync at the end of saving Returns: None, the write result are put into the `queue` """ logger = logging.getLogger(__name__) logger.debug(f"{local_proc_idx} started") mem_before = _process_memory() use_msc = kwargs.get("use_msc", False) local_results = [] try: file_name, storage_key, (bytes_data, tensor_data) = write_bucket extra_kwargs = {} if "serialization_format" in inspect.signature(_write_item).parameters: from torch.distributed.checkpoint.filesystem import SerializationFormat extra_kwargs["serialization_format"] = SerializationFormat.TORCH_SAVE if use_msc: import multistorageclient as msc open_file = msc.open else: open_file = open with open_file(file_name, "wb") as stream: for write_item, data in bytes_data: local_results.append( _write_item(*transform_list, stream, data, write_item, storage_key, **extra_kwargs) ) for write_item, tensor in tensor_data: assert tensor.is_cpu local_results.append( _write_item(*transform_list, stream, tensor, write_item, storage_key, **extra_kwargs) ) if use_fsync: if use_msc: stream.fsync() else: os.fsync(stream.fileno()) local_output = (local_proc_idx, local_results) except Exception as e: logger.debug(f"{local_proc_idx} failed") local_output = (local_proc_idx, e) # type: ignore[assignment] results_queue.put(local_output) # Signal this process is done. count_queue.get() count_queue.task_done() mem_after = _process_memory() logger.debug(f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}") from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync FileSystemWriterAsync.write_preloaded_data = write_preloaded_data_patch