ColabWan / postprocessing /flashvsr /sparse_sage /sparse_int8_attn.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
7.46 kB
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch, math
import triton
import triton.language as tl
import torch.nn.functional as F
@triton.jit
def _attn_fwd_inner(acc, l_i, old_m, q, q_scale, kv_len,
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += lo // BLOCK_N
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
elif STAGE == 3:
lo, hi = 0, kv_len
for start_n in range(lo, hi, BLOCK_N):
kbid = tl.load(K_bid_ptr + start_n//BLOCK_N)
if kbid:
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask = k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk -= new_m[:, None]
else:
local_m = tl.max(qk, 1)
new_m = tl.maximum(old_m, local_m)
qk = qk - new_m[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(old_m - new_m)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
old_m = new_m
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += 1
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, old_m
@triton.jit
def _attn_fwd(Q, K, K_blkid, V, Q_scale, K_scale, Out,
stride_qz, stride_qh, stride_qn,
stride_kz, stride_kh, stride_kn,
stride_vz, stride_vh, stride_vn,
stride_oz, stride_oh, stride_on,
stride_kbidq, stride_kbidk,
qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N)
k_bid_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * stride_kbidq
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :]
Q_scale_ptr = Q_scale + q_scale_offset + start_m
K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None]
K_scale_ptr = K_scale + k_scale_offset
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :]
O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :]
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
4 - STAGE, offs_m, offs_n
)
if STAGE != 1:
acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
2, offs_m, offs_n
)
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len))
def forward(q, k, k_block_id, v, q_scale, k_scale, is_causal=False, tensor_layout="HND", output_dtype=torch.float16):
BLOCK_M = 128
BLOCK_N = 64
stage = 3 if is_causal else 1
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
else:
raise ValueError(f"tensor_layout {tensor_layout} not supported")
if is_causal:
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b )
_attn_fwd[grid](
q, k, k_block_id, v, q_scale, k_scale, o,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_v, stride_h_v, stride_seq_v,
stride_bz_o, stride_h_o, stride_seq_o,
k_block_id.stride(1), k_block_id.stride(2),
qo_len, kv_len,
h_qo, num_kv_groups,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4)
return o