LeTue09's picture
initial clean commit
1faccd4
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# there is some bug in mcore 0.12, so we need to patch it
# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None
def apply_patch():
import megatron.core
import torch
import torch.nn.functional as F
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MultiLatentAttention,
apply_rotary_pos_emb,
deprecate_inference_params,
gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from packaging import version
mcore_ge_013 = version.parse(megatron.core.__version__) >= version.parse("0.13.0")
def patch_get_query_key_value_tensors(
self,
hidden_states,
key_value_states=None,
position_ids=None,
packed_seq_params=None,
inference_context=None,
*,
inference_params=None,
):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# s = sequence length, b = batch size, h = hidden size, n = num attention heads
# Attention heads [s, b, n*h]
assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D"
inference_context = deprecate_inference_params(inference_context, inference_params)
# =========================================
# Prepare RoPE and seqlen related params
# =========================================
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_context, None, hidden_states, self.config, packed_seq_params
)
# rotary_pos_emb:[s, b, 1, 64]
mscale = 1.0
if self.config.rope_type == "rope":
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd"
try:
# In case of TypeError: RotaryEmbedding.forward() got an unexpected keyword argument 'packed_seq'
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
except TypeError:
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
else:
rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len)
# =========================================
# QKV down projection and layernorm
# =========================================
if self.config.q_lora_rank is not None:
# if linear_q_down_proj is ColumnParallelLinear:
# q_compressed: [s, b, q_lora_rank / TP]
# elif linear_q_down_proj is Linear:
# q_compressed: [s / TP, b, q_lora_rank]
q_compressed, _ = self.linear_q_down_proj(hidden_states)
# When output is sharded (ColumnParallelLinear), two things are needed to be
# identical to a normal Linear.
# 1. Manually gather output to restore output dim q_lora_rank;
# 2. Scatter sequence back to s / TP if sequence-parallel since it was
# gathered by ColumnParallelLinear.
if q_compressed.size(-1) != self.config.q_lora_rank:
q_compressed = gather_from_tensor_model_parallel_region(q_compressed)
if self.config.sequence_parallel:
q_compressed = scatter_to_sequence_parallel_region(q_compressed)
q_compressed = self.q_layernorm(q_compressed)
else:
q_compressed = hidden_states
# if linear_kv_down_proj is ColumnParallelLinear:
# kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]
# elif linear_kv_down_proj is Linear:
# kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined, _ = self.linear_kv_down_proj(hidden_states)
if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:
# kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined = gather_from_tensor_model_parallel_region(kv_combined)
# kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]
kv_compressed, k_pos_emb = torch.split(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
)
if self.config.sequence_parallel:
# kv_compressed:[s / TP, b, kv_lora_rank]
kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)
else:
# kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]
kv_compressed, k_pos_emb = torch.split(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
)
if parallel_state.get_tensor_model_parallel_world_size() > 1:
# k_pos_emb: [s, b, qk_pos_emb_head_dim]
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)
kv_compressed = self.kv_layernorm(kv_compressed)
# =========================================
# QKV up projection and RoPE apply
# =========================================
def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):
if self.config.q_lora_rank is not None:
q, _ = self.linear_q_up_proj(q_compressed)
else:
# hidden_states:[s, b, 2048], q: [s, b, n * 192]
q, _ = self.linear_q_proj(q_compressed)
q_len, bsz, _ = q.size()
# q: [s, b, n, 192]
q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)
# kv: [s, b, 2048]
kv, _ = self.linear_kv_up_proj(kv_compressed)
# kv: [s, b, n, 256]
kv = kv.view(
q_len,
bsz,
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.v_head_dim,
)
cp_size = parallel_state.get_context_parallel_world_size()
if inference_context is not None:
# add offset to the sequence start for inference
sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + q_len
rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]
elif packed_seq_params is None or cp_size == 1:
# Shorten rotary_pos_emb to the sequence length when inference_params
# is not provided. This makes sure we can run forward directly with
# any sequence length. During training, the sequence length is always
# the full rotary_pos_emb length, except for sequence packing + CP.
# When sequence packing and context parallel are both enabled, the
# position embedding will not split rotary_pos_emb, so it may exceed
# the sequence length on this CP rank, but we need the full rotary_pos_emb
# to cover the full sequence, so we do not shorten it here.
rotary_pos_emb = rotary_pos_emb[0:q_len]
# [s, b, 64] -> [s, b, 1, 64]
k_pos_emb = torch.unsqueeze(k_pos_emb, 2)
# q: [s, b, n, 128], q_pos_emb: [s, b, n, 64]
q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)
# k_no_pe: [s, b, n, 128], value: [s, b, n, 128]
k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
q_pos_emb = q_pos_emb.squeeze(1)
k_pos_emb = k_pos_emb.squeeze(1)
q_no_pe = q_no_pe.squeeze(1)
k_no_pe = k_no_pe.squeeze(1)
value = value.squeeze(1)
else:
cu_seqlens_q = cu_seqlens_kv = None
# q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64]
q_pos_emb = apply_rotary_pos_emb(
q_pos_emb,
rotary_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
mscale=mscale,
)
k_pos_emb = apply_rotary_pos_emb(
k_pos_emb,
rotary_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_kv,
mscale=mscale,
)
# query: [s, b, n, 192]
query = torch.cat([q_no_pe, q_pos_emb], dim=-1)
if packed_seq_params is not None:
k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)
else:
# key: [s, b, n, 192]
k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
return query, key, value
if self.recompute_up_proj:
self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()
query, key, value = self.qkv_up_checkpoint.checkpoint(
qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb
)
else:
query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)
return query, key, value
def patch_forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
position_ids=None,
sequence_len_offset=None,
*,
inference_params=None,
**kwargs,
):
"""Forward pass for multi-latent attention"""
assert attention_bias is None, "Attention bias should not be passed into MLA."
assert rotary_pos_cos is None and rotary_pos_sin is None, "MLA does not support Flash Decoding"
# hidden_states: [sq, b, h]
inference_context = deprecate_inference_params(inference_context, inference_params)
# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
# query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
qkv = self.get_query_key_value_tensors(
hidden_states,
key_value_states,
position_ids,
packed_seq_params,
inference_context=inference_context,
)
query, key, value = qkv[:3]
q_compressed = None
# kv_compressed = None
if len(qkv) > 4:
q_compressed = qkv[3]
# kv_compressed = qkv[4]
# ===================================================
# Adjust key, value for inference
# ===================================================
# rotary_pos_emb = None
if mcore_ge_013:
query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference(
inference_context, query, key, value, rotary_pos_emb=None
)
else:
query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference(
inference_context, query, key, value, rotary_pos_emb=None
)
# TODO: Currently, TE can only accept contiguous tensors for MLA
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# ==================================
# core attention computation
# ==================================
# Need corresponding TE change
non_dsa_thd_qkv_format = (
packed_seq_params
and packed_seq_params.qkv_format == "thd"
and getattr(self.config, "experimental_attention_variant", None) is None
)
v_dim = value.shape[-1]
if non_dsa_thd_qkv_format and query.shape[-1] != v_dim:
value = F.pad(value, [0, query.shape[-1] - v_dim])
self.core_attention.hidden_size_per_attention_head_v = value.shape[-1]
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query, key, value, attention_mask, packed_seq_params=packed_seq_params
)
else:
extra_kwargs = {}
if getattr(self.config, "experimental_attention_variant", None) == "dsa":
# For dsa we need to pass in the original hidden states and the compressed
# query representation.
extra_kwargs["x"] = hidden_states
extra_kwargs["qr"] = q_compressed
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
packed_seq_params=packed_seq_params,
attn_mask_type=attn_mask_type,
**extra_kwargs,
)
if non_dsa_thd_qkv_format:
if core_attn_out.ndim == 2:
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-1], -1, value.shape[-1])
if query.shape[-1] != v_dim:
core_attn_out = core_attn_out[..., :v_dim]
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
if self.recompute_up_proj:
assert self.qkv_up_checkpoint is not None
self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out)
self.qkv_up_checkpoint = None
# =================
# Output. [sq, b, h]
# =================
output, bias = self.linear_proj(core_attn_out)
return output, bias
# This patch targets mcore 0.12 MLA behavior only.
# For newer mcore, upstream MLA already has packed-seq + CP handling and
# overriding it with the legacy implementation can break RoPE shapes.
if not mcore_ge_013:
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
MultiLatentAttention.forward = patch_forward
def apply_patch_mbridge():
try:
from megatron.core.utils import get_tensor_model_parallel_group_if_none
except ImportError:
import warnings
import megatron.core.utils
import torch
from megatron.core import parallel_state
def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True):
"""Issue a deprecation warning if tp_group is None and return the default tp group."""
if not torch.distributed.is_initialized():
return None
if tp_group is None:
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
warnings.warn(
"Warning: tp_group is None, using default tp group. Passing tp_group will be mandatory soon",
DeprecationWarning,
stacklevel=2,
)
if is_expert:
tp_group = parallel_state.get_expert_tensor_parallel_group(check_initialized=check_initialized)
else:
tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=check_initialized)
return tp_group
megatron.core.utils.get_tensor_model_parallel_group_if_none = get_tensor_model_parallel_group_if_none
def apply_patch_megatron_v012_with_torch_v28():
# Error due to missing serialization_format in _write_item of megatron v012;
# resolved by using megatron v013's implementation.
import inspect
import logging
import os
from pathlib import Path
import megatron.core
import torch
from megatron.core.dist_checkpointing.strategies.async_utils import _disable_gc
from megatron.core.dist_checkpointing.strategies.filesystem_async import _process_memory
from packaging import version
from torch import multiprocessing as mp
from torch.distributed.checkpoint.filesystem import _write_item
if (
version.parse(torch.__version__).base_version != "2.8.0"
or version.parse(megatron.core.__version__).base_version != "0.12.1"
):
return
WriteBucket = tuple[Path, str, tuple[list, list]]
@staticmethod
@_disable_gc()
def write_preloaded_data_patch(
transform_list,
local_proc_idx: int,
write_bucket: WriteBucket,
results_queue: mp.SimpleQueue,
count_queue: mp.JoinableQueue,
use_fsync: bool,
**kwargs,
) -> None:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the `queue`
"""
logger = logging.getLogger(__name__)
logger.debug(f"{local_proc_idx} started")
mem_before = _process_memory()
use_msc = kwargs.get("use_msc", False)
local_results = []
try:
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
extra_kwargs = {}
if "serialization_format" in inspect.signature(_write_item).parameters:
from torch.distributed.checkpoint.filesystem import SerializationFormat
extra_kwargs["serialization_format"] = SerializationFormat.TORCH_SAVE
if use_msc:
import multistorageclient as msc
open_file = msc.open
else:
open_file = open
with open_file(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(
_write_item(*transform_list, stream, data, write_item, storage_key, **extra_kwargs)
)
for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(
_write_item(*transform_list, stream, tensor, write_item, storage_key, **extra_kwargs)
)
if use_fsync:
if use_msc:
stream.fsync()
else:
os.fsync(stream.fileno())
local_output = (local_proc_idx, local_results)
except Exception as e:
logger.debug(f"{local_proc_idx} failed")
local_output = (local_proc_idx, e) # type: ignore[assignment]
results_queue.put(local_output)
# Signal this process is done.
count_queue.get()
count_queue.task_done()
mem_after = _process_memory()
logger.debug(f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}")
from megatron.core.dist_checkpointing.strategies.filesystem_async import FileSystemWriterAsync
FileSystemWriterAsync.write_preloaded_data = write_preloaded_data_patch