""" Copyright (c) 2024 by SageAttention team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import torch, math import triton import triton.language as tl import torch.nn.functional as F @triton.jit def _attn_fwd_inner(acc, l_i, old_m, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, ): if STAGE == 1: lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) K_scale_ptr += lo // BLOCK_N K_ptrs += stride_kn * lo V_ptrs += stride_vn * lo elif STAGE == 3: lo, hi = 0, kv_len for start_n in range(lo, hi, BLOCK_N): kbid = tl.load(K_bid_ptr + start_n//BLOCK_N) if kbid: k_mask = offs_n[None, :] < (kv_len - start_n) k = tl.load(K_ptrs, mask = k_mask) k_scale = tl.load(K_scale_ptr) qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) qk = qk + tl.where(mask, 0, -1.0e6) local_m = tl.max(qk, 1) new_m = tl.maximum(old_m, local_m) qk -= new_m[:, None] else: local_m = tl.max(qk, 1) new_m = tl.maximum(old_m, local_m) qk = qk - new_m[:, None] p = tl.math.exp2(qk) l_ij = tl.sum(p, 1) alpha = tl.math.exp2(old_m - new_m) l_i = l_i * alpha + l_ij acc = acc * alpha[:, None] v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) p = p.to(tl.float16) acc += tl.dot(p, v, out_dtype=tl.float16) old_m = new_m K_ptrs += BLOCK_N * stride_kn K_scale_ptr += 1 V_ptrs += BLOCK_N * stride_vn return acc, l_i, old_m @triton.jit def _attn_fwd(Q, K, K_blkid, V, Q_scale, K_scale, Out, stride_qz, stride_qh, stride_qn, stride_kz, stride_kh, stride_kn, stride_vz, stride_vh, stride_vn, stride_oz, stride_oh, stride_on, stride_kbidq, stride_kbidk, qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr ): start_m = tl.program_id(0) off_z = tl.program_id(2).to(tl.int64) off_h = tl.program_id(1).to(tl.int64) q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) k_bid_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * stride_kbidq offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, HEAD_DIM) Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :] Q_scale_ptr = Q_scale + q_scale_offset + start_m K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] K_scale_ptr = K_scale + k_scale_offset K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :] m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len) q_scale = tl.load(Q_scale_ptr) acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m, BLOCK_M, HEAD_DIM, BLOCK_N, 4 - STAGE, offs_m, offs_n ) if STAGE != 1: acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m, BLOCK_M, HEAD_DIM, BLOCK_N, 2, offs_m, offs_n ) acc = acc / l_i[:, None] tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) def forward(q, k, k_block_id, v, q_scale, k_scale, is_causal=False, tensor_layout="HND", output_dtype=torch.float16): BLOCK_M = 128 BLOCK_N = 64 stage = 3 if is_causal else 1 o = torch.empty(q.shape, dtype=output_dtype, device=q.device) if tensor_layout == "HND": b, h_qo, qo_len, head_dim = q.shape _, h_kv, kv_len, _ = k.shape stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) elif tensor_layout == "NHD": b, qo_len, h_qo, head_dim = q.shape _, kv_len, h_kv, _ = k.shape stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) else: raise ValueError(f"tensor_layout {tensor_layout} not supported") if is_causal: assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" HEAD_DIM_K = head_dim num_kv_groups = h_qo // h_kv grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b ) _attn_fwd[grid]( q, k, k_block_id, v, q_scale, k_scale, o, stride_bz_q, stride_h_q, stride_seq_q, stride_bz_k, stride_h_k, stride_seq_k, stride_bz_v, stride_h_v, stride_seq_v, stride_bz_o, stride_h_o, stride_seq_o, k_block_id.stride(1), k_block_id.stride(2), qo_len, kv_len, h_qo, num_kv_groups, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, STAGE=stage, num_warps=4 if head_dim == 64 else 8, num_stages=4) return o