| |
| |
|
|
| import functools |
| import json |
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from aiter.ops.triton.utils._triton import arch_info |
| from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH |
| from aiter.ops.triton.utils._triton.pid_preprocessing import remap_xcd |
| from aiter.ops.triton.utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors |
| from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr |
|
|
|
|
| @triton.jit |
| def _cdiv_fn(x, y): |
| return (x + y - 1) // y |
|
|
|
|
| @triton.jit |
| def _load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): |
| if offset_first is not None and offset_second is not None: |
| mask = (offset_first[:, None] < boundary_first) & ( |
| offset_second[None, :] < boundary_second |
| ) |
| tensor = tl.load(ptrs, mask=mask, other=0.0) |
| elif offset_first is not None: |
| mask = offset_first[:, None] < boundary_first |
| tensor = tl.load(ptrs, mask=mask, other=0.0) |
| elif offset_second is not None: |
| mask = offset_second[None, :] < boundary_second |
| tensor = tl.load(ptrs, mask=mask, other=0.0) |
| else: |
| tensor = tl.load(ptrs) |
| return tensor |
|
|
|
|
| @triton.jit |
| def _compute_alibi_block( |
| alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False |
| ): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] |
| alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) |
| if transpose: |
| return alibi_block.T |
| else: |
| return alibi_block |
|
|
|
|
| @triton.jit |
| def _attn_fwd_inner( |
| acc, |
| l_i, |
| m_i, |
| q, |
| q_pe, |
| k_ptrs, |
| k_pe_ptrs, |
| v_ptrs, |
| stride_kn, |
| stride_vk, |
| stride_sn, |
| start_m, |
| seqlen_k, |
| seqlen_q, |
| dropout_p, |
| sd_mask_ptrs, |
| dropout_mask_ptrs, |
| philox_seed, |
| philox_ptrs, |
| block_min, |
| block_max, |
| offs_n_causal, |
| masked_blocks, |
| n_extra_tokens, |
| alibi_slope, |
| descale_q, |
| descale_k, |
| descale_v, |
| OFFS_M: tl.constexpr, |
| OFFS_N: tl.constexpr, |
| PRELOAD_V: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BLOCK_DMODEL: tl.constexpr, |
| BLOCK_DMODEL_POW2: tl.constexpr, |
| BLOCK_DMODEL_PE: tl.constexpr, |
| SM_SCALE: tl.constexpr, |
| IS_CAUSAL: tl.constexpr, |
| MASK_STEPS: tl.constexpr, |
| ENABLE_DROPOUT: tl.constexpr, |
| RETURN_SCORES: tl.constexpr, |
| PADDED_HEAD: tl.constexpr, |
| IS_FP8: tl.constexpr, |
| FP8_MAX: tl.constexpr, |
| ENABLE_PIPELINING: tl.constexpr, |
| ): |
| RCP_LN2: tl.constexpr = 1.4426950408889634 |
| HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 |
|
|
| |
|
|
| num_stages: tl.constexpr = ( |
| None if ENABLE_PIPELINING else 1 |
| ) |
| for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages): |
| |
| |
| if MASK_STEPS: |
| k_offs_n = start_n + tl.arange(0, BLOCK_N) |
| else: |
| k_offs_n = None |
| k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) |
| k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) |
| if HAS_PE: |
| k_pe = _load_fn( |
| k_pe_ptrs, |
| None, |
| k_offs_n, |
| (BLOCK_DMODEL + BLOCK_DMODEL_PE), |
| seqlen_k, |
| ) |
| if PRELOAD_V: |
| v = _load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) |
|
|
| |
| |
| |
| mask = tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1) |
| if MASK_STEPS: |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| bound_cond = (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0) |
| size_n = start_n + OFFS_N[None, :] |
| mask_partial = size_n < seqlen_k |
| mask = tl.where(bound_cond, mask_partial, mask) |
|
|
| |
| q_mask = OFFS_M[:, None] < seqlen_q |
| k_mask = (start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k |
| p_mask = q_mask & k_mask |
| qk_scale = SM_SCALE * RCP_LN2 |
| |
| qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
| if HAS_PE: |
| qk += tl.dot(q_pe, k_pe) |
| qk += tl.dot(q, k) |
| if IS_FP8: |
| qk = qk * (qk_scale * descale_q * descale_k) |
| else: |
| qk = qk * qk_scale |
| if IS_CAUSAL: |
| causal_boundary = start_n + offs_n_causal |
| causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] |
| mask = mask & causal_mask |
|
|
| qk = tl.where(mask, qk, float("-inf")) |
|
|
| if alibi_slope is not None: |
| |
| global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| global_n_positions = start_n + tl.arange(0, BLOCK_N) |
| alibi_block = _compute_alibi_block( |
| alibi_slope, seqlen_q, seqlen_k, global_m_positions, global_n_positions |
| ) |
| qk += alibi_block * RCP_LN2 |
| |
| m_ij = tl.maximum(m_i, tl.max(qk, 1)) |
|
|
| |
| p = tl.math.exp2(qk - m_ij[:, None]) |
|
|
| |
| l_ij = tl.sum(p, 1) |
| if ENABLE_DROPOUT: |
| rng_output = tl.rand( |
| philox_seed, philox_ptrs |
| ) |
| dropout_mask = rng_output > dropout_p |
| tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) |
|
|
| |
| sd_mask = tl.where(dropout_mask, p, -p) |
| tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) |
|
|
| |
| p = tl.where(dropout_mask, p, 0.0) |
| elif RETURN_SCORES: |
| |
| tl.store(sd_mask_ptrs, p, mask=p_mask) |
|
|
| |
| |
| |
| alpha = tl.math.exp2(m_i - m_ij) |
| acc = acc * alpha[:, None] |
| |
| l_i = l_i * alpha + l_ij |
| |
| m_i = m_ij |
| if not PRELOAD_V: |
| v = _load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) |
| if IS_FP8: |
| scale_p, descale_p = _compute_fp8_scaling_factors(p, FP8_MAX) |
| acc += ( |
| tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v |
| ) |
| else: |
| acc += tl.dot(p.to(v.type.element_ty), v) |
|
|
| k_ptrs += BLOCK_N * stride_kn |
| if HAS_PE: |
| k_pe_ptrs += BLOCK_N * stride_kn |
| v_ptrs += BLOCK_N * stride_vk |
| if RETURN_SCORES: |
| sd_mask_ptrs += BLOCK_N * stride_sn |
|
|
| if ENABLE_DROPOUT: |
| dropout_mask_ptrs += BLOCK_N * stride_sn |
| philox_ptrs += BLOCK_N * stride_sn |
|
|
| return acc, l_i, m_i |
|
|
|
|
| _attn_fwd_repr = make_kernel_repr( |
| "_attn_fwd", |
| [ |
| "IS_CAUSAL", |
| "NUM_Q_HEADS", |
| "NUM_K_HEADS", |
| "BLOCK_M", |
| "BLOCK_N", |
| "BLOCK_DMODEL", |
| "RETURN_SCORES", |
| "ENABLE_DROPOUT", |
| "IS_FP8", |
| "VARLEN", |
| "NUM_XCD", |
| "USE_INT64_STRIDES", |
| "ENABLE_SINK", |
| ], |
| ) |
|
|
|
|
| @triton.jit(repr=_attn_fwd_repr) |
| def _attn_fwd( |
| q_ptr: torch.Tensor, |
| k_ptr: torch.Tensor, |
| v_ptr: torch.Tensor, |
| descale_q_ptr: torch.Tensor, |
| descale_k_ptr: torch.Tensor, |
| descale_v_ptr: torch.Tensor, |
| out_ptr: torch.Tensor, |
| alibi_slopes_ptr: torch.Tensor, |
| s_dmask_ptr: torch.Tensor, |
| dropout_mask_ptr: torch.Tensor, |
| softmax_lse_ptr: torch.Tensor, |
| sink_ptr: torch.Tensor, |
| stride_qz_in, |
| stride_qh_in, |
| stride_qm_in, |
| stride_qk_in, |
| stride_kz_in, |
| stride_kh_in, |
| stride_kn_in, |
| stride_kk_in, |
| stride_vz_in, |
| stride_vh_in, |
| stride_vn_in, |
| stride_vk_in, |
| stride_descale_q_z_in, |
| stride_descale_k_z_in, |
| stride_descale_v_z_in, |
| stride_oz_in, |
| stride_oh_in, |
| stride_om_in, |
| stride_on_in, |
| stride_alibi_z_in, |
| stride_alibi_h_in, |
| stride_sd_z_in, |
| stride_sd_h_in, |
| stride_sd_m_in, |
| stride_sd_n_in, |
| stride_lse_z_in, |
| stride_lse_h_in, |
| stride_lse_m_in, |
| sm_scale, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| dropout_p, |
| philox_seed, |
| philox_offset_base_in, |
| SEQLEN_Q, |
| SEQLEN_K, |
| IS_CAUSAL: tl.constexpr, |
| NUM_Q_HEADS: tl.constexpr, |
| NUM_K_HEADS: tl.constexpr, |
| PRELOAD_V: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BLOCK_DMODEL: tl.constexpr, |
| BLOCK_DMODEL_POW2: tl.constexpr, |
| BLOCK_DMODEL_PE: tl.constexpr, |
| RETURN_SCORES: tl.constexpr, |
| ENABLE_DROPOUT: tl.constexpr, |
| IS_FP8: tl.constexpr, |
| FP8_MAX: tl.constexpr, |
| VARLEN: tl.constexpr, |
| BATCH, |
| NUM_XCD: tl.constexpr, |
| USE_INT64_STRIDES: tl.constexpr, |
| ENABLE_SINK: tl.constexpr, |
| ): |
| NUM_BLOCKS = (SEQLEN_Q + BLOCK_M - 1) // BLOCK_M |
| |
| wid = tl.program_id( |
| 0 |
| ) |
| |
|
|
| off_q_head = wid % NUM_Q_HEADS |
| off_q_head = remap_xcd(off_q_head, NUM_Q_HEADS, NUM_XCD) |
| start_m = (wid // NUM_Q_HEADS) % NUM_BLOCKS |
| off_z = (wid // (NUM_BLOCKS * NUM_Q_HEADS)) % BATCH |
|
|
| |
| offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_n = tl.arange(0, BLOCK_N) |
| offs_d = tl.arange(0, BLOCK_DMODEL_POW2) |
| HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 |
| if HAS_PE: |
| offs_pe = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL_PE) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if USE_INT64_STRIDES: |
| stride_qz = tl.cast(stride_qz_in, tl.int64) |
| stride_qh = tl.cast(stride_qh_in, tl.int64) |
| stride_qm = tl.cast(stride_qm_in, tl.int64) |
| stride_qk = tl.cast(stride_qk_in, tl.int64) |
| stride_kz = tl.cast(stride_kz_in, tl.int64) |
| stride_kh = tl.cast(stride_kh_in, tl.int64) |
| stride_kn = tl.cast(stride_kn_in, tl.int64) |
| stride_kk = tl.cast(stride_kk_in, tl.int64) |
| stride_vz = tl.cast(stride_vz_in, tl.int64) |
| stride_vh = tl.cast(stride_vh_in, tl.int64) |
| stride_vn = tl.cast(stride_vn_in, tl.int64) |
| stride_vk = tl.cast(stride_vk_in, tl.int64) |
| if IS_FP8: |
| stride_descale_q_z = tl.cast(stride_descale_q_z_in, tl.int64) |
| stride_descale_k_z = tl.cast(stride_descale_k_z_in, tl.int64) |
| stride_descale_v_z = tl.cast(stride_descale_v_z_in, tl.int64) |
| stride_oz = tl.cast(stride_oz_in, tl.int64) |
| stride_oh = tl.cast(stride_oh_in, tl.int64) |
| stride_om = tl.cast(stride_om_in, tl.int64) |
| stride_on = tl.cast(stride_on_in, tl.int64) |
| stride_alibi_z = tl.cast(stride_alibi_z_in, tl.int64) |
| stride_alibi_h = tl.cast(stride_alibi_h_in, tl.int64) |
|
|
| |
| philox_offset_base = tl.cast(philox_offset_base_in, tl.int64) |
| stride_sd_z = tl.cast(stride_sd_z_in, tl.int64) |
| stride_sd_h = tl.cast(stride_sd_h_in, tl.int64) |
| stride_sd_m = tl.cast(stride_sd_m_in, tl.int64) |
| stride_sd_n = tl.cast(stride_sd_n_in, tl.int64) |
| stride_lse_z = tl.cast(stride_lse_z_in, tl.int64) |
| stride_lse_h = tl.cast(stride_lse_h_in, tl.int64) |
| stride_lse_m = tl.cast(stride_lse_m_in, tl.int64) |
| else: |
| stride_qz = stride_qz_in |
| stride_qm = stride_qm_in |
| stride_qk = stride_qk_in |
| stride_qh = stride_qh_in |
| stride_kz = stride_kz_in |
| stride_kh = stride_kh_in |
| stride_kn = stride_kn_in |
| stride_kk = stride_kk_in |
| stride_vz = stride_vz_in |
| stride_vh = stride_vh_in |
| stride_vn = stride_vn_in |
| stride_vk = stride_vk_in |
| stride_descale_q_z = stride_descale_q_z_in |
| stride_descale_k_z = stride_descale_k_z_in |
| stride_descale_v_z = stride_descale_v_z_in |
| stride_oz = stride_oz_in |
| stride_oh = stride_oh_in |
| stride_om = stride_om_in |
| stride_on = stride_on_in |
| stride_alibi_z = stride_alibi_z_in |
| stride_alibi_h = stride_alibi_h_in |
| philox_offset_base = philox_offset_base_in |
| stride_sd_z = stride_sd_z_in |
| stride_sd_h = stride_sd_h_in |
| stride_sd_m = stride_sd_m_in |
| stride_sd_n = stride_sd_n_in |
| stride_lse_z = stride_lse_z_in |
| stride_lse_h = stride_lse_h_in |
| stride_lse_m = stride_lse_m_in |
|
|
| tl.assume(stride_qz_in >= 0) |
| tl.assume(stride_qh_in >= 0) |
| tl.assume(stride_qm_in >= 0) |
| tl.assume(stride_qk_in >= 0) |
| tl.assume(stride_kz_in >= 0) |
| tl.assume(stride_kh_in >= 0) |
| tl.assume(stride_kn_in >= 0) |
| tl.assume(stride_kk_in >= 0) |
| tl.assume(stride_vz_in >= 0) |
| tl.assume(stride_vh_in >= 0) |
| tl.assume(stride_vn_in >= 0) |
| tl.assume(stride_vk_in >= 0) |
| if IS_FP8: |
| tl.assume(stride_descale_q_z_in >= 0) |
| tl.assume(stride_descale_k_z_in >= 0) |
| tl.assume(stride_descale_v_z_in >= 0) |
| tl.assume(stride_oz_in >= 0) |
| tl.assume(stride_oh_in >= 0) |
| tl.assume(stride_om_in >= 0) |
| tl.assume(stride_on_in >= 0) |
| tl.assume(stride_alibi_z_in >= 0) |
| tl.assume(stride_alibi_h_in >= 0) |
| |
| tl.assume(philox_offset_base_in >= 0) |
| tl.assume(stride_sd_z_in >= 0) |
| tl.assume(stride_sd_h_in >= 0) |
| tl.assume(stride_sd_m_in >= 0) |
| tl.assume(stride_sd_n_in >= 0) |
| tl.assume(stride_lse_z_in >= 0) |
| tl.assume(stride_lse_h_in >= 0) |
| tl.assume(stride_lse_m_in >= 0) |
|
|
| if VARLEN: |
| cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) |
| cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) |
|
|
| seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start |
| |
| |
| if start_m * BLOCK_M > seqlen_q: |
| return |
| cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) |
| cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) |
| seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start |
| else: |
| cu_seqlens_q_start = 0 |
| cu_seqlens_k_start = 0 |
| seqlen_q = SEQLEN_Q |
| seqlen_k = SEQLEN_K |
|
|
| n_blocks = _cdiv_fn(seqlen_k, BLOCK_N) |
|
|
| |
| |
| |
| |
| |
| |
| if IS_CAUSAL: |
| |
| |
| |
| |
|
|
| |
| n_blocks_seqlen = _cdiv_fn( |
| (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N |
| ) |
|
|
| |
| |
| n_blocks = min(n_blocks, n_blocks_seqlen) |
|
|
| |
| |
| if n_blocks <= 0: |
| offs_out = ( |
| off_z * stride_oz |
| + off_q_head * stride_oh |
| + cu_seqlens_q_start * stride_om |
| + offs_m[:, None] * stride_om |
| + offs_d[None, :] * stride_on |
| ) |
| acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) |
| out_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) |
| tl.store(out_ptr + offs_out, acc, mask=out_mask) |
|
|
| if softmax_lse_ptr is not None: |
| offs_lse = ( |
| off_z * stride_lse_z |
| + off_q_head * stride_lse_h |
| + cu_seqlens_q_start * stride_lse_m |
| + offs_m * stride_lse_m |
| ) |
| lse_mask = offs_m < SEQLEN_Q |
| lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) |
| tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) |
| |
|
|
| return |
|
|
| grp_sz: tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS |
| if grp_sz != 1: |
| off_k_head = off_q_head // grp_sz |
| else: |
| off_k_head = off_q_head |
|
|
| |
| q_offs = ( |
| off_z * stride_qz |
| + off_q_head * stride_qh |
| + cu_seqlens_q_start * stride_qm |
| + offs_m[:, None] * stride_qm |
| + offs_d[None, :] * stride_qk |
| ) |
| q_ptrs = q_ptr + q_offs |
| if HAS_PE: |
| q_pe_offs = ( |
| off_z * stride_qz |
| + off_q_head * stride_qh |
| + cu_seqlens_q_start * stride_qm |
| + offs_m[:, None] * stride_qm |
| + offs_pe[None, :] * stride_qk |
| ) |
| q_pe_ptrs = q_ptr + q_pe_offs |
| else: |
| q_pe_ptrs = None |
|
|
| k_offs = ( |
| off_z * stride_kz |
| + off_k_head * stride_kh |
| + cu_seqlens_k_start * stride_kn |
| + offs_d[:, None] * stride_kk |
| + offs_n[None, :] * stride_kn |
| ) |
| k_ptrs = k_ptr + k_offs |
| if HAS_PE: |
| k_pe_offs = ( |
| off_z * stride_kz |
| + off_k_head * stride_kh |
| + cu_seqlens_k_start * stride_kn |
| + offs_pe[:, None] * stride_kk |
| + offs_n[None, :] * stride_kn |
| ) |
| k_pe_ptrs = k_ptr + k_pe_offs |
| else: |
| k_pe_ptrs = None |
|
|
| v_offs = ( |
| off_z * stride_vz |
| + off_k_head * stride_vh |
| + cu_seqlens_k_start * stride_vn |
| + offs_n[:, None] * stride_vn |
| + offs_d[None, :] * stride_vk |
| ) |
| v_ptrs = v_ptr + v_offs |
|
|
| |
| if alibi_slopes_ptr is not None: |
| alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h |
| alibi_slope = tl.load(alibi_slopes_ptr + alibi_offs) |
| else: |
| alibi_slope = None |
|
|
| |
| if s_dmask_ptr is not None: |
| s_dmask_offs = ( |
| off_z * stride_sd_z |
| + off_q_head * stride_sd_h |
| + offs_m[:, None] * stride_sd_m |
| + offs_n[None, :] * stride_sd_n |
| ) |
| s_dmask_ptrs = s_dmask_ptr + s_dmask_offs |
| else: |
| s_dmask_ptrs = None |
|
|
| |
| if dropout_mask_ptr is not None: |
| dropout_mask_offs = ( |
| off_z * stride_sd_z |
| + off_q_head * stride_sd_h |
| + offs_m[:, None] * stride_sd_m |
| + offs_n[None, :] * stride_sd_n |
| ) |
| dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs |
| philox_ptrs = ( |
| philox_offset_base |
| + off_z * stride_sd_z |
| + off_q_head * stride_sd_h |
| + offs_m[:, None] * stride_sd_m |
| + offs_n[None, :] * stride_sd_n |
| ) |
| else: |
| dropout_mask_ptrs = None |
| philox_ptrs = None |
|
|
| if ENABLE_SINK: |
| RCP_LN2: tl.constexpr = 1.4426950408889634 |
| m_i_value = tl.load(sink_ptr + off_q_head).to(tl.float32) * RCP_LN2 |
| else: |
| m_i_value = float("-inf") |
|
|
| m_i = tl.full([BLOCK_M], m_i_value, dtype=tl.float32) |
| l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) |
| acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) |
| if BLOCK_DMODEL == BLOCK_DMODEL_POW2: |
| q_mask = offs_m[:, None] < seqlen_q |
| else: |
| q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) |
|
|
| if BLOCK_M >= NUM_Q_HEADS: |
| q_cache_mod: tl.constexpr = ".cg" |
| else: |
| q_cache_mod: tl.constexpr = "" |
|
|
| if HAS_PE: |
| q_pe = tl.load(q_pe_ptrs, mask=q_mask, other=0.0, cache_modifier=q_cache_mod) |
| else: |
| q_pe = None |
| q = tl.load(q_ptrs, mask=q_mask, other=0.0, cache_modifier=q_cache_mod) |
| if IS_FP8: |
| descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) |
| descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) |
| descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) |
| else: |
| descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 |
|
|
| n_extra_tokens = 0 |
| if seqlen_k < BLOCK_N: |
| n_extra_tokens = BLOCK_N - seqlen_k |
| elif seqlen_k % BLOCK_N: |
| n_extra_tokens = seqlen_k % BLOCK_N |
|
|
| |
| |
| padded_block_k = n_extra_tokens != 0 |
| is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) |
| if IS_CAUSAL: |
| |
| |
| masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) |
| else: |
| |
| masked_blocks = padded_block_k |
| |
| |
| masked_blocks = min(masked_blocks, n_blocks) |
| n_full_blocks = n_blocks - masked_blocks |
| block_min = 0 |
| block_max = n_blocks * BLOCK_N |
| |
| |
| if n_full_blocks > 0: |
| block_max = (n_blocks - masked_blocks) * BLOCK_N |
| acc, l_i, m_i = _attn_fwd_inner( |
| acc, |
| l_i, |
| m_i, |
| q, |
| q_pe, |
| k_ptrs, |
| k_pe_ptrs, |
| v_ptrs, |
| stride_kn, |
| stride_vn, |
| stride_sd_n, |
| start_m, |
| seqlen_k, |
| seqlen_q, |
| dropout_p, |
| s_dmask_ptrs, |
| dropout_mask_ptrs, |
| philox_seed, |
| philox_ptrs, |
| block_min, |
| block_max, |
| 0, |
| 0, |
| 0, |
| alibi_slope, |
| descale_q, |
| descale_k, |
| descale_v, |
| offs_m, |
| offs_n, |
| PRELOAD_V, |
| BLOCK_M, |
| BLOCK_N, |
| BLOCK_DMODEL, |
| BLOCK_DMODEL_POW2, |
| BLOCK_DMODEL_PE, |
| sm_scale, |
| False, |
| MASK_STEPS=False, |
| ENABLE_DROPOUT=ENABLE_DROPOUT, |
| RETURN_SCORES=RETURN_SCORES, |
| PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, |
| IS_FP8=IS_FP8, |
| FP8_MAX=FP8_MAX, |
| ENABLE_PIPELINING=True, |
| ) |
| block_min = block_max |
| block_max = n_blocks * BLOCK_N |
|
|
| |
| if masked_blocks > 0: |
| if IS_CAUSAL: |
| offs_n_causal = offs_n + (seqlen_q - seqlen_k) |
| else: |
| offs_n_causal = 0 |
| k_ptrs += n_full_blocks * BLOCK_N * stride_kn |
| if HAS_PE: |
| k_pe_ptrs += n_full_blocks * BLOCK_N * stride_kn |
| v_ptrs += n_full_blocks * BLOCK_N * stride_vn |
| if RETURN_SCORES: |
| s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n |
| if ENABLE_DROPOUT: |
| dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n |
| acc, l_i, m_i = _attn_fwd_inner( |
| acc, |
| l_i, |
| m_i, |
| q, |
| q_pe, |
| k_ptrs, |
| k_pe_ptrs, |
| v_ptrs, |
| stride_kn, |
| stride_vn, |
| stride_sd_n, |
| start_m, |
| seqlen_k, |
| seqlen_q, |
| dropout_p, |
| s_dmask_ptrs, |
| dropout_mask_ptrs, |
| philox_seed, |
| philox_ptrs, |
| block_min, |
| block_max, |
| offs_n_causal, |
| masked_blocks, |
| n_extra_tokens, |
| alibi_slope, |
| descale_q, |
| descale_k, |
| descale_v, |
| offs_m, |
| offs_n, |
| PRELOAD_V, |
| BLOCK_M, |
| BLOCK_N, |
| BLOCK_DMODEL, |
| BLOCK_DMODEL_POW2, |
| BLOCK_DMODEL_PE, |
| sm_scale, |
| IS_CAUSAL, |
| MASK_STEPS=True, |
| ENABLE_DROPOUT=ENABLE_DROPOUT, |
| RETURN_SCORES=RETURN_SCORES, |
| PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, |
| IS_FP8=IS_FP8, |
| FP8_MAX=FP8_MAX, |
| ENABLE_PIPELINING=False, |
| ) |
| |
| |
| l_recip = 1 / l_i[:, None] |
| acc = acc * l_recip |
| if ENABLE_DROPOUT: |
| dropout_scale = 1 / (1 - dropout_p) |
| acc = acc * dropout_scale |
| |
| |
| |
| |
| end_m_idx = (start_m + 1) * BLOCK_M |
| start_m_idx = start_m * BLOCK_M |
| causal_start_idx = seqlen_q - seqlen_k |
| if IS_CAUSAL: |
| if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: |
| out_mask_boundary = tl.full( |
| (BLOCK_DMODEL_POW2,), causal_start_idx, dtype=tl.int32 |
| ) |
| mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) |
| out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] |
| z = 0.0 |
| acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) |
|
|
| |
| overflow_size = end_m_idx - seqlen_q |
| if softmax_lse_ptr is not None: |
| LN2: tl.constexpr = 0.6931471824645996 |
| |
| softmax_lse = m_i + tl.math.log2(l_i) |
| |
| softmax_lse *= LN2 |
|
|
| if IS_CAUSAL: |
| |
| lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx |
| softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) |
|
|
| |
| |
| offs_lse = ( |
| off_z * stride_lse_z |
| + off_q_head * stride_lse_h |
| + cu_seqlens_q_start * stride_lse_m |
| + offs_m * stride_lse_m |
| ) |
| if overflow_size > 0: |
| boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) |
| lse_mask = tl.arange(0, BLOCK_M) < boundary |
| tl.store( |
| softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask |
| ) |
| else: |
| tl.store( |
| softmax_lse_ptr + offs_lse, softmax_lse |
| ) |
|
|
| |
| offs_out = ( |
| off_z * stride_oz |
| + off_q_head * stride_oh |
| + cu_seqlens_q_start * stride_om |
| + offs_m[:, None] * stride_om |
| + offs_d[None, :] * stride_on |
| ) |
| out_mask = tl.full([BLOCK_M, 1], 1, dtype=tl.int1) |
| if overflow_size > 0: |
| out_mask = out_mask & (offs_m[:, None] < seqlen_q) |
| if BLOCK_DMODEL != BLOCK_DMODEL_POW2: |
| out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) |
| op = acc.to(out_ptr.dtype.element_ty) |
| tl.store(out_ptr + offs_out, op, mask=out_mask) |
|
|
|
|
| @functools.lru_cache(maxsize=1024) |
| def _get_config( |
| enable_dropout: bool, |
| dtype: torch.dtype, |
| has_pe: bool = False, |
| ): |
| if not hasattr(_get_config, "_config_dict"): |
| dev = arch_info.get_arch() |
| _get_config._config_dict = {} |
| fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-MHA-DEFAULT.json" |
| with open(fpath, "r") as file: |
| config = json.load(file) |
| _get_config._config_dict["default"] = config |
| fwd_cfg = _get_config._config_dict["default"]["fwd"] |
| has_dropout_or_fp32 = enable_dropout or dtype == torch.float32 |
| |
| if has_pe and has_dropout_or_fp32 and "pe_dropout_or_fp32" in fwd_cfg: |
| return fwd_cfg["pe_dropout_or_fp32"] |
| elif has_pe: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return { |
| "BLOCK_M": 128, |
| "BLOCK_N": 64, |
| "PRELOAD_V": True, |
| "waves_per_eu": 1, |
| "num_warps": 4, |
| "num_ctas": 1, |
| "num_stages": 2, |
| } |
| elif enable_dropout or dtype == torch.float32: |
| return fwd_cfg["dropout_or_fp32"] |
| else: |
| return fwd_cfg["default"] |
|
|