# SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import functools import json import torch import triton import triton.language as tl from aiter.ops.triton.utils._triton import arch_info from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd from aiter.ops.triton.utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr @triton.jit def _cdiv_fn(x, y): return (x + y - 1) // y @triton.jit def _load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): if offset_first is not None and offset_second is not None: mask = (offset_first[:, None] < boundary_first) & ( offset_second[None, :] < boundary_second ) tensor = tl.load(ptrs, mask=mask, other=0.0) elif offset_first is not None: mask = offset_first[:, None] < boundary_first tensor = tl.load(ptrs, mask=mask, other=0.0) elif offset_second is not None: mask = offset_second[None, :] < boundary_second tensor = tl.load(ptrs, mask=mask, other=0.0) else: tensor = tl.load(ptrs) return tensor @triton.jit def _compute_alibi_block( alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False ): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix # for casual mask we want something like this where (1 is kept and 0 is masked) # seqlen_q = 2 and seqlen_k = 5 # 1 1 1 1 0 # 1 1 1 1 1 # seqlen_q = 5 and seqlen_k = 2 # 0 0 # 0 0 # 0 0 # 1 0 # 1 1 # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False # 1. offs_m[:,None] = [[0], # [1], # 2. offs_m[:,None] + seqlen_k = [[5], # [6], # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], # [4], # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], # [4], [ 4, 3, 2, 1, 0]] # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], # [ -4, -3, -2, -1, 0]], relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) if transpose: return alibi_block.T else: return alibi_block @triton.jit def _attn_fwd_inner( acc, l_i, m_i, q, q_pe, k_ptrs, k_pe_ptrs, v_ptrs, stride_kn, stride_vk, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, sd_mask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRELOAD_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_POW2: tl.constexpr, BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2 SM_SCALE: tl.constexpr, IS_CAUSAL: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, PADDED_HEAD: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, ENABLE_PIPELINING: tl.constexpr, ): RCP_LN2: tl.constexpr = 1.4426950408889634 HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 # loop over k, v, and update accumulator num_stages: tl.constexpr = ( None if ENABLE_PIPELINING else 1 ) # Set num_stages==1 if we want to disable pipelining for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. if MASK_STEPS: k_offs_n = start_n + tl.arange(0, BLOCK_N) else: k_offs_n = None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) if HAS_PE: k_pe = _load_fn( k_pe_ptrs, None, k_offs_n, (BLOCK_DMODEL + BLOCK_DMODEL_PE), seqlen_k, ) if PRELOAD_V: v = _load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. mask = tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1) if MASK_STEPS: # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. # last step might get wasted but that is okay. check if this masking works For # that case. # remove the old if condition # if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): # Though this will unconditionally compute mask_partial at runtime, # the causal for loop does not have the if-else block any more, which # helps instruction scheduling and register pressure. bound_cond = (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0) size_n = start_n + OFFS_N[None, :] mask_partial = size_n < seqlen_k mask = tl.where(bound_cond, mask_partial, mask) # compute masks q_mask = OFFS_M[:, None] < seqlen_q k_mask = (start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k p_mask = q_mask & k_mask qk_scale = SM_SCALE * RCP_LN2 # -- compute qk ---- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if HAS_PE: qk += tl.dot(q_pe, k_pe) qk += tl.dot(q, k) if IS_FP8: qk = qk * (qk_scale * descale_q * descale_k) else: qk = qk * qk_scale if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] mask = mask & causal_mask qk = tl.where(mask, qk, float("-inf")) if alibi_slope is not None: # Compute the global position of each token within the sequence global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) alibi_block = _compute_alibi_block( alibi_slope, seqlen_q, seqlen_k, global_m_positions, global_n_positions ) qk += alibi_block * RCP_LN2 # get max scores so far m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Compute scaled QK and softmax probabilities p = tl.math.exp2(qk - m_ij[:, None]) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: rng_output = tl.rand( philox_seed, philox_ptrs ) # TODO: use tl.randint for better performance dropout_mask = rng_output > dropout_p tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) # return scores with negative values for dropped vals sd_mask = tl.where(dropout_mask, p, -p) tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that tl.store(sd_mask_ptrs, p, mask=p_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes # store the diff in maxes to adjust acc and li as we discover new maxes alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij if not PRELOAD_V: v = _load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) if IS_FP8: scale_p, descale_p = _compute_fp8_scaling_factors(p, FP8_MAX) acc += ( tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v ) else: acc += tl.dot(p.to(v.type.element_ty), v) k_ptrs += BLOCK_N * stride_kn if HAS_PE: k_pe_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if RETURN_SCORES: sd_mask_ptrs += BLOCK_N * stride_sn if ENABLE_DROPOUT: dropout_mask_ptrs += BLOCK_N * stride_sn philox_ptrs += BLOCK_N * stride_sn return acc, l_i, m_i _attn_fwd_repr = make_kernel_repr( "_attn_fwd", [ "IS_CAUSAL", "NUM_Q_HEADS", "NUM_K_HEADS", "BLOCK_M", "BLOCK_N", "BLOCK_DMODEL", "RETURN_SCORES", "ENABLE_DROPOUT", "IS_FP8", "VARLEN", "NUM_XCD", "USE_INT64_STRIDES", "ENABLE_SINK", ], ) @triton.jit(repr=_attn_fwd_repr) def _attn_fwd( q_ptr: torch.Tensor, k_ptr: torch.Tensor, v_ptr: torch.Tensor, descale_q_ptr: torch.Tensor, descale_k_ptr: torch.Tensor, descale_v_ptr: torch.Tensor, out_ptr: torch.Tensor, alibi_slopes_ptr: torch.Tensor, s_dmask_ptr: torch.Tensor, dropout_mask_ptr: torch.Tensor, softmax_lse_ptr: torch.Tensor, sink_ptr: torch.Tensor, stride_qz_in, stride_qh_in, stride_qm_in, stride_qk_in, stride_kz_in, stride_kh_in, stride_kn_in, stride_kk_in, stride_vz_in, stride_vh_in, stride_vn_in, stride_vk_in, stride_descale_q_z_in, stride_descale_k_z_in, stride_descale_v_z_in, stride_oz_in, stride_oh_in, stride_om_in, stride_on_in, stride_alibi_z_in, stride_alibi_h_in, stride_sd_z_in, stride_sd_h_in, stride_sd_m_in, stride_sd_n_in, stride_lse_z_in, stride_lse_h_in, stride_lse_m_in, sm_scale, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset_base_in, SEQLEN_Q, SEQLEN_K, IS_CAUSAL: tl.constexpr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, PRELOAD_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_POW2: tl.constexpr, BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2 RETURN_SCORES: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, VARLEN: tl.constexpr, BATCH, NUM_XCD: tl.constexpr, USE_INT64_STRIDES: tl.constexpr, ENABLE_SINK: tl.constexpr, ): NUM_BLOCKS = (SEQLEN_Q + BLOCK_M - 1) // BLOCK_M # calculate offsets wid = tl.program_id( 0 ) # workgroup id ranging: 0,1,2,...., (BATCH * NUM_Q_HEADS * NUM_BLOCKS - 1) # num blocks along seqlen off_q_head = wid % NUM_Q_HEADS off_q_head = remap_xcd(off_q_head, NUM_Q_HEADS, NUM_XCD) start_m = (wid // NUM_Q_HEADS) % NUM_BLOCKS off_z = (wid // (NUM_BLOCKS * NUM_Q_HEADS)) % BATCH # offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL_POW2) HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 if HAS_PE: offs_pe = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL_PE) # NOTE: # Workaround for int64 strides, In the absence of strides being int64, parts of the offset # computation is done in 32 bit and overflows resulting in segfaults # If input strides are defined as int64, it disables vectorized loads which drops perf # If we define new strides as stride_x = stride_x_in.to(tl.int64), that does not work # because strides are tl.constexpr and cannot be upcasted # If we define new strides as stride_x: tl.int64 = stride_x_in, segfault remains # The permanent solution is to enable upcasting of tl.constexpr # In the meantime, the following workaround provides correctness and does not drop perf if USE_INT64_STRIDES: stride_qz = tl.cast(stride_qz_in, tl.int64) stride_qh = tl.cast(stride_qh_in, tl.int64) stride_qm = tl.cast(stride_qm_in, tl.int64) stride_qk = tl.cast(stride_qk_in, tl.int64) stride_kz = tl.cast(stride_kz_in, tl.int64) stride_kh = tl.cast(stride_kh_in, tl.int64) stride_kn = tl.cast(stride_kn_in, tl.int64) stride_kk = tl.cast(stride_kk_in, tl.int64) stride_vz = tl.cast(stride_vz_in, tl.int64) stride_vh = tl.cast(stride_vh_in, tl.int64) stride_vn = tl.cast(stride_vn_in, tl.int64) stride_vk = tl.cast(stride_vk_in, tl.int64) if IS_FP8: stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) stride_oz = tl.cast(stride_oz_in, tl.int64) stride_oh = tl.cast(stride_oh_in, tl.int64) stride_om = tl.cast(stride_om_in, tl.int64) stride_on = tl.cast(stride_on_in, tl.int64) stride_alibi_z = tl.cast(stride_alibi_z_in, tl.int64) stride_alibi_h = tl.cast(stride_alibi_h_in, tl.int64) # NOTE: philox offset is need in dropout pointer calculations philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) stride_sd_z = tl.cast(stride_sd_z_in, tl.int64) stride_sd_h = tl.cast(stride_sd_h_in, tl.int64) stride_sd_m = tl.cast(stride_sd_m_in, tl.int64) stride_sd_n = tl.cast(stride_sd_n_in, tl.int64) stride_lse_z = tl.cast(stride_lse_z_in, tl.int64) stride_lse_h = tl.cast(stride_lse_h_in, tl.int64) stride_lse_m = tl.cast(stride_lse_m_in, tl.int64) else: stride_qz = stride_qz_in stride_qm = stride_qm_in stride_qk = stride_qk_in stride_qh = stride_qh_in stride_kz = stride_kz_in stride_kh = stride_kh_in stride_kn = stride_kn_in stride_kk = stride_kk_in stride_vz = stride_vz_in stride_vh = stride_vh_in stride_vn = stride_vn_in stride_vk = stride_vk_in stride_descale_q_z = stride_descale_q_z_in stride_descale_k_z = stride_descale_k_z_in stride_descale_v_z = stride_descale_v_z_in stride_oz = stride_oz_in stride_oh = stride_oh_in stride_om = stride_om_in stride_on = stride_on_in stride_alibi_z = stride_alibi_z_in stride_alibi_h = stride_alibi_h_in philox_offset_base = philox_offset_base_in stride_sd_z = stride_sd_z_in stride_sd_h = stride_sd_h_in stride_sd_m = stride_sd_m_in stride_sd_n = stride_sd_n_in stride_lse_z = stride_lse_z_in stride_lse_h = stride_lse_h_in stride_lse_m = stride_lse_m_in tl.assume(stride_qz_in >= 0) tl.assume(stride_qh_in >= 0) tl.assume(stride_qm_in >= 0) tl.assume(stride_qk_in >= 0) tl.assume(stride_kz_in >= 0) tl.assume(stride_kh_in >= 0) tl.assume(stride_kn_in >= 0) tl.assume(stride_kk_in >= 0) tl.assume(stride_vz_in >= 0) tl.assume(stride_vh_in >= 0) tl.assume(stride_vn_in >= 0) tl.assume(stride_vk_in >= 0) if IS_FP8: tl.assume(stride_descale_q_z_in >= 0) tl.assume(stride_descale_k_z_in >= 0) tl.assume(stride_descale_v_z_in >= 0) tl.assume(stride_oz_in >= 0) tl.assume(stride_oh_in >= 0) tl.assume(stride_om_in >= 0) tl.assume(stride_on_in >= 0) tl.assume(stride_alibi_z_in >= 0) tl.assume(stride_alibi_h_in >= 0) # NOTE: philox offset is need in dropout pointer calculations tl.assume(philox_offset_base_in >= 0) tl.assume(stride_sd_z_in >= 0) tl.assume(stride_sd_h_in >= 0) tl.assume(stride_sd_m_in >= 0) tl.assume(stride_sd_n_in >= 0) tl.assume(stride_lse_z_in >= 0) tl.assume(stride_lse_h_in >= 0) tl.assume(stride_lse_m_in >= 0) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start # We have a one-size-fits-all grid in id(0). Some seqlens might be too # small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 seqlen_q = SEQLEN_Q seqlen_k = SEQLEN_K n_blocks = _cdiv_fn(seqlen_k, BLOCK_N) # Now we compute whether we need to exit early due to causal masking. # This is because for seqlen_q > seqlen_k, M rows of the attn scores # are completely masked, resulting in 0s written to the output, and # inf written to LSE. We don't need to do any GEMMs in this case. # This block of code determines what N is, and if this WG is operating # on those M rows. if IS_CAUSAL: # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. # This captures the decrease in n_blocks if we have a rectangular attn matrix n_blocks_seqlen = _cdiv_fn( (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N ) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is part of # the blocks that are all 0. We exit early. if n_blocks <= 0: offs_out = ( off_z * stride_oz + off_q_head * stride_oh + cu_seqlens_q_start * stride_om + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on ) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) out_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) tl.store(out_ptr + offs_out, acc, mask=out_mask) if softmax_lse_ptr is not None: offs_lse = ( off_z * stride_lse_z + off_q_head * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m ) lse_mask = offs_m < SEQLEN_Q lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) # TODO: Should dropout and return encoded softmax be handled here too? return grp_sz: tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS if grp_sz != 1: # Grouped Query Attention off_k_head = off_q_head // grp_sz else: off_k_head = off_q_head # q,k,v offsets q_offs = ( off_z * stride_qz + off_q_head * stride_qh + cu_seqlens_q_start * stride_qm + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk ) q_ptrs = q_ptr + q_offs if HAS_PE: q_pe_offs = ( off_z * stride_qz + off_q_head * stride_qh + cu_seqlens_q_start * stride_qm + offs_m[:, None] * stride_qm + offs_pe[None, :] * stride_qk ) q_pe_ptrs = q_ptr + q_pe_offs else: q_pe_ptrs = None k_offs = ( off_z * stride_kz + off_k_head * stride_kh + cu_seqlens_k_start * stride_kn + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn ) k_ptrs = k_ptr + k_offs if HAS_PE: k_pe_offs = ( off_z * stride_kz + off_k_head * stride_kh + cu_seqlens_k_start * stride_kn + offs_pe[:, None] * stride_kk + offs_n[None, :] * stride_kn ) k_pe_ptrs = k_ptr + k_pe_offs else: k_pe_ptrs = None v_offs = ( off_z * stride_vz + off_k_head * stride_vh + cu_seqlens_k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk ) v_ptrs = v_ptr + v_offs # alibi slopes if alibi_slopes_ptr is not None: alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h alibi_slope = tl.load(alibi_slopes_ptr + alibi_offs) else: alibi_slope = None # s_dmask (return_scores) if s_dmask_ptr is not None: s_dmask_offs = ( off_z * stride_sd_z + off_q_head * stride_sd_h + offs_m[:, None] * stride_sd_m + offs_n[None, :] * stride_sd_n ) s_dmask_ptrs = s_dmask_ptr + s_dmask_offs else: s_dmask_ptrs = None # dropout if dropout_mask_ptr is not None: dropout_mask_offs = ( off_z * stride_sd_z + off_q_head * stride_sd_h + offs_m[:, None] * stride_sd_m + offs_n[None, :] * stride_sd_n ) dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs philox_ptrs = ( philox_offset_base + off_z * stride_sd_z + off_q_head * stride_sd_h + offs_m[:, None] * stride_sd_m + offs_n[None, :] * stride_sd_n ) else: dropout_mask_ptrs = None philox_ptrs = None if ENABLE_SINK: RCP_LN2: tl.constexpr = 1.4426950408889634 m_i_value = tl.load(sink_ptr + off_q_head).to(tl.float32) * RCP_LN2 else: m_i_value = float("-inf") m_i = tl.full([BLOCK_M], m_i_value, dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) if BLOCK_DMODEL == BLOCK_DMODEL_POW2: q_mask = offs_m[:, None] < seqlen_q else: q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) if BLOCK_M >= NUM_Q_HEADS: q_cache_mod: tl.constexpr = ".cg" else: q_cache_mod: tl.constexpr = "" if HAS_PE: q_pe = tl.load(q_pe_ptrs, mask=q_mask, other=0.0, cache_modifier=q_cache_mod) else: q_pe = None q = tl.load(q_ptrs, mask=q_mask, other=0.0, cache_modifier=q_cache_mod) if IS_FP8: descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) else: descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N # if CAUSAL, then determine masked_blocks and full blocks # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) if IS_CAUSAL: # There are always at least BLOCK_M // BLOCK_N masked blocks. # Additionally there might be one more due to dissimilar seqlens. masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. # In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N # Compute for full blocks. Here we set causal to false regardless of its actual # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, q_pe, k_ptrs, k_pe_ptrs, v_ptrs, stride_kn, stride_vn, stride_sd_n, start_m, seqlen_k, seqlen_q, dropout_p, s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, block_min, block_max, 0, 0, 0, alibi_slope, descale_q, descale_k, descale_v, offs_m, offs_n, PRELOAD_V, BLOCK_M, BLOCK_N, BLOCK_DMODEL, BLOCK_DMODEL_POW2, BLOCK_DMODEL_PE, sm_scale, False, MASK_STEPS=False, ENABLE_DROPOUT=ENABLE_DROPOUT, RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, ENABLE_PIPELINING=True, ) block_min = block_max block_max = n_blocks * BLOCK_N # Remaining blocks, if any, are full / not masked. if masked_blocks > 0: if IS_CAUSAL: offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: offs_n_causal = 0 k_ptrs += n_full_blocks * BLOCK_N * stride_kn if HAS_PE: k_pe_ptrs += n_full_blocks * BLOCK_N * stride_kn v_ptrs += n_full_blocks * BLOCK_N * stride_vn if RETURN_SCORES: s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n if ENABLE_DROPOUT: dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i, q, q_pe, k_ptrs, k_pe_ptrs, v_ptrs, stride_kn, stride_vn, stride_sd_n, start_m, seqlen_k, seqlen_q, dropout_p, s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, offs_m, offs_n, PRELOAD_V, BLOCK_M, BLOCK_N, BLOCK_DMODEL, BLOCK_DMODEL_POW2, BLOCK_DMODEL_PE, sm_scale, IS_CAUSAL, MASK_STEPS=True, ENABLE_DROPOUT=ENABLE_DROPOUT, RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, ENABLE_PIPELINING=False, ) # epilogue # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. l_recip = 1 / l_i[:, None] acc = acc * l_recip if ENABLE_DROPOUT: dropout_scale = 1 / (1 - dropout_p) acc = acc * dropout_scale # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, # then we have one block with a row of all NaNs which come from computing # softmax over a row of all -infs (-inf - inf = NaN). We check for that here # and store 0s where there are NaNs as these rows should've been zeroed out. end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full( (BLOCK_DMODEL_POW2,), causal_start_idx, dtype=tl.int32 ) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE(Log Sum Exponents), the log of the normalization constant overflow_size = end_m_idx - seqlen_q if softmax_lse_ptr is not None: LN2: tl.constexpr = 0.6931471824645996 # compute log-sum-exp in base 2 units softmax_lse = m_i + tl.math.log2(l_i) # convert back to natural units softmax_lse *= LN2 if IS_CAUSAL: # zero out nans caused by -infs when doing causal lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. # This is only true for the last M block. For others, overflow_size will be -ve offs_lse = ( off_z * stride_lse_z + off_q_head * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m ) if overflow_size > 0: boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) lse_mask = tl.arange(0, BLOCK_M) < boundary tl.store( softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask ) # the log of the normalization constant else: tl.store( softmax_lse_ptr + offs_lse, softmax_lse ) # the log of the normalization constant # write back O offs_out = ( off_z * stride_oz + off_q_head * stride_oh + cu_seqlens_q_start * stride_om + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on ) out_mask = tl.full([BLOCK_M, 1], 1, dtype=tl.int1) if overflow_size > 0: out_mask = out_mask & (offs_m[:, None] < seqlen_q) if BLOCK_DMODEL != BLOCK_DMODEL_POW2: out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) op = acc.to(out_ptr.dtype.element_ty) tl.store(out_ptr + offs_out, op, mask=out_mask) @functools.lru_cache(maxsize=1024) def _get_config( enable_dropout: bool, dtype: torch.dtype, has_pe: bool = False, ): if not hasattr(_get_config, "_config_dict"): dev = arch_info.get_arch() _get_config._config_dict = {} fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" with open(fpath, "r") as file: config = json.load(file) _get_config._config_dict["default"] = config fwd_cfg = _get_config._config_dict["default"]["fwd"] has_dropout_or_fp32 = enable_dropout or dtype == torch.float32 # TODO: pe + dropout is not tuned if has_pe and has_dropout_or_fp32 and "pe_dropout_or_fp32" in fwd_cfg: return fwd_cfg["pe_dropout_or_fp32"] elif has_pe: # MLA prefill (head_dim_qk=192/v=128) tuned for gfx942 (MI300X). # The stock "pe" config uses BLOCK_M=256, which produces too few # workgroups (batch*heads*cdiv(seqlen,256)) to fill the 304 CUs for the # short prefill seqlens seen here, leaving the GPU under-occupied. # Halving BLOCK_M to 128 doubles workgroup count (better occupancy) and # enabling 2-stage software pipelining (num_stages=2) overlaps the K/V # loads with the QK/PV MFMA chain. waves_per_eu=1 + num_warps=4 keeps # register/LDS pressure low enough to actually realize 2 pipeline stages # (num_stages>=3 overflows the 64KB LDS for this 192/128 head config). return { "BLOCK_M": 128, "BLOCK_N": 64, "PRELOAD_V": True, "waves_per_eu": 1, "num_warps": 4, "num_ctas": 1, "num_stages": 2, } elif enable_dropout or dtype == torch.float32: return fwd_cfg["dropout_or_fp32"] else: return fwd_cfg["default"]