|
|
""" |
|
|
Implementation of Forgetting Attention. |
|
|
|
|
|
Our code is adapted from https://github.com/FlagOpen/FlagAttention/blob/ee91638dec6da8c00c4113d179f469e0ffcd5852/src/flag_attn/flash.py. The code is modified to implement Forgetting Attention. |
|
|
|
|
|
The original license info from FlagAttention: |
|
|
|
|
|
Copyright 2023 BAAI |
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
you may not use this file except in compliance with the License. |
|
|
You may obtain a copy of the License at |
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
See the License for the specific language governing permissions and |
|
|
limitations under the License. |
|
|
""" |
|
|
import pytest |
|
|
import math |
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
from einops import rearrange |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
__all__ = ["forgetting_attention"] |
|
|
|
|
|
|
|
|
|
|
|
def maybe_contiguous(x): |
|
|
|
|
|
|
|
|
return x.contiguous() if x.stride(-1) != 1 else x |
|
|
|
|
|
def rounded_multiple(a, b): |
|
|
return (a + b - 1) // b * b |
|
|
|
|
|
|
|
|
class ForgettingAttention(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, q, k, v, log_fgate, seq_start, causal, sm_scale, return_log_normalizer): |
|
|
assert causal, "Only causal attention is supported" |
|
|
Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] |
|
|
assert Dq == Dk == Dv, "feature size of q, k, v should be equal" |
|
|
assert Dk in {16, 32, 64, 128}, "We only support head dims in {16, 32, 64, 128}" |
|
|
|
|
|
B, H, M, D = q.shape |
|
|
if seq_start is not None: |
|
|
has_seq_start = True |
|
|
assert seq_start.shape == (B,) |
|
|
else: |
|
|
has_seq_start = False |
|
|
seq_start = torch.zeros((B,), device=q.device, dtype=torch.long) |
|
|
N = k.shape[2] |
|
|
assert log_fgate.shape == (B, H, N) |
|
|
log_fgate = log_fgate.float() |
|
|
if has_seq_start: |
|
|
log_fgate = log_fgate.clone() |
|
|
|
|
|
|
|
|
|
|
|
mask_index = (torch.arange(N, device=q.device)[None, None, :] < seq_start[:, None, None]) |
|
|
mask_index = torch.broadcast_to(mask_index, log_fgate.size()) |
|
|
log_fgate[mask_index] = 0.0 |
|
|
|
|
|
log_lambda = torch.cumsum(log_fgate, dim=-1, dtype=log_fgate.dtype).float() |
|
|
|
|
|
Hk, Hv = k.shape[1], v.shape[1] |
|
|
assert Hk == Hv, "num of heads in k and v should be equal" |
|
|
assert H == Hk, "groupped query attention has not been tested. You can uncomment this if you know what you are doing." |
|
|
assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" |
|
|
num_groups = H // Hk |
|
|
|
|
|
P_SEQ = N - M |
|
|
larger_m = M > N |
|
|
assert (not larger_m), "The key/value tensors must be longer than the query tensor" |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1. / math.sqrt(D) |
|
|
|
|
|
|
|
|
q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) |
|
|
|
|
|
|
|
|
device = torch.cuda.device_of(q) |
|
|
|
|
|
with torch.cuda.device(device): |
|
|
|
|
|
config = get_fwd_config(B, H, M, N, D, causal) |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = config |
|
|
|
|
|
divisible_m = M % BLOCK_M == 0 |
|
|
divisible_n = N % BLOCK_N == 0 |
|
|
|
|
|
grid = (triton.cdiv(M, BLOCK_M), H, B) |
|
|
o = torch.empty_like(q) |
|
|
L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) |
|
|
_fwd_kernel[grid]( |
|
|
q, k, v, log_lambda, seq_start, sm_scale, |
|
|
L, o, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
|
log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), |
|
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
|
|
B, H, M, N, P_SEQ, num_groups, |
|
|
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, |
|
|
IS_CAUSAL=causal, LARGER_M=larger_m, HAS_SEQ_START=has_seq_start, |
|
|
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, |
|
|
num_warps=num_warps, num_stages=num_stages, |
|
|
) |
|
|
|
|
|
|
|
|
ctx.save_for_backward(q, k, v, o, L, log_lambda, seq_start) |
|
|
ctx.sm_scale = sm_scale |
|
|
ctx.causal = causal |
|
|
ctx.has_seq_start = has_seq_start |
|
|
|
|
|
has_extra_return = return_log_normalizer |
|
|
if has_extra_return: |
|
|
outs = ( |
|
|
o, |
|
|
L if return_log_normalizer else None, |
|
|
) |
|
|
return outs |
|
|
return o |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, do, *ignored): |
|
|
q, k, v, o, L, log_lambda, seq_start = ctx.saved_tensors |
|
|
sm_scale = ctx.sm_scale |
|
|
causal = ctx.causal |
|
|
has_seq_start = ctx.has_seq_start |
|
|
|
|
|
B, H, M, D = q.shape |
|
|
N = k.shape[2] |
|
|
Hk = k.shape[1] |
|
|
num_groups = H // Hk |
|
|
P_SEQ = N - M |
|
|
larger_m = M > N |
|
|
|
|
|
if sm_scale is None: |
|
|
sm_scale = 1. / math.sqrt(D) |
|
|
|
|
|
|
|
|
device = torch.cuda.device_of(q) |
|
|
with torch.cuda.device(device): |
|
|
config = get_bwd_config(B, H, M, N, D, causal) |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = config |
|
|
|
|
|
divisible_m = M % BLOCK_M == 0 |
|
|
divisible_n = N % BLOCK_N == 0 |
|
|
|
|
|
delta = torch.empty_like(L) |
|
|
grid = (triton.cdiv(M, BLOCK_M), H, B) |
|
|
_bwd_preprocess[grid]( |
|
|
o, do, |
|
|
delta, |
|
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
|
|
do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
|
|
delta.stride(0), delta.stride(1), delta.stride(2), |
|
|
M, |
|
|
BLOCK_M=BLOCK_M, D_HEAD=D, |
|
|
DIVISIBLE_M=divisible_m, |
|
|
) |
|
|
|
|
|
|
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_kv_config(B, H, M, N, D, causal) |
|
|
divisible_m = M % BLOCK_M == 0 |
|
|
divisible_n = N % BLOCK_N == 0 |
|
|
|
|
|
dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device) |
|
|
dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device) |
|
|
dlog_lambda = torch.empty((B, H, N), dtype=log_lambda.dtype, device=q.device) |
|
|
grid = (triton.cdiv(N, BLOCK_N), H, B) |
|
|
_bwd_kv_kernel[grid]( |
|
|
q, k, v, log_lambda, seq_start, sm_scale, do, |
|
|
dk, dv, dlog_lambda, |
|
|
L, delta, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
|
log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), |
|
|
do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
|
|
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), |
|
|
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), |
|
|
dlog_lambda.stride(0), dlog_lambda.stride(1), dlog_lambda.stride(2), |
|
|
B, H, M, N, P_SEQ, |
|
|
num_groups, |
|
|
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, |
|
|
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, HAS_SEQ_START=has_seq_start, |
|
|
num_stages=num_stages, num_warps=num_warps, |
|
|
) |
|
|
|
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_q_config(B, H, M, N, D, causal) |
|
|
divisible_m = M % BLOCK_M == 0 |
|
|
divisible_n = N % BLOCK_N == 0 |
|
|
dq = torch.zeros_like(q) |
|
|
grid = (triton.cdiv(M, BLOCK_M), H, B) |
|
|
_bwd_q_kernel[grid]( |
|
|
q, k, v, log_lambda, seq_start, sm_scale, do, |
|
|
dq, dlog_lambda, |
|
|
L, delta, |
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
|
log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), |
|
|
do.stride(0), do.stride(1), do.stride(2), do.stride(3), |
|
|
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), |
|
|
dlog_lambda.stride(0), dlog_lambda.stride(1), dlog_lambda.stride(2), |
|
|
B, H, M, N, P_SEQ, |
|
|
num_groups, |
|
|
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, |
|
|
CAUSAL=causal, LARGER_M=larger_m, HAS_SEQ_START=has_seq_start, |
|
|
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, |
|
|
num_stages=num_stages, num_warps = num_warps, |
|
|
) |
|
|
dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2) |
|
|
dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2) |
|
|
dcumsum = torch.cumsum(dlog_lambda, dim=-1, dtype=log_lambda.dtype) |
|
|
dlog_fgate = dlog_lambda + dcumsum[..., -1:] - dcumsum |
|
|
dlog_fgate = dlog_fgate.float() |
|
|
return dq, dk, dv, dlog_fgate, None, None, None, None, None, None, None |
|
|
|
|
|
|
|
|
def forgetting_attention( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
log_fgate: torch.Tensor, |
|
|
*, |
|
|
head_first: bool = False, |
|
|
seq_start: Optional[torch.Tensor] = None, |
|
|
sm_scale: Optional[float] = None, |
|
|
): |
|
|
""" |
|
|
A FlashAttention-based implementation of Forgetting Attention. |
|
|
|
|
|
Note: |
|
|
- We recommand bfloat16/float16 for q, k, v and float32 for log_fgate. float32 for |
|
|
q, k, v is also supported, but the kernel will not use tensor cores if q, k, v are |
|
|
in float32 (which would be slow). |
|
|
- We only support seqlen_q <= seqlen_k |
|
|
- We only support causal attention |
|
|
- Head dimension must be in one of {16, 32, 64, 128} |
|
|
|
|
|
Arguments: |
|
|
- q: (batch_size, seqlen_q, num_heads, head_dim) unless head_first=True. |
|
|
- k: (batch_size, seqlen_k, num_heads, head_dim) unless head_first=True. |
|
|
- v: (batch_size, seqlen_k, num_heads, head_dim) unless head_first=True. |
|
|
- log_fgate: (batch_size, seqlen_k, num_heads) unless head_first=True. |
|
|
This should be the **log** of the forget gates. This is typically the |
|
|
output of torch.nn.functional.logsigmoid. |
|
|
- head_first: if True, the order the num_heads and seqlen_* axis of the all |
|
|
FloatTensor inputs and outputs should be (num_heads, seq_len_*) instead of |
|
|
(seq_len_*, num_heads) |
|
|
- seq_start: If not None, should be LongTensor with shape (batch_size,) |
|
|
and range in [0, seq_len_k). For each batch index batch_id, no attention |
|
|
will be allocated to tokens before the token index seq_start[batch_id]. |
|
|
This is useful for left-padded inputs. |
|
|
- sm_scale: The scaling of attention scores before applying softmax. If |
|
|
None, it defaults to (1.0 / math.sqrt(head_dim)) |
|
|
|
|
|
Returns: |
|
|
out (torch.Tensor): (batch_size, seqlen_q, num_heads, head_dim) unless head_first=True. |
|
|
""" |
|
|
if not head_first: |
|
|
q, k, v = [rearrange(item, "b t h d -> b h t d") for item in (q, k, v)] |
|
|
log_fgate = rearrange(log_fgate, "b t h -> b h t") |
|
|
out = ForgettingAttention.apply(q, k, v, log_fgate, seq_start, True, sm_scale, False) |
|
|
if not head_first: |
|
|
out = rearrange(out, "b h t d -> b t h d") |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_fwd_config(B, H, M, N, D, causal): |
|
|
assert causal |
|
|
if torch.cuda.get_device_capability() == (8, 0): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 4, 4 |
|
|
elif torch.cuda.get_device_capability() == (9, 0): |
|
|
|
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 2, 8 |
|
|
elif torch.cuda.get_device_capability() == (8, 6): |
|
|
if not causal: |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 |
|
|
else: |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 |
|
|
elif torch.cuda.get_device_capability() == (8, 9): |
|
|
|
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
return (BLOCK_M, BLOCK_N, num_stages, num_warps) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _fwd_kernel( |
|
|
Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, |
|
|
L, O, |
|
|
stride_qz, stride_qh, stride_qm, stride_qk, |
|
|
stride_kz, stride_kh, stride_kn, stride_kk, |
|
|
stride_vz, stride_vh, stride_vn, stride_vk, |
|
|
stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, |
|
|
stride_oz, stride_oh, stride_om, stride_ok, |
|
|
Z, H, M, N, P_SEQ, |
|
|
num_groups, |
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, |
|
|
IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, HAS_SEQ_START: tl.constexpr, |
|
|
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, |
|
|
): |
|
|
input_dtype = Q.dtype.element_ty |
|
|
|
|
|
start_m = tl.program_id(0) |
|
|
off_h = tl.program_id(1) |
|
|
off_z = tl.program_id(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log2e: tl.constexpr = 1.4426950408889634 |
|
|
loge2: tl.constexpr = 0.6931471805599453 |
|
|
qk_scale = sm_scale * log2e |
|
|
|
|
|
|
|
|
off_hk = off_h // num_groups |
|
|
Q += off_z * stride_qz + off_h * stride_qh |
|
|
K += off_z * stride_kz + off_hk * stride_kh |
|
|
V += off_z * stride_vz + off_hk * stride_vh |
|
|
LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h |
|
|
O += off_z * stride_oz + off_h * stride_oh |
|
|
L += (off_z * H + off_h) * M |
|
|
|
|
|
offs_m_base = tl.arange(0, BLOCK_M) |
|
|
offs_m = start_m * BLOCK_M + offs_m_base |
|
|
offs_n_base = tl.arange(0, BLOCK_N) |
|
|
offs_k = tl.arange(0, BLOCK_DMODEL) |
|
|
|
|
|
|
|
|
|
|
|
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) |
|
|
log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m) * stride_log_lambda_n |
|
|
o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) |
|
|
l_ptrs = L + offs_m |
|
|
|
|
|
|
|
|
m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) |
|
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
q = tl.load(q_ptrs, cache_modifier=".cg") |
|
|
log_lambda_out = tl.load(log_lambda_out_ptrs, cache_modifier=".cg") |
|
|
else: |
|
|
mask_m = offs_m < M |
|
|
q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") |
|
|
log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m, cache_modifier=".cg") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if IS_CAUSAL: |
|
|
hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) |
|
|
if LARGER_M: |
|
|
hi = tl.maximum(0, hi) |
|
|
else: |
|
|
hi = N |
|
|
|
|
|
offs_n_init = offs_n_base |
|
|
if HAS_SEQ_START: |
|
|
SEQ_START += off_z |
|
|
seq_start = tl.load(SEQ_START) |
|
|
lo = tl.minimum(seq_start, hi) |
|
|
lo = (lo // BLOCK_N) * BLOCK_N |
|
|
offs_n_init += lo |
|
|
else: |
|
|
lo = 0 |
|
|
seq_start = 0 |
|
|
|
|
|
|
|
|
k_ptrs = K + (offs_k[:, None] * stride_kk + offs_n_init[None, :] * stride_kn) |
|
|
v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) |
|
|
log_lambda_in_ptrs = LOG_LAMBDA + (offs_n_init * stride_log_lambda_n) |
|
|
for start_n in range(lo, hi, BLOCK_N): |
|
|
start_n = tl.multiple_of(start_n, BLOCK_N) |
|
|
offs_n = start_n + offs_n_base |
|
|
|
|
|
|
|
|
if DIVISIBLE_N: |
|
|
k = tl.load(k_ptrs, cache_modifier=".cg") |
|
|
v = tl.load(v_ptrs, cache_modifier=".cg") |
|
|
log_lambda_in = tl.load(log_lambda_in_ptrs, cache_modifier=".cg") |
|
|
else: |
|
|
mask_n = offs_n < N |
|
|
k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") |
|
|
v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") |
|
|
log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n, cache_modifier=".cg") |
|
|
|
|
|
|
|
|
|
|
|
s = tl.dot(q, k, input_precision="ieee") * qk_scale |
|
|
decay_bias = log_lambda_out[:, None] - log_lambda_in[None, :] |
|
|
s += decay_bias * log2e |
|
|
|
|
|
if not DIVISIBLE_N: |
|
|
s = tl.where(mask_n[None, :], s, float("-inf")) |
|
|
if IS_CAUSAL: |
|
|
causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] |
|
|
s = tl.where(causal_mask, s, float("-inf")) |
|
|
if HAS_SEQ_START: |
|
|
s = tl.where(offs_n[None, :] >= seq_start, s, float("-inf")) |
|
|
|
|
|
|
|
|
|
|
|
m_i_new = tl.maximum(m_i, tl.max(s, 1)) |
|
|
alpha = tl.math.exp2((m_i - m_i_new)) |
|
|
p = tl.math.exp2(s - m_i_new[:, None]) |
|
|
|
|
|
|
|
|
p_sum = tl.sum(p, 1) |
|
|
|
|
|
|
|
|
|
|
|
acc *= alpha[:, None] |
|
|
acc += tl.dot(p.to(input_dtype), v, input_precision="ieee") |
|
|
|
|
|
|
|
|
l_i = l_i * alpha + p_sum |
|
|
m_i = m_i_new |
|
|
|
|
|
k_ptrs += BLOCK_N * stride_kn |
|
|
v_ptrs += BLOCK_N * stride_vn |
|
|
log_lambda_in_ptrs += BLOCK_N * stride_log_lambda_n |
|
|
|
|
|
|
|
|
if IS_CAUSAL and (LARGER_M or HAS_SEQ_START): |
|
|
is_empty_line = (offs_m + P_SEQ) < seq_start |
|
|
acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) |
|
|
l = tl.where(is_empty_line, float("-inf"), m_i * loge2 + tl.log(l_i)) |
|
|
else: |
|
|
acc = acc * (1.0 / l_i[:, None]) |
|
|
l = m_i * loge2 + tl.log(l_i) |
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
tl.store(l_ptrs, l, cache_modifier=".cg") |
|
|
tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") |
|
|
else: |
|
|
tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg") |
|
|
tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_bwd_config(B, H, M, N, D, causal): |
|
|
if torch.cuda.get_device_capability() == (9, 0): |
|
|
if not causal: |
|
|
BLOCK_M = 128 if D <= 64 else 64 |
|
|
BLOCK_N = 64 |
|
|
num_stages = 2 |
|
|
num_warps = 4 |
|
|
else: |
|
|
BLOCK_M = 64 |
|
|
BLOCK_N = 64 |
|
|
num_stages = 3 if D <= 64 else 2 |
|
|
num_warps = 4 |
|
|
elif torch.cuda.get_device_capability() == (8, 0): |
|
|
if not causal: |
|
|
BLOCK_M = 128 if D <= 64 else 64 |
|
|
BLOCK_N = 64 |
|
|
num_stages = 2 |
|
|
num_warps = 4 |
|
|
else: |
|
|
BLOCK_M = 64 |
|
|
BLOCK_N = 64 |
|
|
num_stages = 3 if D <= 64 else 2 |
|
|
num_warps = 4 |
|
|
elif torch.cuda.get_device_capability() == (8, 6): |
|
|
if not causal: |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8 |
|
|
else: |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 |
|
|
return (BLOCK_M, BLOCK_N, num_stages, num_warps) |
|
|
|
|
|
def get_bwd_kv_config(B, H, M, N, D, causal): |
|
|
assert causal |
|
|
if torch.cuda.get_device_capability() == (8, 0): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 4, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 128, 4, 8 |
|
|
elif torch.cuda.get_device_capability() == (8, 6): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 |
|
|
elif torch.cuda.get_device_capability() == (8, 9): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 128, 4, 8 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 128, 2, 8 |
|
|
elif torch.cuda.get_device_capability() == (9, 0): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
return (BLOCK_M, BLOCK_N, num_stages, num_warps) |
|
|
|
|
|
def get_bwd_q_config(B, H, M, N, D, causal): |
|
|
assert causal |
|
|
if torch.cuda.get_device_capability() == (8, 0): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 8 |
|
|
elif torch.cuda.get_device_capability() == (8, 6): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 |
|
|
elif torch.cuda.get_device_capability() == (8, 9): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 4, 4 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 |
|
|
elif torch.cuda.get_device_capability() == (9, 0): |
|
|
if D <= 64: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 4, 8 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 2, 8 |
|
|
else: |
|
|
BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 |
|
|
return (BLOCK_M, BLOCK_N, num_stages, num_warps) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _bwd_preprocess( |
|
|
Out, DO, |
|
|
Delta, |
|
|
stride_oz, stride_oh, stride_om, stride_ok, |
|
|
stride_doz, stride_doh, stride_dom, stride_dok, |
|
|
stride_dz, stride_dh, stride_dm, |
|
|
M, |
|
|
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, |
|
|
DIVISIBLE_M: tl.constexpr, |
|
|
): |
|
|
off_h = tl.program_id(1) |
|
|
off_z = tl.program_id(2) |
|
|
Out += off_z * stride_oz + off_h * stride_oh |
|
|
DO += off_z * stride_doz + off_h * stride_doh |
|
|
Delta += off_z * stride_dz + off_h * stride_dh |
|
|
|
|
|
|
|
|
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) |
|
|
off_n = tl.arange(0, D_HEAD) |
|
|
|
|
|
|
|
|
o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok |
|
|
do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok |
|
|
|
|
|
if DIVISIBLE_M: |
|
|
o = tl.load(o_ptrs).to(tl.float32) |
|
|
do = tl.load(do_ptrs).to(tl.float32) |
|
|
else: |
|
|
mask_m = off_m < M |
|
|
o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) |
|
|
do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) |
|
|
|
|
|
|
|
|
delta = tl.sum(o * do, axis=1) |
|
|
|
|
|
|
|
|
d_ptrs = Delta + off_m * stride_dm |
|
|
if DIVISIBLE_M: |
|
|
tl.store(d_ptrs, delta) |
|
|
else: |
|
|
tl.store(d_ptrs, delta, mask=mask_m) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _bwd_kv_kernel( |
|
|
Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, DO, |
|
|
DK, DV, DLOG_LAMBDA, |
|
|
L, |
|
|
D, |
|
|
stride_qz, stride_qh, stride_qm, stride_qk, |
|
|
stride_kz, stride_kh, stride_kn, stride_kk, |
|
|
stride_vz, stride_vh, stride_vn, stride_vk, |
|
|
stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, |
|
|
stride_doz, stride_doh, stride_dom, stride_dok, |
|
|
stride_dkz, stride_dkh, stride_dkn, stride_dkk, |
|
|
stride_dvz, stride_dvh, stride_dvn, stride_dvk, |
|
|
stride_dlog_lambda_z, stride_dlog_lambda_h, stride_dlog_lambda_n, |
|
|
Z, H, M, N, P_SEQ, |
|
|
num_groups, |
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, |
|
|
CAUSAL: tl.constexpr, |
|
|
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, HAS_SEQ_START: tl.constexpr, |
|
|
): |
|
|
input_dtype = Q.dtype.element_ty |
|
|
|
|
|
start_n = tl.program_id(0) |
|
|
off_h = tl.program_id(1) |
|
|
off_z = tl.program_id(2) |
|
|
log2e: tl.constexpr = 1.4426950408889634 |
|
|
qk_scale = sm_scale * log2e |
|
|
|
|
|
|
|
|
off_hk = off_h // num_groups |
|
|
Q += off_z * stride_qz + off_h * stride_qh |
|
|
K += off_z * stride_kz + off_hk * stride_kh |
|
|
V += off_z * stride_vz + off_hk * stride_vh |
|
|
LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h |
|
|
DO += off_z * stride_doz + off_h * stride_doh |
|
|
|
|
|
|
|
|
DK += off_z * stride_dkz + off_h * stride_dkh |
|
|
DV += off_z * stride_dvz + off_h * stride_dvh |
|
|
DLOG_LAMBDA += off_z * stride_dlog_lambda_z + off_h * stride_dlog_lambda_h |
|
|
|
|
|
|
|
|
D += (off_z * H + off_h) * M |
|
|
L += (off_z * H + off_h) * M |
|
|
|
|
|
if CAUSAL: |
|
|
lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) |
|
|
lo = (lo // BLOCK_M) * BLOCK_M |
|
|
else: |
|
|
lo = 0 |
|
|
|
|
|
offs_m_init = lo + tl.arange(0, BLOCK_M) |
|
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
offs_m_base = tl.arange(0, BLOCK_M) |
|
|
offs_k = tl.arange(0, BLOCK_DMODEL) |
|
|
|
|
|
|
|
|
q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) |
|
|
log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m_init) * stride_log_lambda_n |
|
|
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) |
|
|
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) |
|
|
log_lambda_in_ptrs = LOG_LAMBDA + (offs_n * stride_log_lambda_n) |
|
|
do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) |
|
|
|
|
|
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) |
|
|
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) |
|
|
dlog_lambda_in_ptrs = DLOG_LAMBDA + (offs_n * stride_dlog_lambda_n) |
|
|
|
|
|
|
|
|
if DIVISIBLE_N: |
|
|
v = tl.load(v_ptrs) |
|
|
k = tl.load(k_ptrs) |
|
|
log_lambda_in = tl.load(log_lambda_in_ptrs) |
|
|
else: |
|
|
mask_n = offs_n < N |
|
|
v = tl.load(v_ptrs, mask=mask_n[:, None]) |
|
|
k = tl.load(k_ptrs, mask=mask_n[:, None]) |
|
|
log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n) |
|
|
|
|
|
|
|
|
if HAS_SEQ_START: |
|
|
SEQ_START += off_z |
|
|
seq_start = tl.load(SEQ_START) |
|
|
hi = tl.where(start_n * BLOCK_N + BLOCK_N >= seq_start - 1, M, lo) |
|
|
else: |
|
|
hi = M |
|
|
|
|
|
|
|
|
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
|
|
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
|
|
dlog_lambda_in = tl.zeros([BLOCK_N], dtype=tl.float32) |
|
|
|
|
|
|
|
|
for start_m in range(lo, hi, BLOCK_M): |
|
|
start_m = tl.multiple_of(start_m, BLOCK_M) |
|
|
offs_m = start_m + offs_m_base |
|
|
causal_mask = (P_SEQ + offs_m[None, :]) >= (offs_n[:, None]) |
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
q = tl.load(q_ptrs) |
|
|
log_lambda_out = tl.load(log_lambda_out_ptrs) |
|
|
else: |
|
|
mask_m = offs_m < M |
|
|
valid_mask = mask_m[None, :] |
|
|
q = tl.load(q_ptrs, mask=mask_m[:, None]) |
|
|
log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m) |
|
|
|
|
|
|
|
|
sT = tl.dot(k, tl.trans(q), input_precision="ieee") * qk_scale |
|
|
decay_bias = log_lambda_out[None, :] - log_lambda_in[:, None] |
|
|
sT += decay_bias * log2e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
l = tl.load(L + offs_m) |
|
|
else: |
|
|
l = tl.load(L + offs_m, mask=mask_m) |
|
|
pT = tl.math.exp2(sT - l[None, :] * log2e) |
|
|
|
|
|
if not DIVISIBLE_M: |
|
|
pT = tl.where(valid_mask, pT, 0.0) |
|
|
if CAUSAL: |
|
|
pT = tl.where(causal_mask, pT, 0.0) |
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
do = tl.load(do_ptrs) |
|
|
else: |
|
|
do = tl.load(do_ptrs, mask=mask_m[:, None]) |
|
|
|
|
|
|
|
|
dv += tl.dot(pT.to(input_dtype), do, input_precision="ieee") |
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
delta = tl.load(D + offs_m) |
|
|
else: |
|
|
delta = tl.load(D + offs_m, mask=mask_m) |
|
|
|
|
|
dpT = tl.dot(v, tl.trans(do), input_precision="ieee") |
|
|
|
|
|
|
|
|
|
|
|
dsT = pT * (dpT - delta[None, :]) |
|
|
|
|
|
if not DIVISIBLE_M: |
|
|
dsT = tl.where(valid_mask, dsT, 0.0) |
|
|
if CAUSAL: |
|
|
dsT = tl.where(causal_mask, dsT, 0.0) |
|
|
|
|
|
|
|
|
dk += tl.dot(dsT.to(input_dtype), q, input_precision="ieee") |
|
|
dlog_lambda_in += -tl.sum(dsT, axis=1) |
|
|
|
|
|
|
|
|
q_ptrs += BLOCK_M * stride_qm |
|
|
log_lambda_out_ptrs += BLOCK_M * stride_log_lambda_n |
|
|
do_ptrs += BLOCK_M * stride_dom |
|
|
|
|
|
dk *= sm_scale |
|
|
if HAS_SEQ_START: |
|
|
|
|
|
seq_mask = (offs_n >= seq_start) |
|
|
dk = tl.where(seq_mask[:, None], dk, 0.0) |
|
|
dv = tl.where(seq_mask[:, None], dv, 0.0) |
|
|
dlog_lambda_in = tl.where(seq_mask, dlog_lambda_in, 0.0) |
|
|
if DIVISIBLE_N: |
|
|
tl.store(dk_ptrs, dk.to(input_dtype)) |
|
|
tl.store(dv_ptrs, dv.to(input_dtype)) |
|
|
tl.store(dlog_lambda_in_ptrs, dlog_lambda_in.to(tl.float32)) |
|
|
else: |
|
|
tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) |
|
|
tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) |
|
|
tl.store(dlog_lambda_in_ptrs, dlog_lambda_in.to(tl.float32), mask=mask_n) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _bwd_q_kernel( |
|
|
Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, DO, |
|
|
DQ, DLOG_LAMBDA, |
|
|
L, |
|
|
D, |
|
|
stride_qz, stride_qh, stride_qm, stride_qk, |
|
|
stride_kz, stride_kh, stride_kn, stride_kk, |
|
|
stride_vz, stride_vh, stride_vn, stride_vk, |
|
|
stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, |
|
|
stride_doz, stride_doh, stride_dom, stride_dok, |
|
|
stride_dqz, stride_dqh, stride_dqm, stride_dqk, |
|
|
stride_dlog_lambda_z, stride_dlog_lambda_h, stride_dlog_lambda_n, |
|
|
Z, H, M, N, P_SEQ, |
|
|
num_groups, |
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, |
|
|
CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, HAS_SEQ_START: tl.constexpr, |
|
|
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, |
|
|
): |
|
|
input_dtype = Q.dtype.element_ty |
|
|
|
|
|
start_m = tl.program_id(0) |
|
|
off_h = tl.program_id(1) |
|
|
off_z = tl.program_id(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log2e: tl.constexpr = 1.4426950408889634 |
|
|
qk_scale = sm_scale * log2e |
|
|
|
|
|
|
|
|
off_hk = off_h // num_groups |
|
|
Q += off_z * stride_qz + off_h * stride_qh |
|
|
K += off_z * stride_kz + off_hk * stride_kh |
|
|
V += off_z * stride_vz + off_hk * stride_vh |
|
|
LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h |
|
|
DO += off_z * stride_doz + off_h * stride_doh |
|
|
D += (off_z * H + off_h) * M |
|
|
L += (off_z * H + off_h) * M |
|
|
|
|
|
|
|
|
DQ += off_z * stride_dqz + off_h * stride_dqh |
|
|
DLOG_LAMBDA += off_z * stride_dlog_lambda_z + off_h * stride_dlog_lambda_h |
|
|
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
|
offs_k = tl.arange(0, BLOCK_DMODEL) |
|
|
|
|
|
|
|
|
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) |
|
|
log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m) * stride_log_lambda_n |
|
|
|
|
|
dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) |
|
|
dlog_lambda_out_ptrs = DLOG_LAMBDA + (P_SEQ + offs_m) * stride_dlog_lambda_n |
|
|
do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) |
|
|
|
|
|
|
|
|
d_ptrs = D + offs_m |
|
|
l_ptrs = L + offs_m |
|
|
|
|
|
|
|
|
if DIVISIBLE_M: |
|
|
q = tl.load(q_ptrs) |
|
|
do = tl.load(do_ptrs) |
|
|
delta = tl.load(d_ptrs) |
|
|
l = tl.load(l_ptrs) |
|
|
log_lambda_out = tl.load(log_lambda_out_ptrs) |
|
|
else: |
|
|
mask_m = offs_m < M |
|
|
q = tl.load(q_ptrs, mask=mask_m[:, None]) |
|
|
do = tl.load(do_ptrs, mask=mask_m[:, None]) |
|
|
delta = tl.load(d_ptrs, mask=mask_m) |
|
|
l = tl.load(l_ptrs, mask=mask_m) |
|
|
log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m) |
|
|
|
|
|
|
|
|
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
|
dlog_lambda_out = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
|
|
|
|
|
|
|
|
|
|
if CAUSAL: |
|
|
hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) |
|
|
if LARGER_M: |
|
|
hi = tl.maximum(0, hi) |
|
|
else: |
|
|
hi = N |
|
|
|
|
|
offs_n_base = tl.arange(0, BLOCK_N) |
|
|
offs_n_init = offs_n_base |
|
|
if HAS_SEQ_START: |
|
|
SEQ_START += off_z |
|
|
seq_start = tl.load(SEQ_START) |
|
|
lo = tl.minimum(seq_start, hi) |
|
|
lo = (lo // BLOCK_N) * BLOCK_N |
|
|
offs_n_init += lo |
|
|
else: |
|
|
lo = 0 |
|
|
k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) |
|
|
v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) |
|
|
log_lambda_in_ptrs = LOG_LAMBDA + (offs_n_init * stride_log_lambda_n) |
|
|
|
|
|
|
|
|
for start_n in range(lo, hi, BLOCK_N): |
|
|
offs_n = start_n + offs_n_base |
|
|
|
|
|
|
|
|
if DIVISIBLE_N: |
|
|
v = tl.load(v_ptrs) |
|
|
k = tl.load(k_ptrs) |
|
|
log_lambda_in = tl.load(log_lambda_in_ptrs) |
|
|
else: |
|
|
mask_n = offs_n < N |
|
|
v = tl.load(v_ptrs, mask=mask_n[:, None]) |
|
|
k = tl.load(k_ptrs, mask=mask_n[:, None]) |
|
|
log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n) |
|
|
|
|
|
|
|
|
|
|
|
if not DIVISIBLE_N: |
|
|
valid_mask = mask_n[None, :] |
|
|
if CAUSAL: |
|
|
causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) |
|
|
|
|
|
s = tl.dot(q, tl.trans(k), input_precision="ieee") * qk_scale |
|
|
decay_bias = log_lambda_out[:, None] - log_lambda_in[None, :] |
|
|
s += decay_bias * log2e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p = tl.math.exp2(s - l[:, None] * log2e) |
|
|
|
|
|
|
|
|
|
|
|
dp = tl.dot(do.to(input_dtype), tl.trans(v), input_precision="ieee") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds = p * (dp - delta[:, None]) |
|
|
|
|
|
|
|
|
if not DIVISIBLE_N: |
|
|
ds = tl.where(valid_mask, ds, 0.0) |
|
|
if CAUSAL: |
|
|
ds = tl.where(causal_mask, ds, 0.0) |
|
|
if HAS_SEQ_START: |
|
|
ds = tl.where(offs_n[None, :] >= seq_start, ds, 0.0) |
|
|
|
|
|
dq += tl.dot(ds.to(input_dtype), k, input_precision="ieee") |
|
|
dlog_lambda_out += tl.sum(ds, axis=1) |
|
|
|
|
|
|
|
|
k_ptrs += BLOCK_N * stride_kn |
|
|
v_ptrs += BLOCK_N * stride_vn |
|
|
log_lambda_in_ptrs += BLOCK_N * stride_log_lambda_n |
|
|
|
|
|
dq *= sm_scale |
|
|
if DIVISIBLE_M: |
|
|
tmp = tl.load(dlog_lambda_out_ptrs) |
|
|
else: |
|
|
tmp = tl.load(dlog_lambda_out_ptrs, mask=mask_m) |
|
|
dlog_lambda_out += tmp |
|
|
if DIVISIBLE_M: |
|
|
tl.store(dq_ptrs, dq.to(input_dtype)) |
|
|
tl.store(dlog_lambda_out_ptrs, dlog_lambda_out) |
|
|
else: |
|
|
tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None]) |
|
|
tl.store(dlog_lambda_out_ptrs, dlog_lambda_out, mask=mask_m) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("Z, H, M, N, HEAD_DIM", [(4, 2, 1020, 2098, 64), (4, 2, 1024, 2048, 64)]) |
|
|
@pytest.mark.parametrize("causal", [True]) |
|
|
def test_op(Z, H, M, N, HEAD_DIM, causal, dtype=torch.bfloat16): |
|
|
torch.manual_seed(24) |
|
|
q = (torch.empty((Z, H, M, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) |
|
|
k = (torch.empty((Z, H, N, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) |
|
|
v = (torch.empty((Z, H, N, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) |
|
|
fgate_logit = torch.empty((Z, H, N), dtype=torch.float32, device="cuda").uniform_(5, 10) |
|
|
log_fgate = torch.nn.functional.logsigmoid(fgate_logit).requires_grad_() |
|
|
seq_start = torch.randint(low=0, high=N, size=(Z,), dtype=torch.long, device="cuda") |
|
|
|
|
|
|
|
|
sm_scale = 0.5 |
|
|
dout = torch.randn_like(q) |
|
|
|
|
|
P_SEQ = N - M |
|
|
mask = torch.tril(torch.ones((M, N), device="cuda"), diagonal=P_SEQ) |
|
|
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale |
|
|
p = p.float() |
|
|
|
|
|
log_lambda = torch.cumsum(log_fgate, dim=-1) |
|
|
decay_bias = log_lambda[..., -M:, None] - log_lambda[..., None, :] |
|
|
p = p + decay_bias |
|
|
if causal: |
|
|
p[:, :, mask == 0] = float("-inf") |
|
|
|
|
|
attention_mask = torch.arange(N, device="cuda") < seq_start[:, None, None, None] |
|
|
p = torch.where(attention_mask, float("-inf"), p) |
|
|
p = torch.softmax(p.float(), dim=-1).to(dtype) |
|
|
p = p.clone() |
|
|
p[torch.isnan(p)] = 0.0 |
|
|
|
|
|
ref_out = torch.matmul(p, v) |
|
|
ref_out.backward(dout) |
|
|
ref_dv, v.grad = v.grad.clone(), None |
|
|
ref_dk, k.grad = k.grad.clone(), None |
|
|
ref_dq, q.grad = q.grad.clone(), None |
|
|
ref_dlog_fgate, log_fgate.grad = log_fgate.grad.clone(), None |
|
|
|
|
|
tri_out = forgetting_attention(q, k, v, log_fgate, head_first=True, seq_start=seq_start, sm_scale=sm_scale) |
|
|
tri_out = tri_out.to(dtype) |
|
|
|
|
|
tri_out.backward(dout) |
|
|
tri_dv, v.grad = v.grad.clone(), None |
|
|
tri_dk, k.grad = k.grad.clone(), None |
|
|
tri_dq, q.grad = q.grad.clone(), None |
|
|
tri_dlog_fgate, log_fgate.grad = log_fgate.grad.clone(), None |
|
|
|
|
|
|
|
|
assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0), (ref_out - tri_out).abs().max() |
|
|
rtol = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol), (ref_dv - tri_dv).abs().max() |
|
|
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol), (ref_dk - tri_dk).abs().max() |
|
|
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol), (ref_dq - tri_dq).abs().max() |
|
|
assert torch.allclose(ref_dlog_fgate, tri_dlog_fgate, atol=1e-2, rtol=rtol), (ref_dlog_fgate - tri_dlog_fgate).abs().max() |
|
|
|
|
|
try: |
|
|
from flash_attn.flash_attn_interface import \ |
|
|
flash_attn_qkvpacked_func as flash_attn_func |
|
|
HAS_FLASH = True |
|
|
except BaseException: |
|
|
HAS_FLASH = False |
|
|
|
|
|
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') |
|
|
BATCH, N_HEADS, HEAD_DIM = 4, 32, 128 |
|
|
|
|
|
configs = [] |
|
|
for mode in ["fwd", "bwd"]: |
|
|
|
|
|
|
|
|
for causal in [True]: |
|
|
if mode == "bwd" and not causal: |
|
|
continue |
|
|
configs.append( |
|
|
triton.testing.Benchmark( |
|
|
x_names=["N_CTX"], |
|
|
|
|
|
x_vals=[2**i for i in range(14, 15)], |
|
|
line_arg="provider", |
|
|
|
|
|
|
|
|
line_vals=["flag"] + (["flash"] if HAS_FLASH else []), |
|
|
line_names=["Flag"] + (["Flash-2"] if HAS_FLASH else []), |
|
|
styles=[("red", "-"), ("blue", "-"), ("green", "-")], |
|
|
ylabel="ms", |
|
|
plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", |
|
|
args={ |
|
|
"H": N_HEADS, |
|
|
"BATCH": BATCH, |
|
|
"HEAD_DIM": HEAD_DIM, |
|
|
"mode": mode, |
|
|
"causal": causal, |
|
|
}, |
|
|
)) |
|
|
|
|
|
|
|
|
@triton.testing.perf_report(configs) |
|
|
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): |
|
|
assert mode in ["fwd", "bwd"] |
|
|
warmup = 25 |
|
|
rep = 100 |
|
|
dtype = torch.bfloat16 |
|
|
if "flag" in provider: |
|
|
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) |
|
|
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) |
|
|
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) |
|
|
fgate_logit = torch.empty((BATCH, H, N_CTX), dtype=torch.float32, device="cuda").uniform_(5, 10) |
|
|
log_fgate = torch.nn.functional.logsigmoid(fgate_logit).requires_grad_() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sm_scale = 1.3 |
|
|
fn = lambda: forgetting_attention(q, k, v, log_fgate, head_first=True, sm_scale=sm_scale) |
|
|
if mode == "bwd": |
|
|
o = fn() |
|
|
do = torch.randn_like(o) |
|
|
fn = lambda: o.backward(do, retain_graph=True) |
|
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
|
if provider == "flash": |
|
|
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) |
|
|
fn = lambda: flash_attn_func(qkv, causal=causal) |
|
|
if mode == "bwd": |
|
|
o = fn() |
|
|
do = torch.randn_like(o) |
|
|
fn = lambda: o.backward(do, retain_graph=True) |
|
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
|
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM |
|
|
total_flops = 2 * flops_per_matmul |
|
|
if causal: |
|
|
total_flops *= 0.5 |
|
|
if mode == "bwd": |
|
|
total_flops *= 2.5 |
|
|
return total_flops / ms * 1e-9 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
bench_flash_attention.run(save_path=".", print_data=True) |
|
|
|