Add files using upload-large-folder tool
Browse files- all_results.json +10 -0
- fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/parallel_path_fwd.cpython-310.pyc +0 -0
- fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
- fla3/ops/rwkv7/fused_recurrent.py +328 -0
- fla3/ops/simple_gla/__pycache__/__init__.cpython-312.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
- fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc +0 -0
- fla3/ops/simple_gla/parallel.py +732 -0
- fla3/ops/ttt/fused_chunk.py +835 -0
- fla3/ops/ttt/naive.py +126 -0
- fla3/ops/utils/__init__.py +54 -0
- fla3/ops/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/asm.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/index.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/index.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/matmul.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/op.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/op.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/pack.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/pack.cpython-312.pyc +0 -0
- fla3/ops/utils/__pycache__/softmax.cpython-310.pyc +0 -0
- fla3/ops/utils/__pycache__/solve_tril.cpython-310.pyc +0 -0
- fla3/ops/utils/asm.py +17 -0
- fla3/ops/utils/cumsum.py +414 -0
- fla3/ops/utils/index.py +83 -0
- fla3/ops/utils/logcumsumexp.py +52 -0
- fla3/ops/utils/logsumexp.py +80 -0
- fla3/ops/utils/matmul.py +245 -0
- fla3/ops/utils/pack.py +208 -0
- fla3/ops/utils/pooling.py +207 -0
- fla3/ops/utils/softmax.py +111 -0
- fla3/ops/utils/solve_tril.py +276 -0
- flame/__init__.py +0 -0
- flame/__pycache__/__init__.cpython-310.pyc +0 -0
- flame/__pycache__/__init__.cpython-312.pyc +0 -0
- flame/__pycache__/data.cpython-310.pyc +0 -0
- flame/__pycache__/data.cpython-312.pyc +0 -0
- flame/__pycache__/logging.cpython-312.pyc +0 -0
- flame/__pycache__/parser.cpython-310.pyc +0 -0
- flame/__pycache__/parser.cpython-312.pyc +0 -0
- flame/data.py +246 -0
- flame/logging.py +118 -0
- flame/parser.py +94 -0
all_results.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"epoch": 0.7839559871158865,
|
| 3 |
+
"num_tokens": 104891154432,
|
| 4 |
+
"throughput": 12525.363357673923,
|
| 5 |
+
"total_flos": 8.619164947133655e+20,
|
| 6 |
+
"train_loss": 9.173326368447839,
|
| 7 |
+
"train_runtime": 261696.889,
|
| 8 |
+
"train_samples_per_second": 195.709,
|
| 9 |
+
"train_steps_per_second": 0.191
|
| 10 |
+
}
|
fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/parallel_path_fwd.cpython-310.pyc
ADDED
|
Binary file (4.95 kB). View file
|
|
|
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc
ADDED
|
Binary file (2.26 kB). View file
|
|
|
fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
fla3/ops/rwkv7/fused_recurrent.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule
|
| 13 |
+
from fla.ops.utils.op import exp
|
| 14 |
+
from fla.utils import input_guard, use_cuda_graph
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 19 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 20 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 21 |
+
})
|
| 22 |
+
@triton.autotune(
|
| 23 |
+
configs=[
|
| 24 |
+
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
| 25 |
+
for BV in [16, 32, 64]
|
| 26 |
+
for num_warps in [2, 4, 8, 16, 32]
|
| 27 |
+
for num_stages in [2, 3, 4]
|
| 28 |
+
],
|
| 29 |
+
key=['BK'],
|
| 30 |
+
use_cuda_graph=use_cuda_graph,
|
| 31 |
+
)
|
| 32 |
+
@triton.jit(do_not_specialize=['T'])
|
| 33 |
+
def fused_recurrent_rwkv7_fwd_kernel(
|
| 34 |
+
r,
|
| 35 |
+
w,
|
| 36 |
+
k,
|
| 37 |
+
v,
|
| 38 |
+
kk,
|
| 39 |
+
a,
|
| 40 |
+
o,
|
| 41 |
+
h0,
|
| 42 |
+
ht,
|
| 43 |
+
cu_seqlens,
|
| 44 |
+
scale,
|
| 45 |
+
T,
|
| 46 |
+
B: tl.constexpr,
|
| 47 |
+
H: tl.constexpr,
|
| 48 |
+
K: tl.constexpr,
|
| 49 |
+
V: tl.constexpr,
|
| 50 |
+
BK: tl.constexpr,
|
| 51 |
+
BV: tl.constexpr,
|
| 52 |
+
REVERSE: tl.constexpr,
|
| 53 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 54 |
+
STORE_FINAL_STATE: tl.constexpr,
|
| 55 |
+
IS_VARLEN: tl.constexpr,
|
| 56 |
+
IS_DECODE: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
|
| 59 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 60 |
+
|
| 61 |
+
if IS_VARLEN:
|
| 62 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
| 63 |
+
T = eos - bos
|
| 64 |
+
else:
|
| 65 |
+
bos, eos = i_n * T, i_n * T + T
|
| 66 |
+
|
| 67 |
+
o_k = tl.arange(0, BK)
|
| 68 |
+
o_v = i_v * BV + tl.arange(0, BV)
|
| 69 |
+
p_r = r + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 70 |
+
p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 71 |
+
p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 72 |
+
p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
|
| 73 |
+
p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 74 |
+
p_kk = kk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
|
| 75 |
+
|
| 76 |
+
p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
|
| 77 |
+
|
| 78 |
+
mask_k = o_k < K
|
| 79 |
+
mask_v = o_v < V
|
| 80 |
+
mask_h = mask_k[None, :] & mask_v[:, None]
|
| 81 |
+
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
| 82 |
+
|
| 83 |
+
if USE_INITIAL_STATE:
|
| 84 |
+
p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
|
| 85 |
+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
| 86 |
+
|
| 87 |
+
if IS_DECODE:
|
| 88 |
+
b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
|
| 89 |
+
b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
|
| 90 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 91 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 92 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 93 |
+
b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
|
| 94 |
+
b_act_a = -b_kk
|
| 95 |
+
b_b = b_kk * b_a
|
| 96 |
+
|
| 97 |
+
tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
|
| 98 |
+
b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
|
| 99 |
+
b_o = tl.sum(b_h * b_r[None, :], axis=1)
|
| 100 |
+
|
| 101 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
| 102 |
+
else:
|
| 103 |
+
for _ in range(0, T):
|
| 104 |
+
b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
|
| 105 |
+
b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
|
| 106 |
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
| 107 |
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
| 108 |
+
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
|
| 109 |
+
b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
|
| 110 |
+
b_act_a = -b_kk
|
| 111 |
+
b_b = b_kk * b_a
|
| 112 |
+
|
| 113 |
+
tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
|
| 114 |
+
b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
|
| 115 |
+
b_o = tl.sum(b_h * b_r[None, :], axis=1)
|
| 116 |
+
|
| 117 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
| 118 |
+
p_r += (-1 if REVERSE else 1) * H*K
|
| 119 |
+
p_w += (-1 if REVERSE else 1) * H*K
|
| 120 |
+
p_k += (-1 if REVERSE else 1) * H*K
|
| 121 |
+
p_v += (-1 if REVERSE else 1) * H*V
|
| 122 |
+
p_a += (-1 if REVERSE else 1) * H*K
|
| 123 |
+
p_kk += (-1 if REVERSE else 1) * H*K
|
| 124 |
+
p_o += (-1 if REVERSE else 1) * H*V
|
| 125 |
+
|
| 126 |
+
if STORE_FINAL_STATE:
|
| 127 |
+
p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
|
| 128 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@input_guard
|
| 132 |
+
def fused_recurrent_rwkv7_fwd(
|
| 133 |
+
r: torch.Tensor,
|
| 134 |
+
w: torch.Tensor,
|
| 135 |
+
k: torch.Tensor,
|
| 136 |
+
v: torch.Tensor,
|
| 137 |
+
kk: torch.Tensor,
|
| 138 |
+
a: torch.Tensor,
|
| 139 |
+
scale: Optional[float] = 1.0,
|
| 140 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 141 |
+
output_final_state: bool = False,
|
| 142 |
+
reverse: bool = False,
|
| 143 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 144 |
+
):
|
| 145 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 146 |
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
| 147 |
+
BK = triton.next_power_of_2(K)
|
| 148 |
+
IS_DECODE = (T == 1)
|
| 149 |
+
|
| 150 |
+
h0 = initial_state
|
| 151 |
+
if not output_final_state:
|
| 152 |
+
ht = None
|
| 153 |
+
else:
|
| 154 |
+
ht = r.new_empty(N, H, K, V, dtype=torch.float32)
|
| 155 |
+
o = torch.empty_like(v)
|
| 156 |
+
|
| 157 |
+
def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
|
| 158 |
+
fused_recurrent_rwkv7_fwd_kernel[grid](
|
| 159 |
+
r,
|
| 160 |
+
w,
|
| 161 |
+
k,
|
| 162 |
+
v,
|
| 163 |
+
kk,
|
| 164 |
+
a,
|
| 165 |
+
o,
|
| 166 |
+
h0,
|
| 167 |
+
ht,
|
| 168 |
+
cu_seqlens,
|
| 169 |
+
scale,
|
| 170 |
+
T=T,
|
| 171 |
+
B=B,
|
| 172 |
+
H=H,
|
| 173 |
+
K=K,
|
| 174 |
+
V=V,
|
| 175 |
+
BK=BK,
|
| 176 |
+
REVERSE=reverse,
|
| 177 |
+
IS_DECODE=IS_DECODE
|
| 178 |
+
)
|
| 179 |
+
return o, ht
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def fused_recurrent_rwkv7(
|
| 183 |
+
r: torch.Tensor,
|
| 184 |
+
w: torch.Tensor,
|
| 185 |
+
k: torch.Tensor,
|
| 186 |
+
v: torch.Tensor,
|
| 187 |
+
a: torch.Tensor,
|
| 188 |
+
b: torch.Tensor,
|
| 189 |
+
scale: float = 1.0,
|
| 190 |
+
initial_state: torch.Tensor = None,
|
| 191 |
+
output_final_state: bool = True,
|
| 192 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 193 |
+
head_first: bool = False,
|
| 194 |
+
):
|
| 195 |
+
"""
|
| 196 |
+
Args:
|
| 197 |
+
r (torch.Tensor):
|
| 198 |
+
r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 199 |
+
w (torch.Tensor):
|
| 200 |
+
log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 201 |
+
k (torch.Tensor):
|
| 202 |
+
k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 203 |
+
v (torch.Tensor):
|
| 204 |
+
v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 205 |
+
a (torch.Tensor):
|
| 206 |
+
a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 207 |
+
b (torch.Tensor):
|
| 208 |
+
b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 209 |
+
scale (float):
|
| 210 |
+
scale of the attention.
|
| 211 |
+
initial_state (torch.Tensor):
|
| 212 |
+
initial state of shape `[B, H, K, V]` if cu_seqlens is None else `[N, H, K, V]` where N = len(cu_seqlens) - 1.
|
| 213 |
+
output_final_state (bool):
|
| 214 |
+
whether to output the final state.
|
| 215 |
+
cu_seqlens (torch.LongTensor):
|
| 216 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 217 |
+
consistent with the FlashAttention API.
|
| 218 |
+
head_first (bool):
|
| 219 |
+
whether to use head first. Recommended to be False to avoid extra transposes.
|
| 220 |
+
Default: `False`.
|
| 221 |
+
"""
|
| 222 |
+
return fused_recurrent_dplr_delta_rule(
|
| 223 |
+
q=r,
|
| 224 |
+
k=k,
|
| 225 |
+
v=v,
|
| 226 |
+
a=a,
|
| 227 |
+
b=b,
|
| 228 |
+
gk=w,
|
| 229 |
+
scale=scale,
|
| 230 |
+
initial_state=initial_state,
|
| 231 |
+
output_final_state=output_final_state,
|
| 232 |
+
cu_seqlens=cu_seqlens,
|
| 233 |
+
head_first=head_first,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def fused_mul_recurrent_rwkv7(
|
| 238 |
+
r: torch.Tensor,
|
| 239 |
+
w: torch.Tensor,
|
| 240 |
+
k: torch.Tensor,
|
| 241 |
+
v: torch.Tensor,
|
| 242 |
+
kk: torch.Tensor,
|
| 243 |
+
a: torch.Tensor,
|
| 244 |
+
scale: Optional[float] = 1.0,
|
| 245 |
+
initial_state: Optional[torch.Tensor] = None,
|
| 246 |
+
output_final_state: bool = False,
|
| 247 |
+
reverse: bool = False,
|
| 248 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 249 |
+
head_first: bool = False,
|
| 250 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 251 |
+
r"""
|
| 252 |
+
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
r (torch.Tensor):
|
| 256 |
+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 257 |
+
w (torch.Tensor):
|
| 258 |
+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 259 |
+
k (torch.Tensor):
|
| 260 |
+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 261 |
+
v (torch.Tensor):
|
| 262 |
+
a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 263 |
+
kk (torch.Tensor):
|
| 264 |
+
b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
| 265 |
+
a (torch.Tensor):
|
| 266 |
+
gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space!
|
| 267 |
+
scale (Optional[int]):
|
| 268 |
+
Scale factor for the RetNet attention scores.
|
| 269 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: 1.
|
| 270 |
+
initial_state (Optional[torch.Tensor]):
|
| 271 |
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
| 272 |
+
For equal-length input sequences, `N` equals the batch size `B`.
|
| 273 |
+
Default: `None`.
|
| 274 |
+
output_final_state (Optional[bool]):
|
| 275 |
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
| 276 |
+
reverse (Optional[bool]):
|
| 277 |
+
If `True`, process the state passing in reverse order. Default: `False`.
|
| 278 |
+
cu_seqlens (Optional[torch.Tensor]):
|
| 279 |
+
Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
|
| 280 |
+
consistent with the FlashAttention API.
|
| 281 |
+
head_first (Optional[bool]):
|
| 282 |
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
| 283 |
+
Default: `False`.
|
| 284 |
+
"""
|
| 285 |
+
if head_first:
|
| 286 |
+
raise DeprecationWarning(
|
| 287 |
+
"head_first is deprecated and will be removed in a future version. "
|
| 288 |
+
"Please use head_first=False for now instead."
|
| 289 |
+
)
|
| 290 |
+
r, w, k, v, kk, a = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (r, w, k, v, kk, a))
|
| 291 |
+
if not head_first and r.shape[1] < r.shape[2]:
|
| 292 |
+
warnings.warn(
|
| 293 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). "
|
| 294 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 295 |
+
"when head_first=False was specified. "
|
| 296 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 297 |
+
)
|
| 298 |
+
if cu_seqlens is not None:
|
| 299 |
+
if r.shape[0] != 1:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`."
|
| 302 |
+
f"Please flatten variable-length inputs before processing."
|
| 303 |
+
)
|
| 304 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 307 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 308 |
+
)
|
| 309 |
+
if scale is None:
|
| 310 |
+
scale = r.shape[-1] ** -0.5
|
| 311 |
+
else:
|
| 312 |
+
assert scale > 0, "scale must be positive"
|
| 313 |
+
o, final_state = fused_recurrent_rwkv7_fwd(
|
| 314 |
+
r,
|
| 315 |
+
w,
|
| 316 |
+
k,
|
| 317 |
+
v,
|
| 318 |
+
kk,
|
| 319 |
+
a,
|
| 320 |
+
scale,
|
| 321 |
+
initial_state,
|
| 322 |
+
output_final_state,
|
| 323 |
+
reverse,
|
| 324 |
+
cu_seqlens,
|
| 325 |
+
)
|
| 326 |
+
if head_first:
|
| 327 |
+
o = rearrange(o, 'b t h ... -> b h t ...')
|
| 328 |
+
return o, final_state
|
fla3/ops/simple_gla/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (393 Bytes). View file
|
|
|
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc
ADDED
|
Binary file (4.78 kB). View file
|
|
|
fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
fla3/ops/simple_gla/parallel.py
ADDED
|
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from fla.ops.utils import prepare_chunk_indices
|
| 13 |
+
from fla.ops.utils.cumsum import chunk_global_cumsum, chunk_local_cumsum
|
| 14 |
+
from fla.ops.utils.op import safe_exp
|
| 15 |
+
from fla.utils import (
|
| 16 |
+
autocast_custom_bwd,
|
| 17 |
+
autocast_custom_fwd,
|
| 18 |
+
check_shared_mem,
|
| 19 |
+
input_guard,
|
| 20 |
+
is_intel_alchemist,
|
| 21 |
+
is_nvidia_hopper
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
|
| 25 |
+
triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
|
| 26 |
+
NUM_WARPS = [2, 4, 8] if is_nvidia_hopper else [2, 4, 8, 16]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@triton.heuristics({
|
| 30 |
+
'NV': lambda args: triton.cdiv(args['V'], args['BV']),
|
| 31 |
+
'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None,
|
| 32 |
+
'USE_G': lambda args: args['g'] is not None,
|
| 33 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 34 |
+
})
|
| 35 |
+
@triton.autotune(
|
| 36 |
+
configs=[
|
| 37 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 38 |
+
for num_warps in [2, 4, 8, 16]
|
| 39 |
+
for num_stages in [2, 3, 4]
|
| 40 |
+
],
|
| 41 |
+
key=["BT", "BS", "BK", "BV", "USE_G"],
|
| 42 |
+
)
|
| 43 |
+
@triton.jit
|
| 44 |
+
def parallel_simple_gla_fwd_kernel(
|
| 45 |
+
q,
|
| 46 |
+
k,
|
| 47 |
+
v,
|
| 48 |
+
g,
|
| 49 |
+
o,
|
| 50 |
+
attn,
|
| 51 |
+
scale,
|
| 52 |
+
cu_seqlens,
|
| 53 |
+
chunk_indices,
|
| 54 |
+
T,
|
| 55 |
+
B: tl.constexpr,
|
| 56 |
+
H: tl.constexpr,
|
| 57 |
+
K: tl.constexpr,
|
| 58 |
+
V: tl.constexpr,
|
| 59 |
+
BT: tl.constexpr,
|
| 60 |
+
BS: tl.constexpr,
|
| 61 |
+
BK: tl.constexpr,
|
| 62 |
+
BV: tl.constexpr,
|
| 63 |
+
NV: tl.constexpr,
|
| 64 |
+
OUTPUT_ATTENTIONS: tl.constexpr,
|
| 65 |
+
IS_VARLEN: tl.constexpr,
|
| 66 |
+
USE_G: tl.constexpr
|
| 67 |
+
):
|
| 68 |
+
i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 69 |
+
i_k, i_v = i_kv // NV, i_kv % NV
|
| 70 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 71 |
+
o += i_k * B * T * H * V
|
| 72 |
+
|
| 73 |
+
if IS_VARLEN:
|
| 74 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 75 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 76 |
+
T = eos - bos
|
| 77 |
+
else:
|
| 78 |
+
bos, eos = i_b * T, i_b * T + T
|
| 79 |
+
|
| 80 |
+
q += (bos * H + i_h) * K
|
| 81 |
+
k += (bos * H + i_h) * K
|
| 82 |
+
v += (bos * H + i_h) * V
|
| 83 |
+
o += (bos * H + i_h) * V
|
| 84 |
+
if USE_G:
|
| 85 |
+
g += bos * H + i_h
|
| 86 |
+
if OUTPUT_ATTENTIONS:
|
| 87 |
+
attn += (bos * H + i_h * T) * T + i_k * B * H * T * T
|
| 88 |
+
stride_qk = H * K
|
| 89 |
+
stride_vo = H * V
|
| 90 |
+
stride_g = H
|
| 91 |
+
|
| 92 |
+
p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 93 |
+
|
| 94 |
+
# the Q block is kept in the shared memory throughout the whole kernel
|
| 95 |
+
# [BT, BK]
|
| 96 |
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
| 97 |
+
b_q = (b_q * scale).to(b_q.dtype)
|
| 98 |
+
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
| 99 |
+
|
| 100 |
+
# [BT]
|
| 101 |
+
o_q = i_t * BT + tl.arange(0, BT)
|
| 102 |
+
# [BS]
|
| 103 |
+
o_k = i_t * BT + tl.arange(0, BS)
|
| 104 |
+
# Q block and K block have overlap.
|
| 105 |
+
# masks required
|
| 106 |
+
if USE_G:
|
| 107 |
+
p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
|
| 108 |
+
# [BT,]
|
| 109 |
+
b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
|
| 110 |
+
# rescale interchunk output
|
| 111 |
+
else:
|
| 112 |
+
b_gq = None
|
| 113 |
+
|
| 114 |
+
for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
|
| 115 |
+
p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1))
|
| 116 |
+
p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
|
| 117 |
+
# [BK, BS]
|
| 118 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 119 |
+
# [BS, BV]
|
| 120 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 121 |
+
# [BT, BS]
|
| 122 |
+
m_s = o_q[:, None] >= o_k[None, :]
|
| 123 |
+
b_s = tl.dot(b_q, b_k)
|
| 124 |
+
if USE_G:
|
| 125 |
+
p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
|
| 126 |
+
b_gk = tl.load(p_gk, boundary_check=(0,))
|
| 127 |
+
b_s *= safe_exp(b_gq[:, None] - b_gk[None, :])
|
| 128 |
+
b_s = tl.where(m_s, b_s, 0)
|
| 129 |
+
else:
|
| 130 |
+
b_s = tl.where(m_s, b_s, 0)
|
| 131 |
+
# [BT, BV]
|
| 132 |
+
if i_s >= 0:
|
| 133 |
+
b_o += tl.dot(b_s.to(b_q.dtype), b_v)
|
| 134 |
+
if OUTPUT_ATTENTIONS:
|
| 135 |
+
p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0))
|
| 136 |
+
tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
|
| 137 |
+
o_k += BS
|
| 138 |
+
|
| 139 |
+
for i_s in range(i_t * BT - BS, -BS, -BS):
|
| 140 |
+
p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1))
|
| 141 |
+
p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
|
| 142 |
+
# [BK, BS]
|
| 143 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 144 |
+
# [BS, BV]
|
| 145 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 146 |
+
b_s = tl.dot(b_q, b_k)
|
| 147 |
+
if USE_G:
|
| 148 |
+
p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
|
| 149 |
+
b_g = tl.load(p_g, boundary_check=(0,))
|
| 150 |
+
b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
|
| 151 |
+
b_gp = tl.load(g + (i_s-1) * stride_g) if i_s % BT > 0 else 0.
|
| 152 |
+
# No concrete meaning. Just to avoid some layout bugs.
|
| 153 |
+
b_s *= safe_exp(b_gq[:, None] + (b_gn - b_g)[None, :])
|
| 154 |
+
b_gq += (b_gn - b_gp)
|
| 155 |
+
if OUTPUT_ATTENTIONS:
|
| 156 |
+
p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0))
|
| 157 |
+
tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
|
| 158 |
+
if i_s >= 0:
|
| 159 |
+
b_o += tl.dot(b_s.to(b_v.dtype), b_v)
|
| 160 |
+
p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 161 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@triton.jit(do_not_specialize=['T'])
|
| 165 |
+
def parallel_simple_gla_bwd_kernel_dq(
|
| 166 |
+
i_t,
|
| 167 |
+
i_k,
|
| 168 |
+
i_v,
|
| 169 |
+
q,
|
| 170 |
+
k,
|
| 171 |
+
v,
|
| 172 |
+
g,
|
| 173 |
+
do,
|
| 174 |
+
dq,
|
| 175 |
+
dg,
|
| 176 |
+
stride_qk,
|
| 177 |
+
stride_vo,
|
| 178 |
+
stride_g,
|
| 179 |
+
scale,
|
| 180 |
+
T,
|
| 181 |
+
K: tl.constexpr,
|
| 182 |
+
V: tl.constexpr,
|
| 183 |
+
BT: tl.constexpr,
|
| 184 |
+
BS: tl.constexpr,
|
| 185 |
+
BK: tl.constexpr,
|
| 186 |
+
BV: tl.constexpr,
|
| 187 |
+
USE_G: tl.constexpr
|
| 188 |
+
):
|
| 189 |
+
p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 190 |
+
# [BT, BV]
|
| 191 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 192 |
+
# [BT, BK]
|
| 193 |
+
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
| 194 |
+
|
| 195 |
+
for i_s in range(0, i_t * BT, BS):
|
| 196 |
+
p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
|
| 197 |
+
p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1))
|
| 198 |
+
# [BS, BK]
|
| 199 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 200 |
+
# [BV, BS]
|
| 201 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 202 |
+
# [BT, BV] @ [BV, BS] = [BT, BS]
|
| 203 |
+
b_ds = tl.dot(b_do, b_v)
|
| 204 |
+
if USE_G:
|
| 205 |
+
p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
|
| 206 |
+
b_g = tl.load(p_g, boundary_check=(0,))
|
| 207 |
+
b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
|
| 208 |
+
b_gp = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0.
|
| 209 |
+
b_ds *= safe_exp(b_gn - b_g)[None, :]
|
| 210 |
+
if i_s > 0:
|
| 211 |
+
b_dq *= safe_exp(b_gn - b_gp)
|
| 212 |
+
# [BT, BS] @ [BS, BK] = [BT, BK]
|
| 213 |
+
b_dq += tl.dot(b_ds.to(b_v.dtype), b_k)
|
| 214 |
+
|
| 215 |
+
if USE_G:
|
| 216 |
+
p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
|
| 217 |
+
# [BT,]
|
| 218 |
+
b_gq = tl.load(p_gq, boundary_check=(0,))
|
| 219 |
+
# [BT, BK]
|
| 220 |
+
b_dq *= safe_exp(b_gq)[:, None]
|
| 221 |
+
|
| 222 |
+
# [BT]
|
| 223 |
+
o_q = i_t * BT + tl.arange(0, BT)
|
| 224 |
+
# [BS]
|
| 225 |
+
o_k = i_t * BT + tl.arange(0, BS)
|
| 226 |
+
# Q block and K block have overlap. masks required
|
| 227 |
+
for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
|
| 228 |
+
p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
|
| 229 |
+
p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1))
|
| 230 |
+
# [BS, BK]
|
| 231 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 232 |
+
# [BV, BS]
|
| 233 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 234 |
+
# [BT, BV] @ [BV, BS] = [BT, BS]
|
| 235 |
+
b_ds = tl.dot(b_do, b_v)
|
| 236 |
+
if USE_G:
|
| 237 |
+
p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
|
| 238 |
+
b_gk = tl.load(p_gk, boundary_check=(0,))
|
| 239 |
+
b_ds *= safe_exp(b_gq[:, None] - b_gk[None, :])
|
| 240 |
+
b_ds = tl.where(o_q[:, None] >= o_k[None, :], b_ds, 0)
|
| 241 |
+
# [BT, BK]
|
| 242 |
+
b_dq += tl.dot(b_ds.to(b_k.dtype), b_k)
|
| 243 |
+
o_k += BS
|
| 244 |
+
|
| 245 |
+
b_dq *= scale
|
| 246 |
+
p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 247 |
+
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
| 248 |
+
if USE_G:
|
| 249 |
+
p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 250 |
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
| 251 |
+
b_dg = tl.sum(b_dq * b_q, 1)
|
| 252 |
+
p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
|
| 253 |
+
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@triton.jit(do_not_specialize=['T'])
|
| 257 |
+
def parallel_simple_gla_bwd_kernel_dkv(
|
| 258 |
+
i_t,
|
| 259 |
+
i_k,
|
| 260 |
+
i_v,
|
| 261 |
+
q,
|
| 262 |
+
k,
|
| 263 |
+
v,
|
| 264 |
+
g,
|
| 265 |
+
do,
|
| 266 |
+
dk,
|
| 267 |
+
dv,
|
| 268 |
+
dg,
|
| 269 |
+
scale,
|
| 270 |
+
stride_qk,
|
| 271 |
+
stride_vo,
|
| 272 |
+
stride_g,
|
| 273 |
+
T,
|
| 274 |
+
K: tl.constexpr,
|
| 275 |
+
V: tl.constexpr,
|
| 276 |
+
BT: tl.constexpr,
|
| 277 |
+
BS: tl.constexpr,
|
| 278 |
+
BK: tl.constexpr,
|
| 279 |
+
BV: tl.constexpr,
|
| 280 |
+
USE_G: tl.constexpr
|
| 281 |
+
):
|
| 282 |
+
# [BT, BK]
|
| 283 |
+
p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 284 |
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
| 285 |
+
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
| 286 |
+
# [BT, BV]
|
| 287 |
+
p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 288 |
+
b_v = tl.load(p_v, boundary_check=(0, 1))
|
| 289 |
+
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
|
| 290 |
+
if USE_G:
|
| 291 |
+
p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
|
| 292 |
+
b_gk = tl.load(p_gk, boundary_check=(0,))
|
| 293 |
+
NTS = tl.cdiv(T, BS)
|
| 294 |
+
# [BT, BK]
|
| 295 |
+
for i_s in range(NTS * BS - BS, (i_t + 1) * BT - BS, -BS):
|
| 296 |
+
p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
|
| 297 |
+
p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
|
| 298 |
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
| 299 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 300 |
+
b_ds = tl.dot(b_v, tl.trans(b_do))
|
| 301 |
+
b_s = tl.dot(b_k, tl.trans(b_q))
|
| 302 |
+
if USE_G:
|
| 303 |
+
p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
|
| 304 |
+
b_gq = tl.load(p_gq, boundary_check=(0,))
|
| 305 |
+
b_gp = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
|
| 306 |
+
b_gn = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0.
|
| 307 |
+
if i_s >= 0:
|
| 308 |
+
tmp = safe_exp(b_gp - b_gn)
|
| 309 |
+
b_dk *= tmp
|
| 310 |
+
b_dv *= tmp
|
| 311 |
+
tmp2 = safe_exp(b_gq - b_gn)
|
| 312 |
+
b_ds *= tmp2[None, :]
|
| 313 |
+
b_s *= tmp2[None, :]
|
| 314 |
+
# [BT, BK]
|
| 315 |
+
b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
|
| 316 |
+
# [BT, BV]
|
| 317 |
+
b_dv += tl.dot(b_s.to(b_do.dtype), b_do)
|
| 318 |
+
|
| 319 |
+
if USE_G:
|
| 320 |
+
b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * stride_g)
|
| 321 |
+
if i_t >= 0:
|
| 322 |
+
tmp2 = safe_exp(b_g_last - b_gk)[:, None]
|
| 323 |
+
b_dk *= tmp2
|
| 324 |
+
b_dv *= tmp2
|
| 325 |
+
|
| 326 |
+
o_q = i_t * BT + tl.arange(0, BS)
|
| 327 |
+
o_k = i_t * BT + tl.arange(0, BT)
|
| 328 |
+
for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
|
| 329 |
+
p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
|
| 330 |
+
p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
|
| 331 |
+
# [BS, BK]
|
| 332 |
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
| 333 |
+
# [BS, BV]
|
| 334 |
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
| 335 |
+
# [BS]
|
| 336 |
+
b_ds = tl.dot(b_v, tl.trans(b_do))
|
| 337 |
+
b_s = tl.dot(b_k, tl.trans(b_q))
|
| 338 |
+
if USE_G:
|
| 339 |
+
p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
|
| 340 |
+
b_gq = tl.load(p_gq, boundary_check=(0,))
|
| 341 |
+
if i_s >= 0:
|
| 342 |
+
tmp = safe_exp(-b_gk[:, None] + b_gq[None, :])
|
| 343 |
+
b_ds *= tmp
|
| 344 |
+
b_s *= tmp
|
| 345 |
+
m_s = o_k[:, None] <= o_q[None, :]
|
| 346 |
+
b_s = tl.where(m_s, b_s, 0)
|
| 347 |
+
b_ds = tl.where(m_s, b_ds, 0)
|
| 348 |
+
# [BT, BK]
|
| 349 |
+
b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
|
| 350 |
+
b_dv += tl.dot(b_s.to(b_do.dtype), b_do)
|
| 351 |
+
o_q += BS
|
| 352 |
+
b_dk *= scale
|
| 353 |
+
b_dv *= scale
|
| 354 |
+
p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
| 355 |
+
p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
| 356 |
+
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
| 357 |
+
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
| 358 |
+
if USE_G:
|
| 359 |
+
p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
|
| 360 |
+
b_dg = tl.load(p_dg, boundary_check=(0,))
|
| 361 |
+
b_dg -= tl.sum(b_dk * b_k, 1)
|
| 362 |
+
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
@triton.heuristics({
|
| 366 |
+
'NV': lambda args: triton.cdiv(args['V'], args['BV']),
|
| 367 |
+
'USE_G': lambda args: args['g'] is not None,
|
| 368 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 369 |
+
})
|
| 370 |
+
@triton.autotune(
|
| 371 |
+
configs=[
|
| 372 |
+
triton.Config(triton_config, num_warps=num_warps)
|
| 373 |
+
for num_warps in NUM_WARPS
|
| 374 |
+
],
|
| 375 |
+
key=['BT', 'BS', 'BK', 'BV', 'USE_G'],
|
| 376 |
+
)
|
| 377 |
+
@triton.jit(do_not_specialize=['T'])
|
| 378 |
+
def parallel_simple_gla_bwd_kernel(
|
| 379 |
+
q,
|
| 380 |
+
k,
|
| 381 |
+
v,
|
| 382 |
+
g,
|
| 383 |
+
do,
|
| 384 |
+
dq,
|
| 385 |
+
dk,
|
| 386 |
+
dv,
|
| 387 |
+
dg,
|
| 388 |
+
scale,
|
| 389 |
+
cu_seqlens,
|
| 390 |
+
chunk_indices,
|
| 391 |
+
T,
|
| 392 |
+
B: tl.constexpr,
|
| 393 |
+
H: tl.constexpr,
|
| 394 |
+
K: tl.constexpr,
|
| 395 |
+
V: tl.constexpr,
|
| 396 |
+
BT: tl.constexpr,
|
| 397 |
+
BS: tl.constexpr,
|
| 398 |
+
BK: tl.constexpr,
|
| 399 |
+
BV: tl.constexpr,
|
| 400 |
+
NV: tl.constexpr,
|
| 401 |
+
IS_VARLEN: tl.constexpr,
|
| 402 |
+
USE_G: tl.constexpr
|
| 403 |
+
):
|
| 404 |
+
i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 405 |
+
i_k, i_v = i_kv // NV, i_kv % NV
|
| 406 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 407 |
+
dq += i_v * B * H * T * K
|
| 408 |
+
dk += i_v * B * H * T * K
|
| 409 |
+
dv += i_k * B * H * T * V
|
| 410 |
+
if USE_G:
|
| 411 |
+
dg += i_kv * B * H * T
|
| 412 |
+
|
| 413 |
+
if IS_VARLEN:
|
| 414 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 415 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 416 |
+
T = eos - bos
|
| 417 |
+
else:
|
| 418 |
+
bos, eos = i_b * T, i_b * T + T
|
| 419 |
+
|
| 420 |
+
q += (bos * H + i_h) * K
|
| 421 |
+
k += (bos * H + i_h) * K
|
| 422 |
+
v += (bos * H + i_h) * V
|
| 423 |
+
do += (bos * H + i_h) * V
|
| 424 |
+
dq += (bos * H + i_h) * K
|
| 425 |
+
dk += (bos * H + i_h) * K
|
| 426 |
+
dv += (bos * H + i_h) * V
|
| 427 |
+
if USE_G:
|
| 428 |
+
g += bos * H + i_h
|
| 429 |
+
dg += bos * H + i_h
|
| 430 |
+
stride_qk = H * K
|
| 431 |
+
stride_vo = H * V
|
| 432 |
+
stride_g = H
|
| 433 |
+
|
| 434 |
+
parallel_simple_gla_bwd_kernel_dq(
|
| 435 |
+
i_t=i_t,
|
| 436 |
+
i_k=i_k,
|
| 437 |
+
i_v=i_v,
|
| 438 |
+
q=q,
|
| 439 |
+
k=k,
|
| 440 |
+
v=v,
|
| 441 |
+
g=g,
|
| 442 |
+
do=do,
|
| 443 |
+
dq=dq,
|
| 444 |
+
dg=dg,
|
| 445 |
+
scale=scale,
|
| 446 |
+
stride_qk=stride_qk,
|
| 447 |
+
stride_vo=stride_vo,
|
| 448 |
+
stride_g=stride_g,
|
| 449 |
+
T=T,
|
| 450 |
+
K=K,
|
| 451 |
+
V=V,
|
| 452 |
+
BT=BT,
|
| 453 |
+
BS=BS,
|
| 454 |
+
BK=BK,
|
| 455 |
+
BV=BV,
|
| 456 |
+
USE_G=USE_G
|
| 457 |
+
)
|
| 458 |
+
tl.debug_barrier()
|
| 459 |
+
parallel_simple_gla_bwd_kernel_dkv(
|
| 460 |
+
i_t=i_t,
|
| 461 |
+
i_k=i_k,
|
| 462 |
+
i_v=i_v,
|
| 463 |
+
q=q,
|
| 464 |
+
k=k,
|
| 465 |
+
v=v,
|
| 466 |
+
g=g,
|
| 467 |
+
do=do,
|
| 468 |
+
dk=dk,
|
| 469 |
+
dv=dv,
|
| 470 |
+
dg=dg,
|
| 471 |
+
scale=scale,
|
| 472 |
+
stride_qk=stride_qk,
|
| 473 |
+
stride_vo=stride_vo,
|
| 474 |
+
stride_g=stride_g,
|
| 475 |
+
T=T,
|
| 476 |
+
K=K,
|
| 477 |
+
V=V,
|
| 478 |
+
BT=BT,
|
| 479 |
+
BS=BS,
|
| 480 |
+
BK=BK,
|
| 481 |
+
BV=BV,
|
| 482 |
+
USE_G=USE_G
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def parallel_simple_gla_fwd(
|
| 487 |
+
q: torch.Tensor,
|
| 488 |
+
k: torch.Tensor,
|
| 489 |
+
v: torch.Tensor,
|
| 490 |
+
g: torch.Tensor,
|
| 491 |
+
scale: float,
|
| 492 |
+
output_attentions: bool = False,
|
| 493 |
+
chunk_size: int = 128,
|
| 494 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 495 |
+
):
|
| 496 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 497 |
+
BT, BS = chunk_size, 32
|
| 498 |
+
if check_shared_mem('hopper', k.device.index):
|
| 499 |
+
BK = min(256, triton.next_power_of_2(K))
|
| 500 |
+
BV = min(256, triton.next_power_of_2(V))
|
| 501 |
+
elif check_shared_mem('ampere', k.device.index):
|
| 502 |
+
BK = min(128, triton.next_power_of_2(K))
|
| 503 |
+
BV = min(128, triton.next_power_of_2(V))
|
| 504 |
+
else:
|
| 505 |
+
BK = min(64, triton.next_power_of_2(K))
|
| 506 |
+
BV = min(64, triton.next_power_of_2(V))
|
| 507 |
+
|
| 508 |
+
NK = triton.cdiv(K, BK)
|
| 509 |
+
NV = triton.cdiv(V, BV)
|
| 510 |
+
assert BT % BS == 0
|
| 511 |
+
|
| 512 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 513 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 514 |
+
|
| 515 |
+
# local cumulative decay in log space
|
| 516 |
+
if g is not None:
|
| 517 |
+
g = chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens)
|
| 518 |
+
grid = (NK * NV, NT, B * H)
|
| 519 |
+
o = torch.empty(NK, *v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device)
|
| 520 |
+
attn = q.new_zeros(NK, B, H, T, T) if output_attentions else None
|
| 521 |
+
|
| 522 |
+
parallel_simple_gla_fwd_kernel[grid](
|
| 523 |
+
q=q,
|
| 524 |
+
k=k,
|
| 525 |
+
v=v,
|
| 526 |
+
g=g,
|
| 527 |
+
o=o,
|
| 528 |
+
attn=attn,
|
| 529 |
+
scale=scale,
|
| 530 |
+
cu_seqlens=cu_seqlens,
|
| 531 |
+
chunk_indices=chunk_indices,
|
| 532 |
+
B=B,
|
| 533 |
+
H=H,
|
| 534 |
+
T=T,
|
| 535 |
+
K=K,
|
| 536 |
+
V=V,
|
| 537 |
+
BT=BT,
|
| 538 |
+
BS=BS,
|
| 539 |
+
BK=BK,
|
| 540 |
+
BV=BV,
|
| 541 |
+
)
|
| 542 |
+
o = o.sum(0)
|
| 543 |
+
|
| 544 |
+
if output_attentions:
|
| 545 |
+
attn = attn.sum(0)
|
| 546 |
+
return o, g, attn
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def parallel_simple_gla_bwd(
|
| 550 |
+
q: torch.Tensor,
|
| 551 |
+
k: torch.Tensor,
|
| 552 |
+
v: torch.Tensor,
|
| 553 |
+
g: torch.Tensor,
|
| 554 |
+
do: torch.Tensor,
|
| 555 |
+
scale: float,
|
| 556 |
+
chunk_size: int = 128,
|
| 557 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 558 |
+
):
|
| 559 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 560 |
+
BT, BS = chunk_size, 32
|
| 561 |
+
if check_shared_mem('hopper', k.device.index):
|
| 562 |
+
BK = min(256, triton.next_power_of_2(K))
|
| 563 |
+
BV = min(256, triton.next_power_of_2(V))
|
| 564 |
+
elif check_shared_mem('ampere', k.device.index):
|
| 565 |
+
BK = min(128, triton.next_power_of_2(K))
|
| 566 |
+
BV = min(128, triton.next_power_of_2(V))
|
| 567 |
+
elif check_shared_mem('ada', k.device.index):
|
| 568 |
+
BK = min(64, triton.next_power_of_2(K))
|
| 569 |
+
BV = min(64, triton.next_power_of_2(V))
|
| 570 |
+
else:
|
| 571 |
+
BK = min(32, triton.next_power_of_2(K))
|
| 572 |
+
BV = min(32, triton.next_power_of_2(V))
|
| 573 |
+
|
| 574 |
+
NK = triton.cdiv(K, BK)
|
| 575 |
+
NV = triton.cdiv(V, BV)
|
| 576 |
+
assert BT % BS == 0
|
| 577 |
+
|
| 578 |
+
dq = torch.empty(NV, * q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
|
| 579 |
+
dk = torch.empty(NV, * k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
|
| 580 |
+
dv = torch.empty(NK, * v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device)
|
| 581 |
+
dg = torch.empty(NK*NV, *g.shape, dtype=torch.float, device=q.device) if g is not None else None
|
| 582 |
+
|
| 583 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 584 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 585 |
+
|
| 586 |
+
grid = (NK * NV, NT, B * H)
|
| 587 |
+
parallel_simple_gla_bwd_kernel[grid](
|
| 588 |
+
q=q,
|
| 589 |
+
k=k,
|
| 590 |
+
v=v,
|
| 591 |
+
g=g,
|
| 592 |
+
do=do,
|
| 593 |
+
dq=dq,
|
| 594 |
+
dk=dk,
|
| 595 |
+
dv=dv,
|
| 596 |
+
dg=dg,
|
| 597 |
+
cu_seqlens=cu_seqlens,
|
| 598 |
+
chunk_indices=chunk_indices,
|
| 599 |
+
scale=scale,
|
| 600 |
+
T=T,
|
| 601 |
+
B=B,
|
| 602 |
+
H=H,
|
| 603 |
+
K=K,
|
| 604 |
+
V=V,
|
| 605 |
+
BT=BT,
|
| 606 |
+
BS=BS,
|
| 607 |
+
BK=BK,
|
| 608 |
+
BV=BV,
|
| 609 |
+
)
|
| 610 |
+
dq = dq.sum(0)
|
| 611 |
+
dk = dk.sum(0)
|
| 612 |
+
dv = dv.sum(0)
|
| 613 |
+
dg = chunk_global_cumsum(dg.sum(0), reverse=True, cu_seqlens=cu_seqlens) if g is not None else None
|
| 614 |
+
return dq, dk, dv, dg
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class ParallelSimpleGLAFunction(torch.autograd.Function):
|
| 618 |
+
|
| 619 |
+
@staticmethod
|
| 620 |
+
@input_guard
|
| 621 |
+
@autocast_custom_fwd
|
| 622 |
+
def forward(ctx, q, k, v, g, scale, output_attentions, cu_seqlens):
|
| 623 |
+
chunk_size = 128
|
| 624 |
+
ctx.dtype = q.dtype
|
| 625 |
+
|
| 626 |
+
o, g, attn = parallel_simple_gla_fwd(
|
| 627 |
+
q=q,
|
| 628 |
+
k=k,
|
| 629 |
+
v=v,
|
| 630 |
+
g=g,
|
| 631 |
+
scale=scale,
|
| 632 |
+
output_attentions=output_attentions,
|
| 633 |
+
chunk_size=chunk_size,
|
| 634 |
+
cu_seqlens=cu_seqlens,
|
| 635 |
+
)
|
| 636 |
+
ctx.save_for_backward(q, k, v, g, cu_seqlens)
|
| 637 |
+
ctx.scale = scale
|
| 638 |
+
ctx.chunk_size = chunk_size
|
| 639 |
+
return o.to(q.dtype), attn
|
| 640 |
+
|
| 641 |
+
@staticmethod
|
| 642 |
+
@input_guard
|
| 643 |
+
@autocast_custom_bwd
|
| 644 |
+
def backward(ctx, do, da=None):
|
| 645 |
+
q, k, v, g, cu_seqlens = ctx.saved_tensors
|
| 646 |
+
dq, dk, dv, dg = parallel_simple_gla_bwd(
|
| 647 |
+
q=q,
|
| 648 |
+
k=k,
|
| 649 |
+
v=v,
|
| 650 |
+
g=g,
|
| 651 |
+
do=do,
|
| 652 |
+
scale=ctx.scale,
|
| 653 |
+
chunk_size=ctx.chunk_size,
|
| 654 |
+
cu_seqlens=cu_seqlens,
|
| 655 |
+
)
|
| 656 |
+
return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.dtype) if dg is not None else None, None, None, None
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def parallel_simple_gla(
|
| 660 |
+
q: torch.Tensor,
|
| 661 |
+
k: torch.Tensor,
|
| 662 |
+
v: torch.Tensor,
|
| 663 |
+
g: Optional[torch.Tensor] = None,
|
| 664 |
+
scale: Optional[float] = None,
|
| 665 |
+
output_attentions: bool = False,
|
| 666 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 667 |
+
head_first: bool = False
|
| 668 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 669 |
+
r"""
|
| 670 |
+
Args:
|
| 671 |
+
q (torch.Tensor):
|
| 672 |
+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`
|
| 673 |
+
k (torch.Tensor):
|
| 674 |
+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`
|
| 675 |
+
v (torch.Tensor):
|
| 676 |
+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`
|
| 677 |
+
g (torch.Tensor):
|
| 678 |
+
Forget gates of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
| 679 |
+
Compared to GLA, the gating is head-wise instead of elementwise.
|
| 680 |
+
scale (Optional[int]):
|
| 681 |
+
Scale factor for attention scores.
|
| 682 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
| 683 |
+
output_attentions (bool):
|
| 684 |
+
Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`.
|
| 685 |
+
head_first (Optional[bool]):
|
| 686 |
+
Whether the inputs are in the head-first format. Default: `False`.
|
| 687 |
+
cu_seqlens (torch.LongTensor):
|
| 688 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 689 |
+
consistent with the FlashAttention API.
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
o (torch.Tensor):
|
| 693 |
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
| 694 |
+
attn (torch.Tensor):
|
| 695 |
+
Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`
|
| 696 |
+
"""
|
| 697 |
+
if head_first:
|
| 698 |
+
raise DeprecationWarning(
|
| 699 |
+
"head_first is deprecated and will be removed in a future version. "
|
| 700 |
+
"Please use head_first=False for now instead."
|
| 701 |
+
)
|
| 702 |
+
q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g))
|
| 703 |
+
if not head_first and q.shape[1] < q.shape[2]:
|
| 704 |
+
warnings.warn(
|
| 705 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
| 706 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 707 |
+
"when head_first=False was specified. "
|
| 708 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 709 |
+
)
|
| 710 |
+
if cu_seqlens is not None:
|
| 711 |
+
if q.shape[0] != 1:
|
| 712 |
+
raise ValueError(
|
| 713 |
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 714 |
+
f"Please flatten variable-length inputs before processing."
|
| 715 |
+
)
|
| 716 |
+
if output_attentions:
|
| 717 |
+
assert cu_seqlens is None, "output_attentions=True is not supported with variable-length sequences"
|
| 718 |
+
|
| 719 |
+
if scale is None:
|
| 720 |
+
scale = k.shape[-1] ** -0.5
|
| 721 |
+
o, attn = ParallelSimpleGLAFunction.apply(
|
| 722 |
+
q,
|
| 723 |
+
k,
|
| 724 |
+
v,
|
| 725 |
+
g,
|
| 726 |
+
scale,
|
| 727 |
+
output_attentions,
|
| 728 |
+
cu_seqlens
|
| 729 |
+
)
|
| 730 |
+
if head_first:
|
| 731 |
+
o = rearrange(o, 'b t h ... -> b h t ...')
|
| 732 |
+
return o, attn
|
fla3/ops/ttt/fused_chunk.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from fla.modules.layernorm import group_norm
|
| 13 |
+
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, is_nvidia_hopper
|
| 14 |
+
|
| 15 |
+
NUM_WARPS = [1, 2] if is_nvidia_hopper else [1, 2, 4, 8]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@triton.heuristics({
|
| 19 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 20 |
+
'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
|
| 21 |
+
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
| 22 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 23 |
+
})
|
| 24 |
+
@triton.autotune(
|
| 25 |
+
configs=[
|
| 26 |
+
triton.Config({}, num_warps=1),
|
| 27 |
+
triton.Config({}, num_warps=2),
|
| 28 |
+
triton.Config({}, num_warps=4)
|
| 29 |
+
],
|
| 30 |
+
key=['BT', 'BK', 'BV'],
|
| 31 |
+
)
|
| 32 |
+
@triton.jit(do_not_specialize=['T'])
|
| 33 |
+
def fused_chunk_ttt_linear_fwd_kernel(
|
| 34 |
+
q,
|
| 35 |
+
k,
|
| 36 |
+
v,
|
| 37 |
+
eta,
|
| 38 |
+
w,
|
| 39 |
+
b,
|
| 40 |
+
o,
|
| 41 |
+
scale,
|
| 42 |
+
eps,
|
| 43 |
+
h0,
|
| 44 |
+
hb0,
|
| 45 |
+
ht,
|
| 46 |
+
hbt,
|
| 47 |
+
cu_seqlens,
|
| 48 |
+
T,
|
| 49 |
+
H: tl.constexpr,
|
| 50 |
+
K: tl.constexpr,
|
| 51 |
+
V: tl.constexpr,
|
| 52 |
+
BT: tl.constexpr,
|
| 53 |
+
BK: tl.constexpr,
|
| 54 |
+
BV: tl.constexpr,
|
| 55 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 56 |
+
USE_INITIAL_STATE_B: tl.constexpr,
|
| 57 |
+
STORE_FINAL_STATE: tl.constexpr,
|
| 58 |
+
IS_VARLEN: tl.constexpr,
|
| 59 |
+
):
|
| 60 |
+
i_nh = tl.program_id(0)
|
| 61 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 62 |
+
if IS_VARLEN:
|
| 63 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 64 |
+
T = eos - bos
|
| 65 |
+
NT = tl.cdiv(T, BT)
|
| 66 |
+
else:
|
| 67 |
+
bos, eos = i_n * T, i_n * T + T
|
| 68 |
+
NT = tl.cdiv(T, BT)
|
| 69 |
+
|
| 70 |
+
o_i = tl.arange(0, BT)
|
| 71 |
+
v_i = tl.arange(0, BV)
|
| 72 |
+
m_A = o_i[:, None] >= o_i[None, :]
|
| 73 |
+
b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
|
| 74 |
+
b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
|
| 75 |
+
|
| 76 |
+
# [BK, BV]
|
| 77 |
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
| 78 |
+
# [BV]
|
| 79 |
+
b_hb = tl.zeros([BV], dtype=tl.float32)
|
| 80 |
+
if USE_INITIAL_STATE:
|
| 81 |
+
p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
|
| 82 |
+
b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
|
| 83 |
+
if USE_INITIAL_STATE_B:
|
| 84 |
+
p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
|
| 85 |
+
b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
|
| 86 |
+
|
| 87 |
+
for i_t in range(NT):
|
| 88 |
+
p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
|
| 89 |
+
p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
|
| 90 |
+
p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 91 |
+
p_o = tl.make_block_ptr(o+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 92 |
+
p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
|
| 93 |
+
p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
|
| 94 |
+
# [BK, BT]
|
| 95 |
+
b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
|
| 96 |
+
# [BT, BV]
|
| 97 |
+
b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
|
| 98 |
+
|
| 99 |
+
# [BT, BV]
|
| 100 |
+
b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
|
| 101 |
+
b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
|
| 102 |
+
mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
|
| 103 |
+
xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
|
| 104 |
+
var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
|
| 105 |
+
rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
|
| 106 |
+
b_kh_hat = (b_kh - mean) * rstd
|
| 107 |
+
|
| 108 |
+
b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
|
| 109 |
+
b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
|
| 110 |
+
b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
|
| 111 |
+
b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
|
| 112 |
+
* tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
|
| 113 |
+
|
| 114 |
+
# [BT, BK]
|
| 115 |
+
b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
|
| 116 |
+
# [BT]
|
| 117 |
+
b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
|
| 118 |
+
b_q = (b_q * scale).to(b_k.dtype)
|
| 119 |
+
|
| 120 |
+
# [BT, BT]
|
| 121 |
+
b_A = tl.dot(b_q, b_k, allow_tf32=False)
|
| 122 |
+
b_A = tl.where(m_A, b_A, 0)
|
| 123 |
+
b_Ae = tl.where(m_A, b_e[:, None], 0.0)
|
| 124 |
+
|
| 125 |
+
b_o = - tl.dot(b_e[:, None] * b_A.to(b_v2.dtype), b_v2, allow_tf32=False)
|
| 126 |
+
b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v2.dtype), b_v2, allow_tf32=False)
|
| 127 |
+
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
|
| 128 |
+
b_e_last = tl.load(p_e_last)
|
| 129 |
+
b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
|
| 130 |
+
b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
|
| 131 |
+
b_h = tl.where((v_i < V)[None, :], b_h, 0.)
|
| 132 |
+
b_hb = tl.where((v_i < V), b_hb, 0.)
|
| 133 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
| 134 |
+
|
| 135 |
+
if STORE_FINAL_STATE:
|
| 136 |
+
p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
|
| 137 |
+
p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
|
| 138 |
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
| 139 |
+
tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@triton.heuristics({
|
| 143 |
+
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
| 144 |
+
'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
|
| 145 |
+
})
|
| 146 |
+
@triton.autotune(
|
| 147 |
+
configs=[
|
| 148 |
+
triton.Config({}, num_warps=1),
|
| 149 |
+
triton.Config({}, num_warps=2),
|
| 150 |
+
triton.Config({}, num_warps=4)
|
| 151 |
+
],
|
| 152 |
+
key=['BT', 'BK', 'BV'],
|
| 153 |
+
)
|
| 154 |
+
@triton.jit(do_not_specialize=['T'])
|
| 155 |
+
def fused_chunk_ttt_linear_bwd_kernel_h(
|
| 156 |
+
k,
|
| 157 |
+
v,
|
| 158 |
+
v2,
|
| 159 |
+
x,
|
| 160 |
+
y,
|
| 161 |
+
r,
|
| 162 |
+
w,
|
| 163 |
+
b,
|
| 164 |
+
eta,
|
| 165 |
+
h0,
|
| 166 |
+
hb0,
|
| 167 |
+
h,
|
| 168 |
+
do,
|
| 169 |
+
dq,
|
| 170 |
+
scale,
|
| 171 |
+
eps,
|
| 172 |
+
T,
|
| 173 |
+
H: tl.constexpr,
|
| 174 |
+
K: tl.constexpr,
|
| 175 |
+
V: tl.constexpr,
|
| 176 |
+
BT: tl.constexpr,
|
| 177 |
+
BK: tl.constexpr,
|
| 178 |
+
BV: tl.constexpr,
|
| 179 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 180 |
+
USE_INITIAL_STATE_B: tl.constexpr,
|
| 181 |
+
):
|
| 182 |
+
i_nh = tl.program_id(0)
|
| 183 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 184 |
+
bos, _ = i_n * T, i_n * T + T
|
| 185 |
+
NT = tl.cdiv(T, BT)
|
| 186 |
+
boh = i_n * NT
|
| 187 |
+
|
| 188 |
+
o_i = tl.arange(0, BT)
|
| 189 |
+
v_i = tl.arange(0, BV)
|
| 190 |
+
m_A = o_i[:, None] >= o_i[None, :]
|
| 191 |
+
b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
|
| 192 |
+
b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
|
| 193 |
+
|
| 194 |
+
# [BK, BV]
|
| 195 |
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
| 196 |
+
# [BV]
|
| 197 |
+
b_hb = tl.zeros([BV], dtype=tl.float32)
|
| 198 |
+
if USE_INITIAL_STATE:
|
| 199 |
+
p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
|
| 200 |
+
b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
|
| 201 |
+
if USE_INITIAL_STATE_B:
|
| 202 |
+
p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
|
| 203 |
+
b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
|
| 204 |
+
|
| 205 |
+
for i_t in range(NT):
|
| 206 |
+
p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
|
| 207 |
+
p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
|
| 208 |
+
p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 209 |
+
p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 210 |
+
p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 211 |
+
p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 212 |
+
p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
|
| 213 |
+
p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
|
| 214 |
+
p_dq = tl.make_block_ptr(dq+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
|
| 215 |
+
p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 216 |
+
p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
|
| 217 |
+
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
| 218 |
+
# [BK, BT]
|
| 219 |
+
b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
|
| 220 |
+
# [BT, BV]
|
| 221 |
+
b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
|
| 222 |
+
|
| 223 |
+
b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
|
| 224 |
+
b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
|
| 225 |
+
mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
|
| 226 |
+
xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
|
| 227 |
+
var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
|
| 228 |
+
rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
|
| 229 |
+
b_kh_hat = (b_kh - mean) * rstd
|
| 230 |
+
|
| 231 |
+
b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
|
| 232 |
+
b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
|
| 233 |
+
b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
|
| 234 |
+
b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
|
| 235 |
+
* tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
|
| 236 |
+
tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1))
|
| 237 |
+
tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
| 238 |
+
tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1))
|
| 239 |
+
tl.store(p_v2, b_v2.to(p_v2.dtype.element_ty), boundary_check=(0, 1))
|
| 240 |
+
|
| 241 |
+
b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
|
| 242 |
+
b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
|
| 243 |
+
|
| 244 |
+
b_v2 = tl.where((v_i < V)[None, :], b_v2, 0.)
|
| 245 |
+
b_ds = tl.dot(b_do, tl.trans(b_v2).to(b_do.dtype))
|
| 246 |
+
b_ds = tl.where(m_A, b_ds, 0)
|
| 247 |
+
b_ds = b_ds.to(b_k.dtype)
|
| 248 |
+
b_dq = tl.dot(b_do, tl.trans(b_h).to(b_do.dtype))
|
| 249 |
+
b_dq -= tl.dot(b_ds, tl.trans(b_k)) * b_e[:, None]
|
| 250 |
+
b_dq *= scale
|
| 251 |
+
|
| 252 |
+
b_e_last = tl.load(p_e_last)
|
| 253 |
+
b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
|
| 254 |
+
b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
|
| 255 |
+
b_h = tl.where((v_i < V)[None, :], b_h, 0.)
|
| 256 |
+
b_hb = tl.where((v_i < V), b_hb, 0.)
|
| 257 |
+
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@triton.heuristics({
|
| 261 |
+
'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
|
| 262 |
+
'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None,
|
| 263 |
+
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
|
| 264 |
+
'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None,
|
| 265 |
+
})
|
| 266 |
+
@triton.autotune(
|
| 267 |
+
configs=[
|
| 268 |
+
triton.Config({}, num_warps=num_warps)
|
| 269 |
+
for num_warps in NUM_WARPS
|
| 270 |
+
],
|
| 271 |
+
key=['BT', 'BK', 'BV'],
|
| 272 |
+
)
|
| 273 |
+
@triton.jit(do_not_specialize=['T'])
|
| 274 |
+
def fused_chunk_ttt_linear_bwd_kernel_dh(
|
| 275 |
+
q,
|
| 276 |
+
k,
|
| 277 |
+
v,
|
| 278 |
+
v2,
|
| 279 |
+
x,
|
| 280 |
+
y,
|
| 281 |
+
r,
|
| 282 |
+
w,
|
| 283 |
+
b,
|
| 284 |
+
eta,
|
| 285 |
+
h,
|
| 286 |
+
dht,
|
| 287 |
+
dhbt,
|
| 288 |
+
dh0,
|
| 289 |
+
dhb0,
|
| 290 |
+
do,
|
| 291 |
+
dk,
|
| 292 |
+
dv,
|
| 293 |
+
de,
|
| 294 |
+
dw,
|
| 295 |
+
db,
|
| 296 |
+
scale,
|
| 297 |
+
T,
|
| 298 |
+
H: tl.constexpr,
|
| 299 |
+
K: tl.constexpr,
|
| 300 |
+
V: tl.constexpr,
|
| 301 |
+
BT: tl.constexpr,
|
| 302 |
+
BK: tl.constexpr,
|
| 303 |
+
BV: tl.constexpr,
|
| 304 |
+
USE_INITIAL_STATE: tl.constexpr,
|
| 305 |
+
USE_INITIAL_STATE_B: tl.constexpr,
|
| 306 |
+
USE_FINAL_STATE_GRADIENT: tl.constexpr,
|
| 307 |
+
USE_FINAL_STATE_GRADIENT_B: tl.constexpr,
|
| 308 |
+
):
|
| 309 |
+
i_nh = tl.program_id(0)
|
| 310 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 311 |
+
bos, _ = i_n * T, i_n * T + T
|
| 312 |
+
NT = tl.cdiv(T, BT)
|
| 313 |
+
boh = i_n * NT
|
| 314 |
+
|
| 315 |
+
# [BK, BV]
|
| 316 |
+
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
| 317 |
+
# [BV]
|
| 318 |
+
b_dhb = tl.zeros([BV], dtype=tl.float32)
|
| 319 |
+
if USE_FINAL_STATE_GRADIENT:
|
| 320 |
+
p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
|
| 321 |
+
b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero")
|
| 322 |
+
if USE_FINAL_STATE_GRADIENT_B:
|
| 323 |
+
p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
|
| 324 |
+
b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero")
|
| 325 |
+
|
| 326 |
+
# [BV]
|
| 327 |
+
o_i = tl.arange(0, BT)
|
| 328 |
+
v_i = tl.arange(0, BV)
|
| 329 |
+
m_A = o_i[:, None] >= o_i[None, :]
|
| 330 |
+
m_A_t = o_i[:, None] <= o_i[None, :]
|
| 331 |
+
b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
|
| 332 |
+
b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
|
| 333 |
+
b_dw = tl.zeros([BV,], dtype=b_w.dtype)
|
| 334 |
+
b_db = tl.zeros([BV,], dtype=b_b.dtype)
|
| 335 |
+
p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
|
| 336 |
+
p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
|
| 337 |
+
|
| 338 |
+
for i_t in range(NT - 1, -1, -1):
|
| 339 |
+
p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1))
|
| 340 |
+
p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
|
| 341 |
+
p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
|
| 342 |
+
p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 343 |
+
p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 344 |
+
p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 345 |
+
p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 346 |
+
p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
|
| 347 |
+
p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
|
| 348 |
+
p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 349 |
+
p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
|
| 350 |
+
p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
|
| 351 |
+
p_de = tl.make_block_ptr(de+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
|
| 352 |
+
p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
|
| 353 |
+
b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
|
| 354 |
+
b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
|
| 355 |
+
b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
|
| 356 |
+
b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
|
| 357 |
+
b_e_last = tl.load(p_e_last)
|
| 358 |
+
b_A = tl.dot(b_k, b_q)
|
| 359 |
+
b_A = - tl.where(m_A_t, b_A * scale * b_e[None, :], 0).to(do.dtype.element_ty)
|
| 360 |
+
b_Ae = - tl.where(m_A_t, b_e[None, :], 0).to(do.dtype.element_ty)
|
| 361 |
+
b_dv_new = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do)
|
| 362 |
+
b_dv_new -= tl.dot(b_e_last * b_k, b_dh.to(b_k.dtype))
|
| 363 |
+
b_dv_new -= b_e_last * b_dhb.to(b_k.dtype)[None, :]
|
| 364 |
+
|
| 365 |
+
b_v2 = tl.load(p_v2, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
|
| 366 |
+
b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
|
| 367 |
+
b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
|
| 368 |
+
b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
|
| 369 |
+
b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) -
|
| 370 |
+
b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
|
| 371 |
+
b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) +
|
| 372 |
+
b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
|
| 373 |
+
b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v2.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True)
|
| 374 |
+
|
| 375 |
+
b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
|
| 376 |
+
b_w = b_w.to(b_k.dtype)
|
| 377 |
+
b_b = b_b.to(b_k.dtype)
|
| 378 |
+
b_dv = -b_w * b_dy.to(b_k.dtype)
|
| 379 |
+
b_dk = b_w * b_dy.to(b_k.dtype)
|
| 380 |
+
b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) +
|
| 381 |
+
(b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype)
|
| 382 |
+
b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype)
|
| 383 |
+
b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype)
|
| 384 |
+
|
| 385 |
+
b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
|
| 386 |
+
b_q = (b_q * scale).to(b_q.dtype)
|
| 387 |
+
b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) -
|
| 388 |
+
b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V
|
| 389 |
+
b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V
|
| 390 |
+
b_dkh = tl.where((v_i < V)[None, :] * (o_i < T-i_t*BT)[:, None], b_dkh, 0.)
|
| 391 |
+
b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype)
|
| 392 |
+
|
| 393 |
+
b_ds = tl.dot(b_do, tl.trans(b_v2))
|
| 394 |
+
b_ds = tl.where(m_A, b_ds, 0)
|
| 395 |
+
b_ds = b_ds.to(b_k.dtype)
|
| 396 |
+
i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1)
|
| 397 |
+
mask = (o_i == i_last)
|
| 398 |
+
b_dk -= b_e_last * tl.dot(b_v2, tl.trans(b_dh).to(b_v2.dtype))
|
| 399 |
+
b_dk -= tl.dot(tl.trans(b_ds), tl.trans(b_q) * b_e[:, None])
|
| 400 |
+
b_de = mask * tl.sum(- b_dh * tl.trans(tl.dot(tl.trans(b_v2), b_k))).to(b_k.dtype)
|
| 401 |
+
b_de -= mask * tl.sum(b_dhb * tl.sum(b_v2, axis=0)).to(b_k.dtype)
|
| 402 |
+
b_de -= tl.sum(tl.dot(b_ds, b_k) * tl.trans(b_q).to(b_k.dtype), axis=1)
|
| 403 |
+
b_de -= tl.sum(b_ds, axis=1)
|
| 404 |
+
b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh)
|
| 405 |
+
b_dhb += tl.sum(b_do + b_dkh, axis=0)
|
| 406 |
+
b_dh = tl.where((v_i < V)[None, :], b_dh, 0.)
|
| 407 |
+
b_dhb = tl.where((v_i < V), b_dhb, 0.)
|
| 408 |
+
|
| 409 |
+
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
| 410 |
+
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
| 411 |
+
tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,))
|
| 412 |
+
tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,))
|
| 413 |
+
tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
|
| 414 |
+
|
| 415 |
+
if USE_INITIAL_STATE:
|
| 416 |
+
p_dh0 = tl.make_block_ptr(dh0+i_nh*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
|
| 417 |
+
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|
| 418 |
+
if USE_INITIAL_STATE_B:
|
| 419 |
+
p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (0,), (BV,), (0,))
|
| 420 |
+
tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,))
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def fused_chunk_ttt_linear_bwd_h(
|
| 424 |
+
q: torch.Tensor,
|
| 425 |
+
k: torch.Tensor,
|
| 426 |
+
v: torch.Tensor,
|
| 427 |
+
w: torch.Tensor,
|
| 428 |
+
b: torch.Tensor,
|
| 429 |
+
eta: torch.Tensor,
|
| 430 |
+
scale: float,
|
| 431 |
+
eps: float,
|
| 432 |
+
do: torch.Tensor,
|
| 433 |
+
BT: int = 16,
|
| 434 |
+
initial_state: torch.Tensor = None,
|
| 435 |
+
initial_state_bias: torch.Tensor = None,
|
| 436 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 437 |
+
):
|
| 438 |
+
assert cu_seqlens is None, "bwd of varlen is not implemented yet."
|
| 439 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 440 |
+
# N: the actual number of sequences in the batch with either equal or variable lengths
|
| 441 |
+
N, NT = B, triton.cdiv(T, BT)
|
| 442 |
+
BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
|
| 443 |
+
assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
|
| 444 |
+
|
| 445 |
+
h = k.new_empty(B, NT, H, K, V)
|
| 446 |
+
r = v.new_empty(B, T, H, 1, dtype=torch.float32)
|
| 447 |
+
v2 = torch.empty_like(v)
|
| 448 |
+
x = torch.empty_like(v)
|
| 449 |
+
y = torch.empty_like(v)
|
| 450 |
+
dq = torch.empty_like(q)
|
| 451 |
+
|
| 452 |
+
grid = (N * H,)
|
| 453 |
+
fused_chunk_ttt_linear_bwd_kernel_h[grid](
|
| 454 |
+
k=k,
|
| 455 |
+
v=v,
|
| 456 |
+
v2=v2,
|
| 457 |
+
x=x,
|
| 458 |
+
y=y,
|
| 459 |
+
r=r,
|
| 460 |
+
w=w,
|
| 461 |
+
b=b,
|
| 462 |
+
eta=eta,
|
| 463 |
+
h0=initial_state,
|
| 464 |
+
hb0=initial_state_bias,
|
| 465 |
+
h=h,
|
| 466 |
+
do=do,
|
| 467 |
+
dq=dq,
|
| 468 |
+
scale=scale,
|
| 469 |
+
eps=eps,
|
| 470 |
+
T=T,
|
| 471 |
+
H=H,
|
| 472 |
+
K=K,
|
| 473 |
+
V=V,
|
| 474 |
+
BT=BT,
|
| 475 |
+
BK=BK,
|
| 476 |
+
BV=BV,
|
| 477 |
+
)
|
| 478 |
+
return dq, h, v2, x, y, r
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def fused_chunk_ttt_linear_bwd_dh(
|
| 482 |
+
q: torch.Tensor,
|
| 483 |
+
k: torch.Tensor,
|
| 484 |
+
v: torch.Tensor,
|
| 485 |
+
v2: torch.Tensor,
|
| 486 |
+
x: torch.Tensor,
|
| 487 |
+
y: torch.Tensor,
|
| 488 |
+
r: torch.Tensor,
|
| 489 |
+
w: torch.Tensor,
|
| 490 |
+
b: torch.Tensor,
|
| 491 |
+
eta: torch.Tensor,
|
| 492 |
+
scale: float,
|
| 493 |
+
h: torch.Tensor,
|
| 494 |
+
do: torch.Tensor,
|
| 495 |
+
dht: torch.Tensor,
|
| 496 |
+
dhbt: torch.Tensor,
|
| 497 |
+
BT: int = 16,
|
| 498 |
+
initial_state: torch.Tensor = None,
|
| 499 |
+
initial_state_bias: torch.Tensor = None,
|
| 500 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 501 |
+
):
|
| 502 |
+
assert cu_seqlens is None, "bwd of varlen is not implemented yet."
|
| 503 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 504 |
+
# N: the actual number of sequences in the batch with either equal or variable lengths
|
| 505 |
+
N = B
|
| 506 |
+
BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
|
| 507 |
+
assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
|
| 508 |
+
|
| 509 |
+
dh0 = torch.empty_like(initial_state, dtype=torch.float32) if initial_state is not None else None
|
| 510 |
+
dhb0 = torch.empty_like(initial_state_bias, dtype=torch.float32) if initial_state_bias is not None else None
|
| 511 |
+
dk = torch.empty_like(k)
|
| 512 |
+
dv = torch.empty_like(v)
|
| 513 |
+
de = torch.empty_like(eta)
|
| 514 |
+
dw = w.new_empty(B, H, V)
|
| 515 |
+
db = b.new_empty(B, H, V)
|
| 516 |
+
|
| 517 |
+
grid = (N * H,)
|
| 518 |
+
fused_chunk_ttt_linear_bwd_kernel_dh[grid](
|
| 519 |
+
q=q,
|
| 520 |
+
k=k,
|
| 521 |
+
v=v,
|
| 522 |
+
v2=v2,
|
| 523 |
+
x=x,
|
| 524 |
+
y=y,
|
| 525 |
+
r=r,
|
| 526 |
+
w=w,
|
| 527 |
+
b=b,
|
| 528 |
+
eta=eta,
|
| 529 |
+
h=h,
|
| 530 |
+
dht=dht,
|
| 531 |
+
dhbt=dhbt,
|
| 532 |
+
dh0=dh0,
|
| 533 |
+
dhb0=dhb0,
|
| 534 |
+
do=do,
|
| 535 |
+
dk=dk,
|
| 536 |
+
dv=dv,
|
| 537 |
+
de=de,
|
| 538 |
+
dw=dw,
|
| 539 |
+
db=db,
|
| 540 |
+
scale=scale,
|
| 541 |
+
T=T,
|
| 542 |
+
H=H,
|
| 543 |
+
K=K,
|
| 544 |
+
V=V,
|
| 545 |
+
BT=BT,
|
| 546 |
+
BK=BK,
|
| 547 |
+
BV=BV,
|
| 548 |
+
)
|
| 549 |
+
dw = dw.sum(dim=0)
|
| 550 |
+
db = db.sum(dim=0)
|
| 551 |
+
return dk, dv, de, dw, db, dh0, dhb0
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def fused_chunk_ttt_linear_fwd(
|
| 555 |
+
q: torch.Tensor,
|
| 556 |
+
k: torch.Tensor,
|
| 557 |
+
v: torch.Tensor,
|
| 558 |
+
w: torch.Tensor,
|
| 559 |
+
b: torch.Tensor,
|
| 560 |
+
eta: torch.Tensor,
|
| 561 |
+
scale: float,
|
| 562 |
+
eps: float,
|
| 563 |
+
initial_state: torch.Tensor,
|
| 564 |
+
initial_state_bias: torch.Tensor,
|
| 565 |
+
output_final_state: bool,
|
| 566 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 567 |
+
BT: int = 16
|
| 568 |
+
):
|
| 569 |
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
| 570 |
+
# N: the actual number of sequences in the batch with either equal or variable lengths
|
| 571 |
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
| 572 |
+
BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
|
| 573 |
+
assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
|
| 574 |
+
o = torch.empty_like(v)
|
| 575 |
+
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
| 576 |
+
final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None
|
| 577 |
+
|
| 578 |
+
grid = (N * H,)
|
| 579 |
+
fused_chunk_ttt_linear_fwd_kernel[grid](
|
| 580 |
+
q=q,
|
| 581 |
+
k=k,
|
| 582 |
+
v=v,
|
| 583 |
+
eta=eta,
|
| 584 |
+
w=w,
|
| 585 |
+
b=b,
|
| 586 |
+
o=o,
|
| 587 |
+
scale=scale,
|
| 588 |
+
eps=eps,
|
| 589 |
+
h0=initial_state,
|
| 590 |
+
hb0=initial_state_bias,
|
| 591 |
+
ht=final_state,
|
| 592 |
+
hbt=final_state_bias,
|
| 593 |
+
cu_seqlens=cu_seqlens,
|
| 594 |
+
T=T,
|
| 595 |
+
H=H,
|
| 596 |
+
K=K,
|
| 597 |
+
V=V,
|
| 598 |
+
BT=BT,
|
| 599 |
+
BK=BK,
|
| 600 |
+
BV=BV,
|
| 601 |
+
)
|
| 602 |
+
return o, final_state, final_state_bias
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def fused_chunk_ttt_linear_bwd(
|
| 606 |
+
q: torch.Tensor,
|
| 607 |
+
k: torch.Tensor,
|
| 608 |
+
v: torch.Tensor,
|
| 609 |
+
w: torch.Tensor,
|
| 610 |
+
b: torch.Tensor,
|
| 611 |
+
eta: torch.Tensor,
|
| 612 |
+
scale: float,
|
| 613 |
+
eps: float,
|
| 614 |
+
do: torch.Tensor,
|
| 615 |
+
dht: torch.Tensor,
|
| 616 |
+
dhbt: torch.Tensor,
|
| 617 |
+
BT: int = 16,
|
| 618 |
+
initial_state: torch.Tensor = None,
|
| 619 |
+
initial_state_bias: torch.Tensor = None,
|
| 620 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 621 |
+
):
|
| 622 |
+
assert cu_seqlens is None, "bwd of varlen is not implemented yet."
|
| 623 |
+
dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h(
|
| 624 |
+
q=q,
|
| 625 |
+
k=k,
|
| 626 |
+
v=v,
|
| 627 |
+
w=w,
|
| 628 |
+
b=b,
|
| 629 |
+
eta=eta,
|
| 630 |
+
scale=scale,
|
| 631 |
+
eps=eps,
|
| 632 |
+
do=do,
|
| 633 |
+
BT=BT,
|
| 634 |
+
initial_state=initial_state,
|
| 635 |
+
initial_state_bias=initial_state_bias,
|
| 636 |
+
cu_seqlens=cu_seqlens,
|
| 637 |
+
)
|
| 638 |
+
dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd_dh(
|
| 639 |
+
q=q,
|
| 640 |
+
k=k,
|
| 641 |
+
v=v,
|
| 642 |
+
v2=v2,
|
| 643 |
+
x=x,
|
| 644 |
+
y=y,
|
| 645 |
+
r=rstd,
|
| 646 |
+
w=w,
|
| 647 |
+
b=b,
|
| 648 |
+
eta=eta,
|
| 649 |
+
scale=scale,
|
| 650 |
+
h=h,
|
| 651 |
+
do=do,
|
| 652 |
+
dht=dht,
|
| 653 |
+
dhbt=dhbt,
|
| 654 |
+
BT=BT,
|
| 655 |
+
initial_state=initial_state,
|
| 656 |
+
initial_state_bias=initial_state_bias,
|
| 657 |
+
cu_seqlens=cu_seqlens,
|
| 658 |
+
)
|
| 659 |
+
return dq, dk, dv, de, dw, db, dh0, dhb0
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class FusedChunkTTTLinearFunction(torch.autograd.Function):
|
| 663 |
+
|
| 664 |
+
@staticmethod
|
| 665 |
+
@input_guard
|
| 666 |
+
@autocast_custom_fwd
|
| 667 |
+
def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state,
|
| 668 |
+
initial_state_bias, output_final_state, cu_seqlens):
|
| 669 |
+
o, final_state, final_state_bias = fused_chunk_ttt_linear_fwd(
|
| 670 |
+
q=q,
|
| 671 |
+
k=k,
|
| 672 |
+
v=v,
|
| 673 |
+
w=w,
|
| 674 |
+
b=b,
|
| 675 |
+
eta=eta,
|
| 676 |
+
scale=scale,
|
| 677 |
+
eps=eps,
|
| 678 |
+
BT=BT,
|
| 679 |
+
initial_state=initial_state,
|
| 680 |
+
initial_state_bias=initial_state_bias,
|
| 681 |
+
output_final_state=output_final_state,
|
| 682 |
+
cu_seqlens=cu_seqlens,
|
| 683 |
+
)
|
| 684 |
+
ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias)
|
| 685 |
+
ctx.BT = BT
|
| 686 |
+
ctx.scale = scale
|
| 687 |
+
ctx.eps = eps
|
| 688 |
+
ctx.cu_seqlens = cu_seqlens
|
| 689 |
+
return o.to(q.dtype), final_state, final_state_bias
|
| 690 |
+
|
| 691 |
+
@staticmethod
|
| 692 |
+
@input_guard
|
| 693 |
+
@autocast_custom_bwd
|
| 694 |
+
def backward(ctx, do, dht, dhbt):
|
| 695 |
+
q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors
|
| 696 |
+
dq, dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd(
|
| 697 |
+
q=q,
|
| 698 |
+
k=k,
|
| 699 |
+
v=v,
|
| 700 |
+
w=w,
|
| 701 |
+
b=b,
|
| 702 |
+
eta=eta,
|
| 703 |
+
scale=ctx.scale,
|
| 704 |
+
eps=ctx.eps,
|
| 705 |
+
do=do,
|
| 706 |
+
dht=dht,
|
| 707 |
+
dhbt=dhbt,
|
| 708 |
+
BT=ctx.BT,
|
| 709 |
+
initial_state=initial_state,
|
| 710 |
+
initial_state_bias=initial_state_bias,
|
| 711 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 712 |
+
)
|
| 713 |
+
return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def norm_residual(x, weight, bias, eps):
|
| 717 |
+
# GroupNorm and Residual
|
| 718 |
+
B, T, H, D = x.shape
|
| 719 |
+
x += group_norm(
|
| 720 |
+
x.reshape(B, T, -1).clone(),
|
| 721 |
+
weight=weight.reshape(-1).clone(),
|
| 722 |
+
bias=bias.reshape(-1).clone(),
|
| 723 |
+
eps=eps,
|
| 724 |
+
num_groups=H,
|
| 725 |
+
).reshape(x.shape)
|
| 726 |
+
return x
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def fused_chunk_ttt_linear(
|
| 730 |
+
q: torch.Tensor,
|
| 731 |
+
k: torch.Tensor,
|
| 732 |
+
v: torch.Tensor,
|
| 733 |
+
w: torch.Tensor,
|
| 734 |
+
b: torch.Tensor,
|
| 735 |
+
eta: torch.Tensor,
|
| 736 |
+
scale: float = None,
|
| 737 |
+
eps: float = 1e-6,
|
| 738 |
+
chunk_size: int = 16,
|
| 739 |
+
initial_state: torch.Tensor = None,
|
| 740 |
+
initial_state_bias: torch.Tensor = None,
|
| 741 |
+
output_final_state: bool = False,
|
| 742 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 743 |
+
head_first: bool = False,
|
| 744 |
+
):
|
| 745 |
+
r"""
|
| 746 |
+
Args:
|
| 747 |
+
q (torch.Tensor):
|
| 748 |
+
queries of shape `(B, H, T, K)`
|
| 749 |
+
k (torch.Tensor):
|
| 750 |
+
keys of shape `(B, H, T, K)`
|
| 751 |
+
v (torch.Tensor):
|
| 752 |
+
values of shape `(B, H, T, V)`
|
| 753 |
+
w (torch.Tensor):
|
| 754 |
+
layer norm weight of shape `(H, V)`
|
| 755 |
+
b (torch.Tensor):
|
| 756 |
+
layer norm bias of shape `(H, V)`
|
| 757 |
+
eta (torch.Tensor):
|
| 758 |
+
Learning rate for hidden state, of shape `(B, H, T, 1)`.
|
| 759 |
+
scale (Optional[int]):
|
| 760 |
+
Scale factor for the RetNet attention scores.
|
| 761 |
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
| 762 |
+
chunk_size (int):
|
| 763 |
+
chunk size. Default: `16`.
|
| 764 |
+
initial_state (Optional[torch.Tensor]):
|
| 765 |
+
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
| 766 |
+
initial_state_bias (Optional[torch.Tensor]):
|
| 767 |
+
Initial state bias of shape `(B, H, 1, V)`. Default: `None`.
|
| 768 |
+
output_final_state (Optional[bool]):
|
| 769 |
+
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
| 770 |
+
cu_seqlens (torch.LongTensor):
|
| 771 |
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
| 772 |
+
consistent with the FlashAttention API.
|
| 773 |
+
head_first (Optional[bool]):
|
| 774 |
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
| 775 |
+
Default: `False`.
|
| 776 |
+
|
| 777 |
+
Returns:
|
| 778 |
+
o (torch.Tensor):
|
| 779 |
+
Outputs of shape `[B, H, T, V]`
|
| 780 |
+
final_state (torch.Tensor):
|
| 781 |
+
Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
|
| 782 |
+
final_state_bias (torch.Tensor):
|
| 783 |
+
Final state bias of shape `[B, H, 1, V]` if `output_final_state=True` else `None`.
|
| 784 |
+
"""
|
| 785 |
+
assert q.dtype == k.dtype == v.dtype
|
| 786 |
+
assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
|
| 787 |
+
if isinstance(eta, float):
|
| 788 |
+
eta = torch.full_like(q[:, :, :, :1], eta)
|
| 789 |
+
if head_first:
|
| 790 |
+
raise DeprecationWarning(
|
| 791 |
+
"head_first is deprecated and will be removed in a future version. "
|
| 792 |
+
"Please use head_first=False for now instead."
|
| 793 |
+
)
|
| 794 |
+
q, k, v, eta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, eta))
|
| 795 |
+
if not head_first and q.shape[1] < q.shape[2]:
|
| 796 |
+
warnings.warn(
|
| 797 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
| 798 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 799 |
+
"when head_first=False was specified. "
|
| 800 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 801 |
+
)
|
| 802 |
+
if cu_seqlens is not None:
|
| 803 |
+
if q.shape[0] != 1:
|
| 804 |
+
raise ValueError(
|
| 805 |
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
| 806 |
+
f"Please flatten variable-length inputs before processing."
|
| 807 |
+
)
|
| 808 |
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
| 809 |
+
raise ValueError(
|
| 810 |
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
| 811 |
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
| 812 |
+
)
|
| 813 |
+
if scale is None:
|
| 814 |
+
scale = k.shape[-1] ** -0.5
|
| 815 |
+
else:
|
| 816 |
+
assert scale > 0, "Scale must be positive."
|
| 817 |
+
o, final_state, final_state_bias = FusedChunkTTTLinearFunction.apply(
|
| 818 |
+
q,
|
| 819 |
+
k,
|
| 820 |
+
v,
|
| 821 |
+
w,
|
| 822 |
+
b,
|
| 823 |
+
chunk_size,
|
| 824 |
+
eta,
|
| 825 |
+
scale,
|
| 826 |
+
eps,
|
| 827 |
+
initial_state,
|
| 828 |
+
initial_state_bias,
|
| 829 |
+
output_final_state,
|
| 830 |
+
cu_seqlens,
|
| 831 |
+
)
|
| 832 |
+
o = norm_residual(o, w, b, eps)
|
| 833 |
+
if head_first:
|
| 834 |
+
o = rearrange(o, 'b t h ... -> b h t ...')
|
| 835 |
+
return o, final_state, final_state_bias
|
fla3/ops/ttt/naive.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def ttt_linear(
|
| 9 |
+
q: torch.Tensor,
|
| 10 |
+
k: torch.Tensor,
|
| 11 |
+
v: torch.Tensor,
|
| 12 |
+
w: torch.Tensor,
|
| 13 |
+
b: torch.Tensor,
|
| 14 |
+
eta: torch.Tensor,
|
| 15 |
+
scale: float,
|
| 16 |
+
eps: float,
|
| 17 |
+
mini_batch_size: int,
|
| 18 |
+
initial_state: torch.Tensor,
|
| 19 |
+
initial_state_bias: torch.Tensor,
|
| 20 |
+
output_final_state: bool
|
| 21 |
+
):
|
| 22 |
+
B, H, T, D = q.shape
|
| 23 |
+
BT = mini_batch_size
|
| 24 |
+
NT = T // BT
|
| 25 |
+
# [NT, B, H, mini_batch_size, D]
|
| 26 |
+
_q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
|
| 27 |
+
_k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
|
| 28 |
+
_v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
|
| 29 |
+
# [NT, B, H, BT, 1]
|
| 30 |
+
_eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4)
|
| 31 |
+
# [H, 1, D]
|
| 32 |
+
w = w.reshape(H, 1, D).to(torch.float32)
|
| 33 |
+
b = b.reshape(H, 1, D).to(torch.float32)
|
| 34 |
+
|
| 35 |
+
h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state
|
| 36 |
+
hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias
|
| 37 |
+
q *= scale
|
| 38 |
+
# [NT, B, H, BT, D]
|
| 39 |
+
o = torch.empty_like(_v)
|
| 40 |
+
|
| 41 |
+
for i in range(NT):
|
| 42 |
+
q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]]
|
| 43 |
+
kh = k_i @ h + hb
|
| 44 |
+
reconstruction_target = v_i - k_i
|
| 45 |
+
|
| 46 |
+
mean = kh.mean(-1, True)
|
| 47 |
+
var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
|
| 48 |
+
rstd = torch.sqrt(var + eps).to(torch.float32)
|
| 49 |
+
kh_hat = (kh - mean) / rstd
|
| 50 |
+
|
| 51 |
+
g = w * kh_hat + b - reconstruction_target
|
| 52 |
+
g *= w
|
| 53 |
+
v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D)
|
| 54 |
+
|
| 55 |
+
Attn = torch.tril(q_i @ k_i.transpose(-2, -1))
|
| 56 |
+
o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new
|
| 57 |
+
h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new
|
| 58 |
+
hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True)
|
| 59 |
+
# layer norm with residuals
|
| 60 |
+
|
| 61 |
+
mean = o_i.mean(dim=-1, keepdim=True)
|
| 62 |
+
var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
|
| 63 |
+
rstd = torch.sqrt(var + eps).to(torch.float32)
|
| 64 |
+
o[i] = o_i + (o_i - mean) / rstd * w + b
|
| 65 |
+
|
| 66 |
+
# [B, H, T, D]
|
| 67 |
+
o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
|
| 68 |
+
h = h if output_final_state else None
|
| 69 |
+
hb = hb if output_final_state else None
|
| 70 |
+
return o, h, hb
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def chunk_ttt_linear_ref(
|
| 74 |
+
q: torch.Tensor,
|
| 75 |
+
k: torch.Tensor,
|
| 76 |
+
v: torch.Tensor,
|
| 77 |
+
w: torch.Tensor,
|
| 78 |
+
b: torch.Tensor,
|
| 79 |
+
eta: torch.Tensor,
|
| 80 |
+
scale: float = None,
|
| 81 |
+
eps: float = 1e-6,
|
| 82 |
+
mini_batch_size: int = 16,
|
| 83 |
+
initial_state: torch.Tensor = None,
|
| 84 |
+
initial_state_bias: torch.Tensor = None,
|
| 85 |
+
output_final_state: bool = False,
|
| 86 |
+
head_first: bool = False,
|
| 87 |
+
):
|
| 88 |
+
assert q.dtype == k.dtype == v.dtype
|
| 89 |
+
assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same."
|
| 90 |
+
if isinstance(eta, float):
|
| 91 |
+
eta = torch.full_like(q[:, :, :, :1], eta)
|
| 92 |
+
if scale is None:
|
| 93 |
+
scale = k.shape[-1] ** -0.5
|
| 94 |
+
if not head_first:
|
| 95 |
+
q = q.transpose(1, 2)
|
| 96 |
+
k = k.transpose(1, 2)
|
| 97 |
+
v = v.transpose(1, 2)
|
| 98 |
+
eta = eta.transpose(1, 2)
|
| 99 |
+
T = q.shape[-2]
|
| 100 |
+
padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size
|
| 101 |
+
if padded > 0:
|
| 102 |
+
q = F.pad(q, (0, 0, 0, padded))
|
| 103 |
+
k = F.pad(k, (0, 0, 0, padded))
|
| 104 |
+
v = F.pad(v, (0, 0, 0, padded))
|
| 105 |
+
eta = F.pad(eta, (0, 0, 0, padded))
|
| 106 |
+
eta[:, :, -1, :] = eta[:, :, -(padded+1), :]
|
| 107 |
+
assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size."
|
| 108 |
+
q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b])
|
| 109 |
+
o, final_state, final_state_bias = ttt_linear(
|
| 110 |
+
q,
|
| 111 |
+
k,
|
| 112 |
+
v,
|
| 113 |
+
w,
|
| 114 |
+
b,
|
| 115 |
+
eta,
|
| 116 |
+
scale,
|
| 117 |
+
eps,
|
| 118 |
+
mini_batch_size,
|
| 119 |
+
initial_state,
|
| 120 |
+
initial_state_bias,
|
| 121 |
+
output_final_state,
|
| 122 |
+
)
|
| 123 |
+
o = o[:, :, :T, :].contiguous()
|
| 124 |
+
if not head_first:
|
| 125 |
+
o = o.transpose(1, 2)
|
| 126 |
+
return o, final_state, final_state_bias
|
fla3/ops/utils/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from .asm import fp32_to_tf32_asm
|
| 4 |
+
from .cumsum import (
|
| 5 |
+
chunk_global_cumsum,
|
| 6 |
+
chunk_global_cumsum_scalar,
|
| 7 |
+
chunk_global_cumsum_vector,
|
| 8 |
+
chunk_local_cumsum,
|
| 9 |
+
chunk_local_cumsum_scalar,
|
| 10 |
+
chunk_local_cumsum_vector
|
| 11 |
+
)
|
| 12 |
+
from .index import (
|
| 13 |
+
prepare_chunk_indices,
|
| 14 |
+
prepare_chunk_offsets,
|
| 15 |
+
prepare_cu_seqlens_from_mask,
|
| 16 |
+
prepare_lens,
|
| 17 |
+
prepare_lens_from_mask,
|
| 18 |
+
prepare_position_ids,
|
| 19 |
+
prepare_sequence_ids,
|
| 20 |
+
prepare_token_indices
|
| 21 |
+
)
|
| 22 |
+
from .logsumexp import logsumexp_fwd
|
| 23 |
+
from .matmul import addmm, matmul
|
| 24 |
+
from .pack import pack_sequence, unpack_sequence
|
| 25 |
+
from .pooling import mean_pooling
|
| 26 |
+
from .softmax import softmax_bwd, softmax_fwd
|
| 27 |
+
from .solve_tril import solve_tril
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
'chunk_global_cumsum',
|
| 31 |
+
'chunk_global_cumsum_scalar',
|
| 32 |
+
'chunk_global_cumsum_vector',
|
| 33 |
+
'chunk_local_cumsum',
|
| 34 |
+
'chunk_local_cumsum_scalar',
|
| 35 |
+
'chunk_local_cumsum_vector',
|
| 36 |
+
'pack_sequence',
|
| 37 |
+
'unpack_sequence',
|
| 38 |
+
'prepare_chunk_indices',
|
| 39 |
+
'prepare_chunk_offsets',
|
| 40 |
+
'prepare_cu_seqlens_from_mask',
|
| 41 |
+
'prepare_lens',
|
| 42 |
+
'prepare_lens_from_mask',
|
| 43 |
+
'prepare_position_ids',
|
| 44 |
+
'prepare_sequence_ids',
|
| 45 |
+
'prepare_token_indices',
|
| 46 |
+
'logsumexp_fwd',
|
| 47 |
+
'addmm',
|
| 48 |
+
'matmul',
|
| 49 |
+
'mean_pooling',
|
| 50 |
+
'softmax_bwd',
|
| 51 |
+
'softmax_fwd',
|
| 52 |
+
'fp32_to_tf32_asm',
|
| 53 |
+
'solve_tril',
|
| 54 |
+
]
|
fla3/ops/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
fla3/ops/utils/__pycache__/asm.cpython-310.pyc
ADDED
|
Binary file (482 Bytes). View file
|
|
|
fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
fla3/ops/utils/__pycache__/index.cpython-310.pyc
ADDED
|
Binary file (3.12 kB). View file
|
|
|
fla3/ops/utils/__pycache__/index.cpython-312.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
fla3/ops/utils/__pycache__/matmul.cpython-312.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
fla3/ops/utils/__pycache__/op.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
fla3/ops/utils/__pycache__/op.cpython-312.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
fla3/ops/utils/__pycache__/pack.cpython-310.pyc
ADDED
|
Binary file (4.56 kB). View file
|
|
|
fla3/ops/utils/__pycache__/pack.cpython-312.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
fla3/ops/utils/__pycache__/softmax.cpython-310.pyc
ADDED
|
Binary file (2.35 kB). View file
|
|
|
fla3/ops/utils/__pycache__/solve_tril.cpython-310.pyc
ADDED
|
Binary file (7.63 kB). View file
|
|
|
fla3/ops/utils/asm.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from ...utils import device_platform
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def fp32_to_tf32_asm() -> str:
|
| 7 |
+
"""
|
| 8 |
+
Get the assembly code for converting FP32 to TF32.
|
| 9 |
+
"""
|
| 10 |
+
ASM_DICT = {
|
| 11 |
+
'nvidia': 'cvt.rna.tf32.f32 $0, $1;'
|
| 12 |
+
}
|
| 13 |
+
if device_platform in ASM_DICT:
|
| 14 |
+
return ASM_DICT[device_platform]
|
| 15 |
+
else:
|
| 16 |
+
# return empty string if the device is not supported
|
| 17 |
+
return ""
|
fla3/ops/utils/cumsum.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 12 |
+
from ...utils import check_shared_mem, input_guard
|
| 13 |
+
|
| 14 |
+
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.heuristics({
|
| 18 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 19 |
+
})
|
| 20 |
+
@triton.autotune(
|
| 21 |
+
configs=[
|
| 22 |
+
triton.Config({}, num_warps=num_warps)
|
| 23 |
+
for num_warps in [1, 2, 4, 8]
|
| 24 |
+
],
|
| 25 |
+
key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']
|
| 26 |
+
)
|
| 27 |
+
@triton.jit(do_not_specialize=['T'])
|
| 28 |
+
def chunk_local_cumsum_scalar_kernel(
|
| 29 |
+
s,
|
| 30 |
+
o,
|
| 31 |
+
cu_seqlens,
|
| 32 |
+
chunk_indices,
|
| 33 |
+
T,
|
| 34 |
+
B: tl.constexpr,
|
| 35 |
+
H: tl.constexpr,
|
| 36 |
+
BT: tl.constexpr,
|
| 37 |
+
REVERSE: tl.constexpr,
|
| 38 |
+
IS_VARLEN: tl.constexpr,
|
| 39 |
+
HEAD_FIRST: tl.constexpr,
|
| 40 |
+
):
|
| 41 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 42 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 43 |
+
if IS_VARLEN:
|
| 44 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 45 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 46 |
+
T = eos - bos
|
| 47 |
+
else:
|
| 48 |
+
bos, eos = i_b * T, i_b * T + T
|
| 49 |
+
|
| 50 |
+
if HEAD_FIRST:
|
| 51 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 52 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 53 |
+
else:
|
| 54 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 55 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 56 |
+
# [BT]
|
| 57 |
+
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
|
| 58 |
+
b_o = tl.cumsum(b_s, axis=0)
|
| 59 |
+
if REVERSE:
|
| 60 |
+
b_z = tl.sum(b_s, axis=0)
|
| 61 |
+
b_o = -b_o + b_z[None] + b_s
|
| 62 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@triton.heuristics({
|
| 66 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 67 |
+
})
|
| 68 |
+
@triton.autotune(
|
| 69 |
+
configs=[
|
| 70 |
+
triton.Config({'BS': BS}, num_warps=num_warps)
|
| 71 |
+
for BS in BS_LIST
|
| 72 |
+
for num_warps in [2, 4, 8]
|
| 73 |
+
],
|
| 74 |
+
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']
|
| 75 |
+
)
|
| 76 |
+
@triton.jit(do_not_specialize=['T'])
|
| 77 |
+
def chunk_local_cumsum_vector_kernel(
|
| 78 |
+
s,
|
| 79 |
+
o,
|
| 80 |
+
cu_seqlens,
|
| 81 |
+
chunk_indices,
|
| 82 |
+
T,
|
| 83 |
+
B: tl.constexpr,
|
| 84 |
+
H: tl.constexpr,
|
| 85 |
+
S: tl.constexpr,
|
| 86 |
+
BT: tl.constexpr,
|
| 87 |
+
BS: tl.constexpr,
|
| 88 |
+
REVERSE: tl.constexpr,
|
| 89 |
+
IS_VARLEN: tl.constexpr,
|
| 90 |
+
HEAD_FIRST: tl.constexpr,
|
| 91 |
+
):
|
| 92 |
+
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 93 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 94 |
+
if IS_VARLEN:
|
| 95 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 96 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 97 |
+
T = eos - bos
|
| 98 |
+
else:
|
| 99 |
+
bos, eos = i_b * T, i_b * T + T
|
| 100 |
+
|
| 101 |
+
o_i = tl.arange(0, BT)
|
| 102 |
+
if REVERSE:
|
| 103 |
+
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
| 104 |
+
else:
|
| 105 |
+
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
| 106 |
+
|
| 107 |
+
if HEAD_FIRST:
|
| 108 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 109 |
+
p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 110 |
+
else:
|
| 111 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 112 |
+
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 113 |
+
# [BT, BS]
|
| 114 |
+
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
| 115 |
+
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
| 116 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@triton.heuristics({
|
| 120 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 121 |
+
})
|
| 122 |
+
@triton.autotune(
|
| 123 |
+
configs=[
|
| 124 |
+
triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
|
| 125 |
+
for BT in [32, 64, 128, 256]
|
| 126 |
+
for num_warps in [2, 4, 8]
|
| 127 |
+
for num_stages in [1, 2, 3, 4]
|
| 128 |
+
],
|
| 129 |
+
key=['B', 'H', 'IS_VARLEN', 'REVERSE']
|
| 130 |
+
)
|
| 131 |
+
@triton.jit(do_not_specialize=['T'])
|
| 132 |
+
def chunk_global_cumsum_scalar_kernel(
|
| 133 |
+
s,
|
| 134 |
+
o,
|
| 135 |
+
cu_seqlens,
|
| 136 |
+
T,
|
| 137 |
+
B: tl.constexpr,
|
| 138 |
+
H: tl.constexpr,
|
| 139 |
+
BT: tl.constexpr,
|
| 140 |
+
REVERSE: tl.constexpr,
|
| 141 |
+
IS_VARLEN: tl.constexpr,
|
| 142 |
+
HEAD_FIRST: tl.constexpr,
|
| 143 |
+
):
|
| 144 |
+
i_nh = tl.program_id(0)
|
| 145 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 146 |
+
if IS_VARLEN:
|
| 147 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 148 |
+
else:
|
| 149 |
+
bos, eos = i_n * T, i_n * T + T
|
| 150 |
+
T = eos - bos
|
| 151 |
+
|
| 152 |
+
b_z = tl.zeros([], dtype=tl.float32)
|
| 153 |
+
NT = tl.cdiv(T, BT)
|
| 154 |
+
for i_c in range(NT):
|
| 155 |
+
i_t = NT-1-i_c if REVERSE else i_c
|
| 156 |
+
if HEAD_FIRST:
|
| 157 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 158 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
| 159 |
+
else:
|
| 160 |
+
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 161 |
+
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
| 162 |
+
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
|
| 163 |
+
b_o = tl.cumsum(b_s, axis=0)
|
| 164 |
+
b_ss = tl.sum(b_s, 0)
|
| 165 |
+
if REVERSE:
|
| 166 |
+
b_o = -b_o + b_ss + b_s
|
| 167 |
+
b_o += b_z
|
| 168 |
+
if i_c >= 0:
|
| 169 |
+
b_z += b_ss
|
| 170 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@triton.heuristics({
|
| 174 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
| 175 |
+
})
|
| 176 |
+
@triton.autotune(
|
| 177 |
+
configs=[
|
| 178 |
+
triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
|
| 179 |
+
for BT in [16, 32, 64, 128]
|
| 180 |
+
for num_warps in [2, 4, 8]
|
| 181 |
+
for num_stages in [1, 2, 3, 4]
|
| 182 |
+
],
|
| 183 |
+
key=['B', 'H', 'S', 'IS_VARLEN', 'REVERSE']
|
| 184 |
+
)
|
| 185 |
+
@triton.jit(do_not_specialize=['T'])
|
| 186 |
+
def chunk_global_cumsum_vector_kernel(
|
| 187 |
+
s,
|
| 188 |
+
z,
|
| 189 |
+
cu_seqlens,
|
| 190 |
+
T,
|
| 191 |
+
B: tl.constexpr,
|
| 192 |
+
H: tl.constexpr,
|
| 193 |
+
S: tl.constexpr,
|
| 194 |
+
BT: tl.constexpr,
|
| 195 |
+
BS: tl.constexpr,
|
| 196 |
+
REVERSE: tl.constexpr,
|
| 197 |
+
IS_VARLEN: tl.constexpr,
|
| 198 |
+
HEAD_FIRST: tl.constexpr,
|
| 199 |
+
):
|
| 200 |
+
i_s, i_nh = tl.program_id(0), tl.program_id(1)
|
| 201 |
+
i_n, i_h = i_nh // H, i_nh % H
|
| 202 |
+
if IS_VARLEN:
|
| 203 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 204 |
+
else:
|
| 205 |
+
bos, eos = i_n * T, i_n * T + T
|
| 206 |
+
T = eos - bos
|
| 207 |
+
|
| 208 |
+
o_i = tl.arange(0, BT)
|
| 209 |
+
if REVERSE:
|
| 210 |
+
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
| 211 |
+
else:
|
| 212 |
+
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
| 213 |
+
|
| 214 |
+
b_z = tl.zeros([BS], dtype=tl.float32)
|
| 215 |
+
NT = tl.cdiv(T, BT)
|
| 216 |
+
for i_c in range(NT):
|
| 217 |
+
i_t = NT-1-i_c if REVERSE else i_c
|
| 218 |
+
if HEAD_FIRST:
|
| 219 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 220 |
+
p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 221 |
+
else:
|
| 222 |
+
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 223 |
+
p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
| 224 |
+
# [BT, BS]
|
| 225 |
+
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
| 226 |
+
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
|
| 227 |
+
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
| 228 |
+
if i_c >= 0:
|
| 229 |
+
b_z += tl.sum(b_s, 0)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def chunk_local_cumsum_scalar(
|
| 233 |
+
g: torch.Tensor,
|
| 234 |
+
chunk_size: int,
|
| 235 |
+
reverse: bool = False,
|
| 236 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 237 |
+
head_first: bool = False,
|
| 238 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
if head_first:
|
| 241 |
+
B, H, T = g.shape
|
| 242 |
+
else:
|
| 243 |
+
B, T, H = g.shape
|
| 244 |
+
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
|
| 245 |
+
BT = chunk_size
|
| 246 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 247 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 248 |
+
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
| 249 |
+
grid = (NT, B * H)
|
| 250 |
+
chunk_local_cumsum_scalar_kernel[grid](
|
| 251 |
+
g_org,
|
| 252 |
+
g,
|
| 253 |
+
cu_seqlens,
|
| 254 |
+
chunk_indices,
|
| 255 |
+
T=T,
|
| 256 |
+
B=B,
|
| 257 |
+
H=H,
|
| 258 |
+
BT=BT,
|
| 259 |
+
HEAD_FIRST=head_first,
|
| 260 |
+
REVERSE=reverse
|
| 261 |
+
)
|
| 262 |
+
return g
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def chunk_local_cumsum_vector(
|
| 266 |
+
g: torch.Tensor,
|
| 267 |
+
chunk_size: int,
|
| 268 |
+
reverse: bool = False,
|
| 269 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 270 |
+
head_first: bool = False,
|
| 271 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 272 |
+
) -> torch.Tensor:
|
| 273 |
+
if head_first:
|
| 274 |
+
B, H, T, S = g.shape
|
| 275 |
+
else:
|
| 276 |
+
B, T, H, S = g.shape
|
| 277 |
+
BT = chunk_size
|
| 278 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 279 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 280 |
+
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
|
| 281 |
+
|
| 282 |
+
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
| 283 |
+
def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
|
| 284 |
+
# keep cummulative normalizer in fp32
|
| 285 |
+
# this kernel is equivalent to
|
| 286 |
+
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
| 287 |
+
chunk_local_cumsum_vector_kernel[grid](
|
| 288 |
+
g_org,
|
| 289 |
+
g,
|
| 290 |
+
cu_seqlens,
|
| 291 |
+
chunk_indices,
|
| 292 |
+
T=T,
|
| 293 |
+
B=B,
|
| 294 |
+
H=H,
|
| 295 |
+
S=S,
|
| 296 |
+
BT=BT,
|
| 297 |
+
HEAD_FIRST=head_first,
|
| 298 |
+
REVERSE=reverse
|
| 299 |
+
)
|
| 300 |
+
return g
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@input_guard
|
| 304 |
+
def chunk_global_cumsum_scalar(
|
| 305 |
+
s: torch.Tensor,
|
| 306 |
+
reverse: bool = False,
|
| 307 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 308 |
+
head_first: bool = False,
|
| 309 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
if head_first:
|
| 312 |
+
B, H, T = s.shape
|
| 313 |
+
else:
|
| 314 |
+
B, T, H = s.shape
|
| 315 |
+
N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
|
| 316 |
+
|
| 317 |
+
z = torch.empty_like(s, dtype=output_dtype or s.dtype)
|
| 318 |
+
grid = (N * H,)
|
| 319 |
+
chunk_global_cumsum_scalar_kernel[grid](
|
| 320 |
+
s,
|
| 321 |
+
z,
|
| 322 |
+
cu_seqlens,
|
| 323 |
+
T=T,
|
| 324 |
+
B=B,
|
| 325 |
+
H=H,
|
| 326 |
+
HEAD_FIRST=head_first,
|
| 327 |
+
REVERSE=reverse
|
| 328 |
+
)
|
| 329 |
+
return z
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@input_guard
|
| 333 |
+
def chunk_global_cumsum_vector(
|
| 334 |
+
s: torch.Tensor,
|
| 335 |
+
reverse: bool = False,
|
| 336 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 337 |
+
head_first: bool = False,
|
| 338 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
if head_first:
|
| 341 |
+
B, H, T, S = s.shape
|
| 342 |
+
else:
|
| 343 |
+
B, T, H, S = s.shape
|
| 344 |
+
N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
|
| 345 |
+
BS = min(32, triton.next_power_of_2(S))
|
| 346 |
+
|
| 347 |
+
z = torch.empty_like(s, dtype=output_dtype or s.dtype)
|
| 348 |
+
grid = (triton.cdiv(S, BS), N * H)
|
| 349 |
+
chunk_global_cumsum_vector_kernel[grid](
|
| 350 |
+
s,
|
| 351 |
+
z,
|
| 352 |
+
cu_seqlens,
|
| 353 |
+
T=T,
|
| 354 |
+
B=B,
|
| 355 |
+
H=H,
|
| 356 |
+
S=S,
|
| 357 |
+
BS=BS,
|
| 358 |
+
HEAD_FIRST=head_first,
|
| 359 |
+
REVERSE=reverse
|
| 360 |
+
)
|
| 361 |
+
return z
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@input_guard
|
| 365 |
+
def chunk_global_cumsum(
|
| 366 |
+
s: torch.Tensor,
|
| 367 |
+
reverse: bool = False,
|
| 368 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 369 |
+
head_first: bool = False,
|
| 370 |
+
output_dtype: Optional[torch.dtype] = torch.float
|
| 371 |
+
) -> torch.Tensor:
|
| 372 |
+
if cu_seqlens is not None:
|
| 373 |
+
assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
| 374 |
+
if len(s.shape) == 3:
|
| 375 |
+
return chunk_global_cumsum_scalar(s, reverse, cu_seqlens, head_first, output_dtype)
|
| 376 |
+
elif len(s.shape) == 4:
|
| 377 |
+
return chunk_global_cumsum_vector(s, reverse, cu_seqlens, head_first, output_dtype)
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Unsupported input shape {s.shape}. "
|
| 381 |
+
f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
|
| 382 |
+
f"or [B, H, T]/[B, H, T, D] otherwise"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
@input_guard
|
| 387 |
+
def chunk_local_cumsum(
|
| 388 |
+
g: torch.Tensor,
|
| 389 |
+
chunk_size: int,
|
| 390 |
+
reverse: bool = False,
|
| 391 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 392 |
+
head_first: bool = False,
|
| 393 |
+
output_dtype: Optional[torch.dtype] = torch.float,
|
| 394 |
+
**kwargs
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
if not head_first and g.shape[1] < g.shape[2]:
|
| 397 |
+
warnings.warn(
|
| 398 |
+
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
| 399 |
+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
| 400 |
+
"when head_first=False was specified. "
|
| 401 |
+
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
| 402 |
+
)
|
| 403 |
+
if cu_seqlens is not None:
|
| 404 |
+
assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
| 405 |
+
if len(g.shape) == 3:
|
| 406 |
+
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
|
| 407 |
+
elif len(g.shape) == 4:
|
| 408 |
+
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"Unsupported input shape {g.shape}. "
|
| 412 |
+
f"which should be (B, T, H, D) if `head_first=False` "
|
| 413 |
+
f"or (B, H, T, D) otherwise"
|
| 414 |
+
)
|
fla3/ops/utils/index.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
|
| 9 |
+
from ...utils import tensor_cache
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@triton.autotune(
|
| 13 |
+
configs=[
|
| 14 |
+
triton.Config({}, num_warps=num_warps)
|
| 15 |
+
for num_warps in [4, 8, 16, 32]
|
| 16 |
+
],
|
| 17 |
+
key=['B'],
|
| 18 |
+
)
|
| 19 |
+
@triton.jit
|
| 20 |
+
def prepare_position_ids_kernel(
|
| 21 |
+
y,
|
| 22 |
+
cu_seqlens,
|
| 23 |
+
B: tl.constexpr
|
| 24 |
+
):
|
| 25 |
+
i_n = tl.program_id(0)
|
| 26 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 27 |
+
T = eos - bos
|
| 28 |
+
|
| 29 |
+
o = tl.arange(0, B)
|
| 30 |
+
for i in range(0, tl.cdiv(T, B) * B, B):
|
| 31 |
+
o_i = o + i
|
| 32 |
+
tl.store(y + bos + o_i, o_i, o_i < T)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@tensor_cache
|
| 36 |
+
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 37 |
+
return cu_seqlens[1:] - cu_seqlens[:-1]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@tensor_cache
|
| 41 |
+
def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor:
|
| 42 |
+
return mask.sum(dim=-1, dtype=torch.int32)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@tensor_cache
|
| 46 |
+
def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor:
|
| 47 |
+
return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@tensor_cache
|
| 51 |
+
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 52 |
+
return torch.cat([
|
| 53 |
+
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
|
| 54 |
+
for n in prepare_lens(cu_seqlens).unbind()
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@tensor_cache
|
| 59 |
+
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 60 |
+
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@tensor_cache
|
| 64 |
+
def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
| 65 |
+
position_ids = prepare_position_ids(cu_seqlens)
|
| 66 |
+
return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@tensor_cache
|
| 70 |
+
def prepare_chunk_indices(
|
| 71 |
+
cu_seqlens: torch.LongTensor,
|
| 72 |
+
chunk_size: int
|
| 73 |
+
) -> torch.LongTensor:
|
| 74 |
+
indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
|
| 75 |
+
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@tensor_cache
|
| 79 |
+
def prepare_chunk_offsets(
|
| 80 |
+
cu_seqlens: torch.LongTensor,
|
| 81 |
+
chunk_size: int
|
| 82 |
+
) -> torch.LongTensor:
|
| 83 |
+
return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
|
fla3/ops/utils/logcumsumexp.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
from ...ops.utils.op import exp, log
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.autotune(
|
| 11 |
+
configs=[
|
| 12 |
+
triton.Config({'BT': BT}, num_warps=num_warps)
|
| 13 |
+
for BT in [16, 32, 64]
|
| 14 |
+
for num_warps in [2, 4, 8]
|
| 15 |
+
],
|
| 16 |
+
key=['S']
|
| 17 |
+
)
|
| 18 |
+
@triton.jit(do_not_specialize=['T'])
|
| 19 |
+
def logcumsumexp_fwd_kernel(
|
| 20 |
+
s,
|
| 21 |
+
z,
|
| 22 |
+
T,
|
| 23 |
+
S: tl.constexpr,
|
| 24 |
+
BT: tl.constexpr
|
| 25 |
+
):
|
| 26 |
+
i_bh = tl.program_id(0)
|
| 27 |
+
o_i = tl.arange(0, BT)
|
| 28 |
+
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
| 29 |
+
|
| 30 |
+
b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
|
| 31 |
+
b_zp = tl.zeros([S,], dtype=tl.float32)
|
| 32 |
+
for i_t in range(tl.cdiv(T, BT)):
|
| 33 |
+
p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
|
| 34 |
+
p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
|
| 35 |
+
|
| 36 |
+
# [BT, S]
|
| 37 |
+
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
| 38 |
+
# [S,]
|
| 39 |
+
b_mc = tl.max(b_s, 0)
|
| 40 |
+
b_mc = tl.maximum(b_mp, b_mc)
|
| 41 |
+
b_zp = b_zp * exp(b_mp - b_mc)
|
| 42 |
+
# [BT, S]
|
| 43 |
+
b_s = exp(b_s - b_mc)
|
| 44 |
+
b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
|
| 45 |
+
# [S,]
|
| 46 |
+
b_zc = tl.max(b_z, 0)
|
| 47 |
+
b_mp = b_mc
|
| 48 |
+
b_zp = b_zc
|
| 49 |
+
# [BT, BS]
|
| 50 |
+
# small eps to prevent underflows
|
| 51 |
+
b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
|
| 52 |
+
tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
fla3/ops/utils/logsumexp.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.op import exp, log
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.heuristics({
|
| 14 |
+
'HAS_SCALE': lambda args: args['scale'] is not None
|
| 15 |
+
})
|
| 16 |
+
@triton.autotune(
|
| 17 |
+
configs=[
|
| 18 |
+
triton.Config({}, num_warps=num_warps)
|
| 19 |
+
for num_warps in [1, 2, 4, 8, 16, 32]
|
| 20 |
+
],
|
| 21 |
+
key=['D']
|
| 22 |
+
)
|
| 23 |
+
@triton.jit
|
| 24 |
+
def logsumexp_fwd_kernel(
|
| 25 |
+
x,
|
| 26 |
+
z,
|
| 27 |
+
scale,
|
| 28 |
+
D: tl.constexpr,
|
| 29 |
+
B: tl.constexpr,
|
| 30 |
+
HAS_SCALE: tl.constexpr
|
| 31 |
+
):
|
| 32 |
+
i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
|
| 33 |
+
o_d = i_d * B + tl.arange(0, B)
|
| 34 |
+
m_d = o_d < D
|
| 35 |
+
|
| 36 |
+
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
|
| 37 |
+
if HAS_SCALE:
|
| 38 |
+
b_x = b_x * scale
|
| 39 |
+
b_m = tl.max(b_x, 0)
|
| 40 |
+
b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m
|
| 41 |
+
tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def logsumexp_fwd(
|
| 45 |
+
x,
|
| 46 |
+
scale: Optional[float] = None,
|
| 47 |
+
dtype: Optional[torch.dtype] = None
|
| 48 |
+
):
|
| 49 |
+
r"""
|
| 50 |
+
Compute the logsumexp of the input tensor over the last dimension.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
x (Tensor):
|
| 54 |
+
The input tensor of any shape.
|
| 55 |
+
scale (Optional[float]):
|
| 56 |
+
The scale applied to the input tensor. Default: `None`.
|
| 57 |
+
dtype (Optional[torch.dtype]):
|
| 58 |
+
The data type of the output tensor. Default: `None`.
|
| 59 |
+
Returns:
|
| 60 |
+
Tensor: The logsumexp of the input tensor.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
shape = x.shape
|
| 64 |
+
x = x.view(-1, shape[-1])
|
| 65 |
+
N, D = x.shape
|
| 66 |
+
B = min(triton.next_power_of_2(D), 64 * 1024)
|
| 67 |
+
ND = triton.cdiv(D, B)
|
| 68 |
+
|
| 69 |
+
z = x.new_empty(N, ND, dtype=torch.float)
|
| 70 |
+
logsumexp_fwd_kernel[(N, ND)](
|
| 71 |
+
x=x,
|
| 72 |
+
z=z,
|
| 73 |
+
scale=scale,
|
| 74 |
+
D=D,
|
| 75 |
+
B=B
|
| 76 |
+
)
|
| 77 |
+
z = z.logsumexp(-1).view(*shape[:-1])
|
| 78 |
+
if dtype is not None and dtype != torch.float:
|
| 79 |
+
z = z.to(dtype)
|
| 80 |
+
return z
|
fla3/ops/utils/matmul.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
# code adapted from
|
| 5 |
+
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from ...ops.utils.op import exp
|
| 14 |
+
from ...utils import input_guard
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
| 18 |
+
# - A list of `triton.Config` objects that define different configurations of
|
| 19 |
+
# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
|
| 20 |
+
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
|
| 21 |
+
# provided configs
|
| 22 |
+
@triton.heuristics({
|
| 23 |
+
'HAS_ALPHA': lambda args: args['alpha'] is not None,
|
| 24 |
+
'HAS_BETA': lambda args: args['beta'] is not None
|
| 25 |
+
})
|
| 26 |
+
@triton.autotune(
|
| 27 |
+
configs=[
|
| 28 |
+
triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
|
| 29 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
|
| 30 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 31 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 32 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 33 |
+
triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
|
| 34 |
+
triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
|
| 35 |
+
triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
|
| 36 |
+
# Good config for fp8 inputs.
|
| 37 |
+
# triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
|
| 38 |
+
# triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
|
| 39 |
+
# triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 40 |
+
# triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
|
| 41 |
+
# triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 42 |
+
# triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
|
| 43 |
+
# triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
|
| 44 |
+
# triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
|
| 45 |
+
],
|
| 46 |
+
key=['M', 'N', 'K']
|
| 47 |
+
)
|
| 48 |
+
@triton.jit
|
| 49 |
+
def matmul_kernel(
|
| 50 |
+
# Pointers to matrices
|
| 51 |
+
a,
|
| 52 |
+
b,
|
| 53 |
+
c,
|
| 54 |
+
input,
|
| 55 |
+
alpha,
|
| 56 |
+
beta,
|
| 57 |
+
# Matrix dimensions
|
| 58 |
+
M,
|
| 59 |
+
N,
|
| 60 |
+
K,
|
| 61 |
+
# The stride variables represent how much to increase the ptr by when moving by 1
|
| 62 |
+
# element in a particular dimension. E.g. `s_am` is how much to increase `a`
|
| 63 |
+
# by to get the element one row down (A has M rows).
|
| 64 |
+
stride_ab, stride_am, stride_ak, # a: batch, M, K
|
| 65 |
+
stride_bk, stride_bn, # b: K, N
|
| 66 |
+
stride_cb, stride_cm, stride_cn, # c: batch, M, N
|
| 67 |
+
# Meta-parameters
|
| 68 |
+
BM: tl.constexpr,
|
| 69 |
+
BK: tl.constexpr,
|
| 70 |
+
BN: tl.constexpr,
|
| 71 |
+
G: tl.constexpr,
|
| 72 |
+
ACTIVATION: tl.constexpr,
|
| 73 |
+
HAS_INPUT: tl.constexpr,
|
| 74 |
+
HAS_ALPHA: tl.constexpr,
|
| 75 |
+
HAS_BETA: tl.constexpr,
|
| 76 |
+
ALLOW_TF32: tl.constexpr,
|
| 77 |
+
X_DIM: tl.constexpr = 1,
|
| 78 |
+
):
|
| 79 |
+
"""Kernel for computing the matmul C = A x B.
|
| 80 |
+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
| 81 |
+
"""
|
| 82 |
+
# -----------------------------------------------------------
|
| 83 |
+
# Map program ids `pid` to the block of C it should compute.
|
| 84 |
+
# This is done in a grouped ordering to promote L2 data reuse.
|
| 85 |
+
# See above `L2 Cache Optimizations` section for details.
|
| 86 |
+
i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 87 |
+
|
| 88 |
+
NM, NN = tl.num_programs(1), tl.num_programs(2)
|
| 89 |
+
i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
|
| 90 |
+
|
| 91 |
+
# ----------------------------------------------------------
|
| 92 |
+
# Create pointers for the first blocks of A and B.
|
| 93 |
+
# We will advance this pointer as we move in the K direction
|
| 94 |
+
# and accumulate
|
| 95 |
+
# `p_a` is a block of [BM, BK] pointers
|
| 96 |
+
# `p_b` is a block of [BK, BN] pointers
|
| 97 |
+
# See above `Pointer Arithmetic` section for details
|
| 98 |
+
a_batch_ptr = a + i_b * stride_ab
|
| 99 |
+
o_am = (i_m * BM + tl.arange(0, BM)) % M
|
| 100 |
+
o_bn = (i_n * BN + tl.arange(0, BN)) % N
|
| 101 |
+
o_k = tl.arange(0, BK)
|
| 102 |
+
|
| 103 |
+
p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak)
|
| 104 |
+
p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn)
|
| 105 |
+
|
| 106 |
+
b_acc = tl.zeros((BM, BN), dtype=tl.float32)
|
| 107 |
+
for k in range(0, tl.cdiv(K, BK)):
|
| 108 |
+
# Load the next block of A and B, generate a mask by checking the K dimension.
|
| 109 |
+
# If it is out of bounds, set it to 0.
|
| 110 |
+
b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
|
| 111 |
+
b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
|
| 112 |
+
# We accumulate along the K dimension.
|
| 113 |
+
b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32)
|
| 114 |
+
# Advance the ptrs to the next K block.
|
| 115 |
+
p_a += BK * stride_ak
|
| 116 |
+
p_b += BK * stride_bk
|
| 117 |
+
|
| 118 |
+
o_cm = i_m * BM + tl.arange(0, BM)
|
| 119 |
+
o_cn = i_n * BN + tl.arange(0, BN)
|
| 120 |
+
mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
|
| 121 |
+
|
| 122 |
+
b_c = b_acc
|
| 123 |
+
# You can fuse arbitrary activation functions here
|
| 124 |
+
# while the b_acc is still in FP32!
|
| 125 |
+
if ACTIVATION == "leaky_relu":
|
| 126 |
+
b_c = leaky_relu(b_c)
|
| 127 |
+
elif ACTIVATION == "relu":
|
| 128 |
+
b_c = relu(b_c)
|
| 129 |
+
elif ACTIVATION == "sigmoid":
|
| 130 |
+
b_c = sigmoid(b_c)
|
| 131 |
+
elif ACTIVATION == "tanh":
|
| 132 |
+
b_c = tanh(b_c)
|
| 133 |
+
|
| 134 |
+
if HAS_ALPHA:
|
| 135 |
+
b_c *= tl.load(alpha)
|
| 136 |
+
|
| 137 |
+
if HAS_INPUT:
|
| 138 |
+
p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :]
|
| 139 |
+
mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask
|
| 140 |
+
b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32)
|
| 141 |
+
if HAS_BETA:
|
| 142 |
+
b_i *= tl.load(beta)
|
| 143 |
+
b_c += b_i
|
| 144 |
+
|
| 145 |
+
# -----------------------------------------------------------
|
| 146 |
+
# Write back the block of the output matrix C with masks.
|
| 147 |
+
c_batch_ptr = c + i_b * stride_cb
|
| 148 |
+
p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :]
|
| 149 |
+
tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
|
| 153 |
+
@triton.jit
|
| 154 |
+
def leaky_relu(x):
|
| 155 |
+
return tl.where(x >= 0, x, 0.01 * x)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@triton.jit
|
| 159 |
+
def sigmoid(x):
|
| 160 |
+
# σ(x) = 1 / (1 + exp(-x))
|
| 161 |
+
return 1.0 / (1.0 + exp(-x))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@triton.jit
|
| 165 |
+
def tanh(x):
|
| 166 |
+
# tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
| 167 |
+
# 2 * sigmoid(2x) - 1
|
| 168 |
+
return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@triton.jit
|
| 172 |
+
def relu(x):
|
| 173 |
+
# ReLU(x) = max(0, x)
|
| 174 |
+
return tl.maximum(x, 0.0)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@input_guard
|
| 178 |
+
def matmul(a, b, activation=''):
|
| 179 |
+
assert a.dim() in [2, 3], "a must be 2D or 3D"
|
| 180 |
+
assert b.dim() == 2, "b must be 2D"
|
| 181 |
+
assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
|
| 182 |
+
|
| 183 |
+
if a.dim() == 2:
|
| 184 |
+
a_dim = 2
|
| 185 |
+
a = a.unsqueeze(0).contiguous() # (1, M, K)
|
| 186 |
+
else:
|
| 187 |
+
a_dim = 3
|
| 188 |
+
allow_tf32 = False if a.dtype == torch.float32 else True
|
| 189 |
+
|
| 190 |
+
B, M, K = a.shape[0], a.shape[1], a.shape[2]
|
| 191 |
+
K_b, N = b.shape
|
| 192 |
+
assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
|
| 193 |
+
c = a.new_empty(B, M, N)
|
| 194 |
+
|
| 195 |
+
def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
|
| 196 |
+
matmul_kernel[grid](
|
| 197 |
+
a, b, c, None, None, None,
|
| 198 |
+
M, N, K,
|
| 199 |
+
a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
|
| 200 |
+
b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
|
| 201 |
+
c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
|
| 202 |
+
ACTIVATION=activation,
|
| 203 |
+
ALLOW_TF32=allow_tf32,
|
| 204 |
+
HAS_INPUT=False,
|
| 205 |
+
)
|
| 206 |
+
return c.squeeze(0) if a_dim == 2 else c
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@input_guard
|
| 210 |
+
def addmm(
|
| 211 |
+
x: torch.Tensor,
|
| 212 |
+
a: torch.Tensor,
|
| 213 |
+
b: torch.Tensor,
|
| 214 |
+
alpha: Optional[float] = None,
|
| 215 |
+
beta: Optional[float] = None,
|
| 216 |
+
) -> torch.Tensor:
|
| 217 |
+
assert a.dim() in [2, 3], "a must be 2D or 3D"
|
| 218 |
+
assert b.dim() == 2, "b must be 2D"
|
| 219 |
+
assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
|
| 220 |
+
|
| 221 |
+
if a.dim() == 2:
|
| 222 |
+
a_dim = 2
|
| 223 |
+
a = a.unsqueeze(0).contiguous() # (1, M, K)
|
| 224 |
+
else:
|
| 225 |
+
a_dim = 3
|
| 226 |
+
allow_tf32 = False if a.dtype == torch.float32 else True
|
| 227 |
+
|
| 228 |
+
B, M, K = a.shape[0], a.shape[1], a.shape[2]
|
| 229 |
+
K_b, N = b.shape
|
| 230 |
+
assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
|
| 231 |
+
c = a.new_empty(B, M, N)
|
| 232 |
+
|
| 233 |
+
def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
|
| 234 |
+
matmul_kernel[grid](
|
| 235 |
+
a, b, c, x, alpha, beta,
|
| 236 |
+
M, N, K,
|
| 237 |
+
a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
|
| 238 |
+
b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
|
| 239 |
+
c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
|
| 240 |
+
ACTIVATION=None,
|
| 241 |
+
ALLOW_TF32=allow_tf32,
|
| 242 |
+
HAS_INPUT=True,
|
| 243 |
+
X_DIM=x.dim(),
|
| 244 |
+
)
|
| 245 |
+
return c.squeeze(0) if a_dim == 2 else c
|
fla3/ops/utils/pack.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
# Code adapted from https://github.com/mayank31398/cute-kernels
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
|
| 12 |
+
from ...ops.utils.index import prepare_lens
|
| 13 |
+
from ...utils import input_guard
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.autotune(
|
| 17 |
+
configs=[
|
| 18 |
+
triton.Config({}, num_warps=num_warps)
|
| 19 |
+
for num_warps in [4, 8, 16, 32]
|
| 20 |
+
],
|
| 21 |
+
key=['D', 'PADDING_SIDE', 'PACK']
|
| 22 |
+
)
|
| 23 |
+
@triton.jit
|
| 24 |
+
def packunpack_sequence_kernel(
|
| 25 |
+
x,
|
| 26 |
+
y,
|
| 27 |
+
cu_seqlens,
|
| 28 |
+
S,
|
| 29 |
+
D,
|
| 30 |
+
BD: tl.constexpr,
|
| 31 |
+
PADDING_SIDE: tl.constexpr,
|
| 32 |
+
PACK: tl.constexpr,
|
| 33 |
+
):
|
| 34 |
+
i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 35 |
+
bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
|
| 36 |
+
|
| 37 |
+
T = eos - bos
|
| 38 |
+
if PADDING_SIDE == 'left':
|
| 39 |
+
NP = S - T
|
| 40 |
+
if i_s < NP:
|
| 41 |
+
return
|
| 42 |
+
i_t = bos + (i_s - NP)
|
| 43 |
+
else:
|
| 44 |
+
if i_s >= T:
|
| 45 |
+
return
|
| 46 |
+
i_t = bos + i_s
|
| 47 |
+
|
| 48 |
+
o_d = i_d * BD + tl.arange(0, BD)
|
| 49 |
+
mask = o_d < D
|
| 50 |
+
|
| 51 |
+
if PACK:
|
| 52 |
+
b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
|
| 53 |
+
tl.store(y + i_t * D + o_d, b_x, mask=mask)
|
| 54 |
+
else:
|
| 55 |
+
b_x = tl.load(x + i_t * D + o_d, mask=mask)
|
| 56 |
+
tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def pack_sequence_fwdbwd(
|
| 60 |
+
x: torch.Tensor,
|
| 61 |
+
cu_seqlens: torch.Tensor,
|
| 62 |
+
padding_side: str,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
B, S = x.shape[:2]
|
| 65 |
+
D = x.numel() // (B * S)
|
| 66 |
+
BD = min(triton.next_power_of_2(D), 4096)
|
| 67 |
+
ND = triton.cdiv(D, BD)
|
| 68 |
+
|
| 69 |
+
y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
|
| 70 |
+
packunpack_sequence_kernel[ND, S, B](
|
| 71 |
+
x=x,
|
| 72 |
+
y=y,
|
| 73 |
+
cu_seqlens=cu_seqlens,
|
| 74 |
+
S=S,
|
| 75 |
+
D=D,
|
| 76 |
+
BD=BD,
|
| 77 |
+
PADDING_SIDE=padding_side,
|
| 78 |
+
PACK=True,
|
| 79 |
+
)
|
| 80 |
+
return y
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def unpack_sequence_fwdbwd(
|
| 84 |
+
x: torch.Tensor,
|
| 85 |
+
cu_seqlens: torch.Tensor,
|
| 86 |
+
padding_side: str,
|
| 87 |
+
desired_shape: torch.Size,
|
| 88 |
+
) -> torch.Tensor:
|
| 89 |
+
if desired_shape is None:
|
| 90 |
+
desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
|
| 91 |
+
y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
|
| 92 |
+
B, S = y.shape[:2]
|
| 93 |
+
D = y.numel() // (B * S)
|
| 94 |
+
BD = min(triton.next_power_of_2(D), 4096)
|
| 95 |
+
ND = triton.cdiv(D, BD)
|
| 96 |
+
|
| 97 |
+
packunpack_sequence_kernel[ND, S, B](
|
| 98 |
+
x=x,
|
| 99 |
+
y=y,
|
| 100 |
+
cu_seqlens=cu_seqlens,
|
| 101 |
+
S=S,
|
| 102 |
+
D=D,
|
| 103 |
+
BD=BD,
|
| 104 |
+
PADDING_SIDE=padding_side,
|
| 105 |
+
PACK=False,
|
| 106 |
+
)
|
| 107 |
+
return y
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class PackSequenceFunction(torch.autograd.Function):
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
@input_guard
|
| 114 |
+
def forward(
|
| 115 |
+
ctx,
|
| 116 |
+
x: torch.Tensor,
|
| 117 |
+
cu_seqlens: torch.Tensor,
|
| 118 |
+
padding_side: str,
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
assert padding_side in ['left', 'right']
|
| 121 |
+
assert x.ndim >= 2
|
| 122 |
+
|
| 123 |
+
ctx.cu_seqlens = cu_seqlens
|
| 124 |
+
ctx.padding_side = padding_side
|
| 125 |
+
ctx.desired_shape = x.shape
|
| 126 |
+
|
| 127 |
+
y = pack_sequence_fwdbwd(
|
| 128 |
+
x=x,
|
| 129 |
+
cu_seqlens=cu_seqlens,
|
| 130 |
+
padding_side=padding_side,
|
| 131 |
+
)
|
| 132 |
+
return y
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
@input_guard
|
| 136 |
+
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
|
| 137 |
+
dx = unpack_sequence_fwdbwd(
|
| 138 |
+
x=dy,
|
| 139 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 140 |
+
padding_side=ctx.padding_side,
|
| 141 |
+
desired_shape=ctx.desired_shape,
|
| 142 |
+
)
|
| 143 |
+
return dx, *[None] * 10
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class UnpackSequenceFunction(torch.autograd.Function):
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
@input_guard
|
| 150 |
+
def forward(
|
| 151 |
+
ctx,
|
| 152 |
+
x: torch.Tensor,
|
| 153 |
+
cu_seqlens: torch.Tensor,
|
| 154 |
+
padding_side: str,
|
| 155 |
+
desired_shape: Optional[torch.Size] = None,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
assert padding_side in ['left', 'right']
|
| 158 |
+
assert x.ndim >= 2
|
| 159 |
+
if desired_shape is not None:
|
| 160 |
+
assert desired_shape[0] == cu_seqlens.shape[0] - 1
|
| 161 |
+
assert desired_shape[2:] == x.shape[1:]
|
| 162 |
+
|
| 163 |
+
ctx.cu_seqlens = cu_seqlens
|
| 164 |
+
ctx.padding_side = padding_side
|
| 165 |
+
|
| 166 |
+
y = unpack_sequence_fwdbwd(
|
| 167 |
+
x=x,
|
| 168 |
+
cu_seqlens=cu_seqlens,
|
| 169 |
+
padding_side=padding_side,
|
| 170 |
+
desired_shape=desired_shape,
|
| 171 |
+
)
|
| 172 |
+
return y
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
@input_guard
|
| 176 |
+
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
|
| 177 |
+
dx = pack_sequence_fwdbwd(
|
| 178 |
+
x=dy,
|
| 179 |
+
cu_seqlens=ctx.cu_seqlens,
|
| 180 |
+
padding_side=ctx.padding_side,
|
| 181 |
+
)
|
| 182 |
+
return dx, None, None, None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def pack_sequence(
|
| 186 |
+
x: torch.Tensor,
|
| 187 |
+
cu_seqlens: torch.Tensor,
|
| 188 |
+
padding_side: str = 'left'
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
return PackSequenceFunction.apply(
|
| 191 |
+
x,
|
| 192 |
+
cu_seqlens,
|
| 193 |
+
padding_side,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def unpack_sequence(
|
| 198 |
+
x: torch.Tensor,
|
| 199 |
+
cu_seqlens: torch.Tensor,
|
| 200 |
+
padding_side: str = 'left',
|
| 201 |
+
desired_shape: Optional[torch.Size] = None,
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
return UnpackSequenceFunction.apply(
|
| 204 |
+
x,
|
| 205 |
+
cu_seqlens,
|
| 206 |
+
padding_side,
|
| 207 |
+
desired_shape,
|
| 208 |
+
)
|
fla3/ops/utils/pooling.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 11 |
+
from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.heuristics({
|
| 15 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 16 |
+
})
|
| 17 |
+
@triton.autotune(
|
| 18 |
+
configs=[
|
| 19 |
+
triton.Config({'BD': BD}, num_warps=num_warps)
|
| 20 |
+
for BD in [16, 32, 64, 128]
|
| 21 |
+
for num_warps in [1, 2, 4, 8]
|
| 22 |
+
],
|
| 23 |
+
key=['BT']
|
| 24 |
+
)
|
| 25 |
+
@triton.jit(do_not_specialize=['T'])
|
| 26 |
+
def mean_pooling_fwd_kernel(
|
| 27 |
+
x,
|
| 28 |
+
o,
|
| 29 |
+
cu_seqlens,
|
| 30 |
+
chunk_indices,
|
| 31 |
+
T,
|
| 32 |
+
H: tl.constexpr,
|
| 33 |
+
D: tl.constexpr,
|
| 34 |
+
BT: tl.constexpr,
|
| 35 |
+
BD: tl.constexpr,
|
| 36 |
+
IS_VARLEN: tl.constexpr
|
| 37 |
+
):
|
| 38 |
+
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 39 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 40 |
+
if IS_VARLEN:
|
| 41 |
+
i_tg = i_t
|
| 42 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 43 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 44 |
+
T = eos - bos
|
| 45 |
+
NT = tl.cdiv(T, BT)
|
| 46 |
+
else:
|
| 47 |
+
NT = tl.cdiv(T, BT)
|
| 48 |
+
i_tg = i_b * NT + i_t
|
| 49 |
+
bos, eos = i_b * T, i_b * T + T
|
| 50 |
+
|
| 51 |
+
p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
| 52 |
+
p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
|
| 53 |
+
# [BT, BD]
|
| 54 |
+
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
| 55 |
+
# [BD]
|
| 56 |
+
b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
|
| 57 |
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@triton.heuristics({
|
| 61 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 62 |
+
})
|
| 63 |
+
@triton.autotune(
|
| 64 |
+
configs=[
|
| 65 |
+
triton.Config({'BD': BD}, num_warps=num_warps)
|
| 66 |
+
for BD in [16, 32, 64, 128]
|
| 67 |
+
for num_warps in [1, 2, 4, 8]
|
| 68 |
+
],
|
| 69 |
+
key=['BT']
|
| 70 |
+
)
|
| 71 |
+
@triton.jit(do_not_specialize=['T'])
|
| 72 |
+
def mean_pooling_bwd_kernel(
|
| 73 |
+
do,
|
| 74 |
+
dx,
|
| 75 |
+
cu_seqlens,
|
| 76 |
+
chunk_indices,
|
| 77 |
+
T,
|
| 78 |
+
H: tl.constexpr,
|
| 79 |
+
D: tl.constexpr,
|
| 80 |
+
BT: tl.constexpr,
|
| 81 |
+
BD: tl.constexpr,
|
| 82 |
+
IS_VARLEN: tl.constexpr
|
| 83 |
+
):
|
| 84 |
+
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
| 85 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 86 |
+
if IS_VARLEN:
|
| 87 |
+
i_tg = i_t
|
| 88 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 89 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 90 |
+
T = eos - bos
|
| 91 |
+
NT = tl.cdiv(T, BT)
|
| 92 |
+
else:
|
| 93 |
+
NT = tl.cdiv(T, BT)
|
| 94 |
+
i_tg = i_b * NT + i_t
|
| 95 |
+
bos, eos = i_b * T, i_b * T + T
|
| 96 |
+
|
| 97 |
+
p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
| 98 |
+
p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
|
| 99 |
+
# [BD]
|
| 100 |
+
b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
|
| 101 |
+
# [BT, BD]
|
| 102 |
+
b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
|
| 103 |
+
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def mean_pooling_fwd(
|
| 107 |
+
x: torch.Tensor,
|
| 108 |
+
chunk_size: int,
|
| 109 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
B, T, H, D = x.shape
|
| 112 |
+
BT = chunk_size
|
| 113 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 114 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 115 |
+
|
| 116 |
+
o = x.new_empty(B, NT, H, D)
|
| 117 |
+
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
|
| 118 |
+
mean_pooling_fwd_kernel[grid](
|
| 119 |
+
x,
|
| 120 |
+
o,
|
| 121 |
+
cu_seqlens,
|
| 122 |
+
chunk_indices,
|
| 123 |
+
T=T,
|
| 124 |
+
H=H,
|
| 125 |
+
D=D,
|
| 126 |
+
BT=BT,
|
| 127 |
+
)
|
| 128 |
+
return o
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def mean_pooling_bwd(
|
| 132 |
+
do: torch.Tensor,
|
| 133 |
+
batch_size: int,
|
| 134 |
+
seq_len: int,
|
| 135 |
+
chunk_size: int,
|
| 136 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
B, T, H, D = batch_size, seq_len, *do.shape[-2:]
|
| 139 |
+
BT = chunk_size
|
| 140 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
| 141 |
+
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
| 142 |
+
|
| 143 |
+
dx = do.new_empty(B, T, H, D)
|
| 144 |
+
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
|
| 145 |
+
mean_pooling_bwd_kernel[grid](
|
| 146 |
+
do,
|
| 147 |
+
dx,
|
| 148 |
+
cu_seqlens,
|
| 149 |
+
chunk_indices,
|
| 150 |
+
T=T,
|
| 151 |
+
H=H,
|
| 152 |
+
D=D,
|
| 153 |
+
BT=BT,
|
| 154 |
+
)
|
| 155 |
+
return dx
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MeanPoolingFunction(torch.autograd.Function):
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
@input_guard
|
| 162 |
+
@autocast_custom_fwd
|
| 163 |
+
def forward(
|
| 164 |
+
ctx,
|
| 165 |
+
x: torch.Tensor,
|
| 166 |
+
chunk_size: int,
|
| 167 |
+
cu_seqlens: Optional[torch.LongTensor] = None
|
| 168 |
+
) -> torch.Tensor:
|
| 169 |
+
o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
|
| 170 |
+
ctx.batch_size = x.shape[0]
|
| 171 |
+
ctx.seq_len = x.shape[1]
|
| 172 |
+
ctx.chunk_size = chunk_size
|
| 173 |
+
ctx.cu_seqlens = cu_seqlens
|
| 174 |
+
return o
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
@input_guard
|
| 178 |
+
@autocast_custom_bwd
|
| 179 |
+
def backward(
|
| 180 |
+
ctx, do
|
| 181 |
+
) -> Tuple[torch.Tensor, None, None]:
|
| 182 |
+
batch_size = ctx.batch_size
|
| 183 |
+
seq_len = ctx.seq_len
|
| 184 |
+
chunk_size = ctx.chunk_size
|
| 185 |
+
cu_seqlens = ctx.cu_seqlens
|
| 186 |
+
dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
|
| 187 |
+
return dx, None, None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def mean_pooling(
|
| 191 |
+
x: torch.Tensor,
|
| 192 |
+
chunk_size: int,
|
| 193 |
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
| 194 |
+
head_first: bool = False
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
if head_first:
|
| 197 |
+
x = x.transpose(1, 2)
|
| 198 |
+
if cu_seqlens is not None:
|
| 199 |
+
if x.shape[0] != 1:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
|
| 202 |
+
f"Please ..tten variable-length inputs before processing."
|
| 203 |
+
)
|
| 204 |
+
o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
|
| 205 |
+
if head_first:
|
| 206 |
+
o = o.transpose(1, 2)
|
| 207 |
+
return o
|
fla3/ops/utils/softmax.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.op import exp
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@triton.autotune(
|
| 14 |
+
configs=[
|
| 15 |
+
triton.Config({}, num_warps=1),
|
| 16 |
+
triton.Config({}, num_warps=2),
|
| 17 |
+
triton.Config({}, num_warps=4),
|
| 18 |
+
triton.Config({}, num_warps=8),
|
| 19 |
+
triton.Config({}, num_warps=16),
|
| 20 |
+
triton.Config({}, num_warps=32)
|
| 21 |
+
],
|
| 22 |
+
key=['D']
|
| 23 |
+
)
|
| 24 |
+
@triton.jit
|
| 25 |
+
def softmax_fwd_kernel(
|
| 26 |
+
x,
|
| 27 |
+
p,
|
| 28 |
+
D: tl.constexpr,
|
| 29 |
+
B: tl.constexpr
|
| 30 |
+
):
|
| 31 |
+
i_n = tl.program_id(0)
|
| 32 |
+
o_d = tl.arange(0, B)
|
| 33 |
+
m_d = o_d < D
|
| 34 |
+
|
| 35 |
+
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
|
| 36 |
+
b_m = tl.max(b_x, 0)
|
| 37 |
+
b_x = exp(b_x - b_m)
|
| 38 |
+
b_p = b_x / tl.sum(b_x, 0)
|
| 39 |
+
|
| 40 |
+
tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@triton.autotune(
|
| 44 |
+
configs=[
|
| 45 |
+
triton.Config({}, num_warps=1),
|
| 46 |
+
triton.Config({}, num_warps=2),
|
| 47 |
+
triton.Config({}, num_warps=4),
|
| 48 |
+
triton.Config({}, num_warps=8),
|
| 49 |
+
triton.Config({}, num_warps=16),
|
| 50 |
+
triton.Config({}, num_warps=32)
|
| 51 |
+
],
|
| 52 |
+
key=['D']
|
| 53 |
+
)
|
| 54 |
+
@triton.jit
|
| 55 |
+
def softmax_bwd_kernel(
|
| 56 |
+
p,
|
| 57 |
+
dp,
|
| 58 |
+
ds,
|
| 59 |
+
D: tl.constexpr,
|
| 60 |
+
B: tl.constexpr
|
| 61 |
+
):
|
| 62 |
+
i_n = tl.program_id(0)
|
| 63 |
+
o_d = tl.arange(0, B)
|
| 64 |
+
m_d = o_d < D
|
| 65 |
+
|
| 66 |
+
b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.)
|
| 67 |
+
b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.)
|
| 68 |
+
b_pp = tl.sum(b_p * b_dp, 0)
|
| 69 |
+
b_ds = b_p * b_dp - b_p * b_pp
|
| 70 |
+
tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def softmax_fwd(
|
| 74 |
+
x: torch.Tensor,
|
| 75 |
+
dtype: Optional[torch.dtype] = torch.float
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
shape = x.shape
|
| 78 |
+
x = x.view(-1, x.shape[-1])
|
| 79 |
+
|
| 80 |
+
N, D = x.shape
|
| 81 |
+
B = triton.next_power_of_2(D)
|
| 82 |
+
|
| 83 |
+
p = torch.empty_like(x, dtype=dtype)
|
| 84 |
+
softmax_fwd_kernel[(N,)](
|
| 85 |
+
x=x,
|
| 86 |
+
p=p,
|
| 87 |
+
D=D,
|
| 88 |
+
B=B
|
| 89 |
+
)
|
| 90 |
+
return p.view(*shape)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def softmax_bwd(
|
| 94 |
+
p: torch.Tensor,
|
| 95 |
+
dp: torch.Tensor,
|
| 96 |
+
dtype: Optional[torch.dtype] = torch.float
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
shape = p.shape
|
| 99 |
+
p = p.view(-1, p.shape[-1])
|
| 100 |
+
ds = torch.empty_like(p, dtype=dtype)
|
| 101 |
+
|
| 102 |
+
N, D = p.shape
|
| 103 |
+
B = triton.next_power_of_2(D)
|
| 104 |
+
softmax_bwd_kernel[(N,)](
|
| 105 |
+
p=p,
|
| 106 |
+
dp=dp,
|
| 107 |
+
ds=ds,
|
| 108 |
+
D=D,
|
| 109 |
+
B=B
|
| 110 |
+
)
|
| 111 |
+
return ds.view(*shape)
|
fla3/ops/utils/solve_tril.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import triton
|
| 8 |
+
import triton.language as tl
|
| 9 |
+
|
| 10 |
+
from ...ops.utils.index import prepare_chunk_indices
|
| 11 |
+
from ...utils import input_guard
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@triton.heuristics({
|
| 15 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 16 |
+
})
|
| 17 |
+
@triton.autotune(
|
| 18 |
+
configs=[
|
| 19 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 20 |
+
for num_warps in [1, 2, 4, 8]
|
| 21 |
+
for num_stages in [2, 3, 4, 5]
|
| 22 |
+
],
|
| 23 |
+
key=['BT'],
|
| 24 |
+
)
|
| 25 |
+
@triton.jit(do_not_specialize=['T'])
|
| 26 |
+
def solve_tril_16x16_kernel(
|
| 27 |
+
A,
|
| 28 |
+
Ad,
|
| 29 |
+
cu_seqlens,
|
| 30 |
+
chunk_indices,
|
| 31 |
+
T,
|
| 32 |
+
H: tl.constexpr,
|
| 33 |
+
BT: tl.constexpr,
|
| 34 |
+
IS_VARLEN: tl.constexpr,
|
| 35 |
+
):
|
| 36 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 37 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 38 |
+
if IS_VARLEN:
|
| 39 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 40 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 41 |
+
T = eos - bos
|
| 42 |
+
else:
|
| 43 |
+
bos, eos = i_b * T, i_b * T + T
|
| 44 |
+
|
| 45 |
+
A = A + (bos*H + i_h) * BT
|
| 46 |
+
Ad = Ad + (bos*H + i_h) * 16
|
| 47 |
+
|
| 48 |
+
offset = (i_t * 16) % BT
|
| 49 |
+
p_A = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
|
| 50 |
+
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 16, 0), (16, 16), (1, 0))
|
| 51 |
+
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
| 52 |
+
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
| 53 |
+
|
| 54 |
+
o_i = tl.arange(0, 16)
|
| 55 |
+
for i in range(1, min(16, T-i_t*16)):
|
| 56 |
+
b_a = -tl.load(A + (i_t * 16 + i) * H*BT + o_i + offset)
|
| 57 |
+
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
| 58 |
+
mask = o_i == i
|
| 59 |
+
b_A = tl.where(mask[:, None], b_a, b_A)
|
| 60 |
+
b_A += o_i[:, None] == o_i[None, :]
|
| 61 |
+
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@triton.heuristics({
|
| 65 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 66 |
+
})
|
| 67 |
+
@triton.autotune(
|
| 68 |
+
configs=[
|
| 69 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 70 |
+
for num_warps in [1, 2, 4, 8]
|
| 71 |
+
for num_stages in [2, 3, 4, 5]
|
| 72 |
+
],
|
| 73 |
+
key=['H', 'BT', 'IS_VARLEN'],
|
| 74 |
+
)
|
| 75 |
+
@triton.jit(do_not_specialize=['T'])
|
| 76 |
+
def merge_16x16_to_32x32_inverse_kernel(
|
| 77 |
+
A,
|
| 78 |
+
Ad,
|
| 79 |
+
Ai,
|
| 80 |
+
cu_seqlens,
|
| 81 |
+
chunk_indices,
|
| 82 |
+
T,
|
| 83 |
+
H: tl.constexpr,
|
| 84 |
+
BT: tl.constexpr,
|
| 85 |
+
IS_VARLEN: tl.constexpr
|
| 86 |
+
):
|
| 87 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 88 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 89 |
+
if IS_VARLEN:
|
| 90 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 91 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 92 |
+
T = eos - bos
|
| 93 |
+
else:
|
| 94 |
+
bos, eos = i_b * T, i_b * T + T
|
| 95 |
+
|
| 96 |
+
A += (bos*H + i_h) * 32
|
| 97 |
+
Ad += (bos*H + i_h) * 16
|
| 98 |
+
Ai += (bos*H + i_h) * 32
|
| 99 |
+
|
| 100 |
+
p_A_21 = tl.make_block_ptr(A, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
| 101 |
+
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
| 102 |
+
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
| 103 |
+
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
| 104 |
+
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
|
| 105 |
+
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
| 106 |
+
|
| 107 |
+
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
| 108 |
+
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
| 109 |
+
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
| 110 |
+
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
|
| 111 |
+
tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 112 |
+
tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 113 |
+
tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@triton.heuristics({
|
| 117 |
+
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
| 118 |
+
})
|
| 119 |
+
@triton.autotune(
|
| 120 |
+
configs=[
|
| 121 |
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
| 122 |
+
for num_warps in [2, 4, 8]
|
| 123 |
+
for num_stages in [2, 3, 4, 5]
|
| 124 |
+
],
|
| 125 |
+
key=['H', 'BT', 'IS_VARLEN'],
|
| 126 |
+
)
|
| 127 |
+
@triton.jit(do_not_specialize=['T'])
|
| 128 |
+
def merge_16x16_to_64x64_inverse_kernel(
|
| 129 |
+
A,
|
| 130 |
+
Ad,
|
| 131 |
+
Ai,
|
| 132 |
+
cu_seqlens,
|
| 133 |
+
chunk_indices,
|
| 134 |
+
T,
|
| 135 |
+
H: tl.constexpr,
|
| 136 |
+
BT: tl.constexpr,
|
| 137 |
+
IS_VARLEN: tl.constexpr
|
| 138 |
+
):
|
| 139 |
+
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
| 140 |
+
i_b, i_h = i_bh // H, i_bh % H
|
| 141 |
+
if IS_VARLEN:
|
| 142 |
+
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
| 143 |
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
| 144 |
+
T = eos - bos
|
| 145 |
+
else:
|
| 146 |
+
bos, eos = i_b * T, i_b * T + T
|
| 147 |
+
|
| 148 |
+
A += (bos*H + i_h) * 64
|
| 149 |
+
Ad += (bos*H + i_h) * 16
|
| 150 |
+
Ai += (bos*H + i_h) * 64
|
| 151 |
+
|
| 152 |
+
p_A_21 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
|
| 153 |
+
p_A_32 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
|
| 154 |
+
p_A_31 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
|
| 155 |
+
p_A_43 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
|
| 156 |
+
p_A_42 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
|
| 157 |
+
p_A_41 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
|
| 158 |
+
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64, 0), (16, 16), (1, 0))
|
| 159 |
+
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
|
| 160 |
+
p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
|
| 161 |
+
p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
|
| 162 |
+
|
| 163 |
+
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
| 164 |
+
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
| 165 |
+
A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
| 166 |
+
A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
| 167 |
+
A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
| 168 |
+
A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
| 169 |
+
|
| 170 |
+
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
| 171 |
+
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
| 172 |
+
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
|
| 173 |
+
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
|
| 174 |
+
|
| 175 |
+
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
|
| 176 |
+
Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee')
|
| 177 |
+
Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee')
|
| 178 |
+
|
| 179 |
+
Ai_31 = -tl.dot(
|
| 180 |
+
Ai_33,
|
| 181 |
+
tl.dot(A_31, Ai_11, input_precision='ieee') +
|
| 182 |
+
tl.dot(A_32, Ai_21, input_precision='ieee'),
|
| 183 |
+
input_precision='ieee'
|
| 184 |
+
)
|
| 185 |
+
Ai_42 = -tl.dot(
|
| 186 |
+
Ai_44,
|
| 187 |
+
tl.dot(A_42, Ai_22, input_precision='ieee') +
|
| 188 |
+
tl.dot(A_43, Ai_32, input_precision='ieee'),
|
| 189 |
+
input_precision='ieee'
|
| 190 |
+
)
|
| 191 |
+
Ai_41 = -tl.dot(
|
| 192 |
+
Ai_44,
|
| 193 |
+
tl.dot(A_41, Ai_11, input_precision='ieee') +
|
| 194 |
+
tl.dot(A_42, Ai_21, input_precision='ieee') +
|
| 195 |
+
tl.dot(A_43, Ai_31, input_precision='ieee'),
|
| 196 |
+
input_precision='ieee'
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64, 0), (16, 16), (1, 0))
|
| 200 |
+
p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0))
|
| 201 |
+
p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0))
|
| 202 |
+
p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0))
|
| 203 |
+
p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
|
| 204 |
+
p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
|
| 205 |
+
p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
|
| 206 |
+
p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
|
| 207 |
+
p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
|
| 208 |
+
p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
|
| 209 |
+
tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 210 |
+
tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 211 |
+
tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 212 |
+
tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 213 |
+
tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 214 |
+
tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 215 |
+
tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 216 |
+
tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 217 |
+
tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 218 |
+
tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@input_guard
|
| 222 |
+
def solve_tril(
|
| 223 |
+
A: torch.Tensor,
|
| 224 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 225 |
+
output_dtype: torch.dtype = torch.float
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
"""
|
| 228 |
+
Compute the inverse of the lower triangular matrix
|
| 229 |
+
A should be strictly lower triangular, i.e., A.triu() == 0.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
A (torch.Tensor):
|
| 233 |
+
[B, T, H, K]
|
| 234 |
+
cu_seqlens (torch.Tensor):
|
| 235 |
+
The cumulative sequence lengths of the input tensor.
|
| 236 |
+
Default: None.
|
| 237 |
+
output_dtype (torch.dtype):
|
| 238 |
+
The dtype of the output tensor. Default: `torch.float`
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
(I + A)^-1 with the same shape as A
|
| 242 |
+
"""
|
| 243 |
+
assert A.shape[-1] in [16, 32, 64]
|
| 244 |
+
|
| 245 |
+
B, T, H, BT = A.shape
|
| 246 |
+
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
|
| 247 |
+
|
| 248 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
|
| 249 |
+
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
| 250 |
+
solve_tril_16x16_kernel[NT, B * H](
|
| 251 |
+
A=A,
|
| 252 |
+
Ad=Ad,
|
| 253 |
+
cu_seqlens=cu_seqlens,
|
| 254 |
+
chunk_indices=chunk_indices,
|
| 255 |
+
T=T,
|
| 256 |
+
H=H,
|
| 257 |
+
BT=BT,
|
| 258 |
+
)
|
| 259 |
+
if BT == 16:
|
| 260 |
+
return Ad
|
| 261 |
+
|
| 262 |
+
Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype)
|
| 263 |
+
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
| 264 |
+
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
| 265 |
+
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
| 266 |
+
merge_fn[NT, B * H](
|
| 267 |
+
A=A,
|
| 268 |
+
Ad=Ad,
|
| 269 |
+
Ai=Ai,
|
| 270 |
+
cu_seqlens=cu_seqlens,
|
| 271 |
+
chunk_indices=chunk_indices,
|
| 272 |
+
T=T,
|
| 273 |
+
H=H,
|
| 274 |
+
BT=BT,
|
| 275 |
+
)
|
| 276 |
+
return Ai
|
flame/__init__.py
ADDED
|
File without changes
|
flame/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
flame/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
flame/__pycache__/data.cpython-310.pyc
ADDED
|
Binary file (8.17 kB). View file
|
|
|
flame/__pycache__/data.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
flame/__pycache__/logging.cpython-312.pyc
ADDED
|
Binary file (6.44 kB). View file
|
|
|
flame/__pycache__/parser.cpython-310.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
flame/__pycache__/parser.cpython-312.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
flame/data.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, Iterable, List, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from datasets import Dataset, IterableDataset
|
| 12 |
+
from flame.logging import get_logger
|
| 13 |
+
from transformers import PreTrainedTokenizer
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HuggingfaceDataset(IterableDataset):
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dataset: Dataset,
|
| 23 |
+
tokenizer: PreTrainedTokenizer,
|
| 24 |
+
context_len: int = 2048,
|
| 25 |
+
rank: int = 0,
|
| 26 |
+
world_size: int = 1,
|
| 27 |
+
buffer_size: int = 1024
|
| 28 |
+
) -> HuggingfaceDataset:
|
| 29 |
+
|
| 30 |
+
self.dataset = dataset
|
| 31 |
+
self.tokenizer = tokenizer
|
| 32 |
+
|
| 33 |
+
self.data = dataset.shard(world_size, rank)
|
| 34 |
+
self.context_len = context_len
|
| 35 |
+
self.rank = rank
|
| 36 |
+
self.world_size = world_size
|
| 37 |
+
self.buffer_size = buffer_size
|
| 38 |
+
|
| 39 |
+
if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
|
| 40 |
+
self.dtype = torch.int16
|
| 41 |
+
elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
|
| 42 |
+
self.dtype = torch.int32
|
| 43 |
+
else:
|
| 44 |
+
self.dtype = torch.int64
|
| 45 |
+
self.states = None
|
| 46 |
+
self.buffer = torch.tensor([], dtype=self.dtype)
|
| 47 |
+
self.tokens = []
|
| 48 |
+
self.rand_id = 0
|
| 49 |
+
self.token_id = 0
|
| 50 |
+
self.rng_state = None
|
| 51 |
+
self._epoch = 0
|
| 52 |
+
|
| 53 |
+
def __iter__(self):
|
| 54 |
+
g = torch.Generator()
|
| 55 |
+
g.manual_seed(self._epoch + self.rank)
|
| 56 |
+
if self.rng_state is not None:
|
| 57 |
+
g.set_state(self.rng_state)
|
| 58 |
+
|
| 59 |
+
rand_it = self.randint(0, self.buffer_size, g=g)
|
| 60 |
+
if self.states is not None:
|
| 61 |
+
self.data.load_state_dict(self.states)
|
| 62 |
+
|
| 63 |
+
# max number of tokens allowed in the chunk buffer
|
| 64 |
+
n_tokens = self.buffer_size * self.context_len
|
| 65 |
+
|
| 66 |
+
while True:
|
| 67 |
+
for sample in self.tokenize(self.data):
|
| 68 |
+
# keep appending the samples to the token buffer
|
| 69 |
+
self.tokens += sample
|
| 70 |
+
# if the token buffer is full, start sampling
|
| 71 |
+
# NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
|
| 72 |
+
if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
|
| 73 |
+
self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
|
| 74 |
+
self.tokens = self.tokens[n_tokens:]
|
| 75 |
+
if len(self.buffer) == self.buffer_size:
|
| 76 |
+
yield from self.sample(rand_it)
|
| 77 |
+
|
| 78 |
+
n_chunks = len(self.tokens) // self.context_len
|
| 79 |
+
# handle the left tokens in the buffer
|
| 80 |
+
if n_chunks > 0:
|
| 81 |
+
n_tokens = n_chunks * self.context_len
|
| 82 |
+
indices = torch.randperm(n_chunks, generator=g).tolist()
|
| 83 |
+
self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
|
| 84 |
+
self.tokens = self.tokens[n_tokens:]
|
| 85 |
+
for i in indices:
|
| 86 |
+
yield {'input_ids': self.buffer[i]}
|
| 87 |
+
|
| 88 |
+
def tokenize(self, data, batch_size: int = 64):
|
| 89 |
+
texts, states = [], []
|
| 90 |
+
for sample in data:
|
| 91 |
+
texts.append(sample['text'])
|
| 92 |
+
states.append(self.data.state_dict())
|
| 93 |
+
if len(texts) == batch_size:
|
| 94 |
+
for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
|
| 95 |
+
self.states = s
|
| 96 |
+
yield tokenized
|
| 97 |
+
texts, states = [], []
|
| 98 |
+
if len(texts) > 0:
|
| 99 |
+
for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
|
| 100 |
+
self.states = s
|
| 101 |
+
yield tokenized
|
| 102 |
+
|
| 103 |
+
def sample(self, indices):
|
| 104 |
+
n_tokens = (len(self.tokens) // self.context_len) * self.context_len
|
| 105 |
+
while self.token_id < n_tokens:
|
| 106 |
+
i = next(indices)
|
| 107 |
+
start, end = self.token_id, self.token_id + self.context_len
|
| 108 |
+
self.token_id += self.context_len
|
| 109 |
+
yield {'input_ids': self.buffer[i].to(torch.long)}
|
| 110 |
+
self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
|
| 111 |
+
self.token_id = 0
|
| 112 |
+
self.tokens = self.tokens[n_tokens:]
|
| 113 |
+
|
| 114 |
+
def randint(
|
| 115 |
+
self,
|
| 116 |
+
low: int,
|
| 117 |
+
high: int,
|
| 118 |
+
batch_size: int = 1024,
|
| 119 |
+
g: torch.Generator = torch.Generator()
|
| 120 |
+
) -> Iterable[int]:
|
| 121 |
+
indices = torch.empty(batch_size, dtype=torch.long)
|
| 122 |
+
while True:
|
| 123 |
+
# record the generator states before sampling
|
| 124 |
+
self.rng_state = g.get_state()
|
| 125 |
+
indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
|
| 126 |
+
for i in indices[self.rand_id:].tolist():
|
| 127 |
+
self.rand_id += 1
|
| 128 |
+
yield i
|
| 129 |
+
self.rand_id = 0
|
| 130 |
+
|
| 131 |
+
def set_epoch(self, epoch):
|
| 132 |
+
self._epoch = epoch
|
| 133 |
+
if hasattr(self.dataset, "set_epoch"):
|
| 134 |
+
self.dataset.set_epoch(epoch)
|
| 135 |
+
|
| 136 |
+
def state_dict(self):
|
| 137 |
+
return {
|
| 138 |
+
'states': self.states,
|
| 139 |
+
'buffer': self.buffer.clone(),
|
| 140 |
+
'tokens': deepcopy(self.tokens),
|
| 141 |
+
'rand_id': self.rand_id,
|
| 142 |
+
'token_id': self.token_id,
|
| 143 |
+
'rng_state': self.rng_state,
|
| 144 |
+
'epoch': self._epoch
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def load_state_dict(self, state_dict):
|
| 148 |
+
self.states = state_dict['states']
|
| 149 |
+
self.buffer = state_dict['buffer'].clone()
|
| 150 |
+
self.tokens = deepcopy(state_dict['tokens'])
|
| 151 |
+
self.rand_id = state_dict['rand_id']
|
| 152 |
+
self.token_id = state_dict['token_id']
|
| 153 |
+
self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
|
| 154 |
+
self._epoch = state_dict['epoch']
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@dataclass
|
| 158 |
+
class DataCollatorForLanguageModeling:
|
| 159 |
+
"""
|
| 160 |
+
Data collator used for language modeling.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
| 164 |
+
The tokenizer used for encoding the data.
|
| 165 |
+
varlen (`bool`):
|
| 166 |
+
Whether to return sequences with variable lengths.
|
| 167 |
+
If `True`, the offsets indicating the start and end of each sequence will be returned.
|
| 168 |
+
For example, if the sequence lengths are `[4, 8, 12]`,
|
| 169 |
+
the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
|
| 170 |
+
If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
|
| 171 |
+
return_tensors (`str`):
|
| 172 |
+
The type of Tensor to return. Allowable values are "pt".
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
tokenizer: PreTrainedTokenizer
|
| 176 |
+
varlen: bool = False
|
| 177 |
+
return_tensors: str = "pt"
|
| 178 |
+
|
| 179 |
+
def __call__(
|
| 180 |
+
self,
|
| 181 |
+
examples: List[Union[List[int], Dict[str, Any]]]
|
| 182 |
+
) -> Dict[str, Any]:
|
| 183 |
+
if not isinstance(examples[0], Dict):
|
| 184 |
+
examples = [{'input_ids': example} for example in examples]
|
| 185 |
+
|
| 186 |
+
def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 187 |
+
tensorized = {}
|
| 188 |
+
for key in ['input_ids', 'offsets']:
|
| 189 |
+
if key not in example:
|
| 190 |
+
continue
|
| 191 |
+
if isinstance(example[key], List):
|
| 192 |
+
tensorized[key] = torch.tensor(example[key], dtype=torch.long)
|
| 193 |
+
elif isinstance(example[key], np.ndarray):
|
| 194 |
+
tensorized[key] = torch.from_numpy(example[key])
|
| 195 |
+
else:
|
| 196 |
+
tensorized[key] = example[key]
|
| 197 |
+
return tensorized
|
| 198 |
+
|
| 199 |
+
examples = list(map(tensorize, examples))
|
| 200 |
+
|
| 201 |
+
if not self.varlen:
|
| 202 |
+
length_of_first = examples[0]['input_ids'].size(0)
|
| 203 |
+
# Check if padding is necessary.
|
| 204 |
+
if all(example['input_ids'].size(0) == length_of_first for example in examples):
|
| 205 |
+
batch = {
|
| 206 |
+
'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
|
| 207 |
+
}
|
| 208 |
+
else:
|
| 209 |
+
# If yes, check if we have a `pad_token`.
|
| 210 |
+
if self.tokenizer._pad_token is None:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"You are attempting to pad samples but the tokenizer you are using "
|
| 213 |
+
f"({self.tokenizer.__class__.__name__}) does not have a pad token."
|
| 214 |
+
)
|
| 215 |
+
batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
|
| 216 |
+
else:
|
| 217 |
+
if len(examples) > 1:
|
| 218 |
+
raise ValueError("The batch size must be 1 for variable length inputs.")
|
| 219 |
+
batch = {
|
| 220 |
+
'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
|
| 221 |
+
}
|
| 222 |
+
if 'offsets' in examples[0]:
|
| 223 |
+
batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
|
| 224 |
+
else:
|
| 225 |
+
# determine boundaries by bos/eos positions
|
| 226 |
+
if self.tokenizer.add_bos_token:
|
| 227 |
+
offsets = []
|
| 228 |
+
if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
|
| 229 |
+
offsets.append(torch.tensor([0], dtype=torch.long))
|
| 230 |
+
offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
|
| 231 |
+
offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
|
| 232 |
+
batch['offsets'] = torch.cat(offsets, dim=0)
|
| 233 |
+
elif self.tokenizer.add_eos_token:
|
| 234 |
+
offsets = [torch.tensor([0], dtype=torch.long)]
|
| 235 |
+
offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
|
| 236 |
+
if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
|
| 237 |
+
offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
|
| 238 |
+
batch['offsets'] = torch.cat(offsets, dim=0)
|
| 239 |
+
else:
|
| 240 |
+
raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")
|
| 241 |
+
|
| 242 |
+
labels = batch['input_ids'].clone()
|
| 243 |
+
if self.tokenizer.pad_token_id is not None:
|
| 244 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 245 |
+
batch["labels"] = labels
|
| 246 |
+
return batch
|
flame/logging.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from transformers.trainer_callback import (ExportableState, TrainerCallback,
|
| 10 |
+
TrainerControl, TrainerState)
|
| 11 |
+
from transformers.training_args import TrainingArguments
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_logger(name: str = None) -> logging.Logger:
|
| 15 |
+
formatter = logging.Formatter(
|
| 16 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
| 17 |
+
)
|
| 18 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 19 |
+
handler.setFormatter(formatter)
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(name)
|
| 22 |
+
if 'RANK' in os.environ and int(os.environ['RANK']) == 0:
|
| 23 |
+
logger.setLevel(logging.INFO)
|
| 24 |
+
logger.addHandler(handler)
|
| 25 |
+
|
| 26 |
+
return logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
LOG_FILE_NAME = "trainer_log.jsonl"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LogCallback(TrainerCallback, ExportableState):
|
| 35 |
+
def __init__(self, start_time: float = None, elapsed_time: float = None):
|
| 36 |
+
|
| 37 |
+
self.start_time = time.time() if start_time is None else start_time
|
| 38 |
+
self.elapsed_time = 0 if elapsed_time is None else elapsed_time
|
| 39 |
+
self.last_time = self.start_time
|
| 40 |
+
|
| 41 |
+
def on_train_begin(
|
| 42 |
+
self,
|
| 43 |
+
args: TrainingArguments,
|
| 44 |
+
state: TrainerState,
|
| 45 |
+
control: TrainerControl,
|
| 46 |
+
**kwargs
|
| 47 |
+
):
|
| 48 |
+
r"""
|
| 49 |
+
Event called at the beginning of training.
|
| 50 |
+
"""
|
| 51 |
+
if state.is_local_process_zero:
|
| 52 |
+
if not args.resume_from_checkpoint:
|
| 53 |
+
self.start_time = time.time()
|
| 54 |
+
self.elapsed_time = 0
|
| 55 |
+
else:
|
| 56 |
+
self.start_time = state.stateful_callbacks['LogCallback']['start_time']
|
| 57 |
+
self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']
|
| 58 |
+
|
| 59 |
+
if args.save_on_each_node:
|
| 60 |
+
if not state.is_local_process_zero:
|
| 61 |
+
return
|
| 62 |
+
else:
|
| 63 |
+
if not state.is_world_process_zero:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
self.last_time = time.time()
|
| 67 |
+
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
| 68 |
+
logger.warning("Previous log file in this folder will be deleted.")
|
| 69 |
+
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
| 70 |
+
|
| 71 |
+
def on_log(
|
| 72 |
+
self,
|
| 73 |
+
args: TrainingArguments,
|
| 74 |
+
state: TrainerState,
|
| 75 |
+
control: TrainerControl,
|
| 76 |
+
logs,
|
| 77 |
+
**kwargs
|
| 78 |
+
):
|
| 79 |
+
if args.save_on_each_node:
|
| 80 |
+
if not state.is_local_process_zero:
|
| 81 |
+
return
|
| 82 |
+
else:
|
| 83 |
+
if not state.is_world_process_zero:
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
self.elapsed_time += time.time() - self.last_time
|
| 87 |
+
self.last_time = time.time()
|
| 88 |
+
if 'num_input_tokens_seen' in logs:
|
| 89 |
+
logs['num_tokens'] = logs.pop('num_input_tokens_seen')
|
| 90 |
+
state.log_history[-1].pop('num_input_tokens_seen')
|
| 91 |
+
throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
|
| 92 |
+
state.log_history[-1]['throughput'] = logs['throughput'] = throughput
|
| 93 |
+
state.stateful_callbacks["LogCallback"] = self.state()
|
| 94 |
+
|
| 95 |
+
logs = dict(
|
| 96 |
+
current_steps=state.global_step,
|
| 97 |
+
total_steps=state.max_steps,
|
| 98 |
+
loss=state.log_history[-1].get("loss", None),
|
| 99 |
+
eval_loss=state.log_history[-1].get("eval_loss", None),
|
| 100 |
+
predict_loss=state.log_history[-1].get("predict_loss", None),
|
| 101 |
+
learning_rate=state.log_history[-1].get("learning_rate", None),
|
| 102 |
+
epoch=state.log_history[-1].get("epoch", None),
|
| 103 |
+
percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 107 |
+
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
| 108 |
+
f.write(json.dumps(logs) + "\n")
|
| 109 |
+
|
| 110 |
+
def state(self) -> dict:
|
| 111 |
+
return {
|
| 112 |
+
'start_time': self.start_time,
|
| 113 |
+
'elapsed_time': self.elapsed_time
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def from_state(cls, state):
|
| 118 |
+
return cls(state['start_time'], state['elapsed_time'])
|
flame/parser.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import transformers
|
| 9 |
+
from transformers import HfArgumentParser, TrainingArguments
|
| 10 |
+
|
| 11 |
+
from flame.logging import get_logger
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TrainingArguments(TrainingArguments):
|
| 18 |
+
|
| 19 |
+
model_name_or_path: str = field(
|
| 20 |
+
default=None,
|
| 21 |
+
metadata={
|
| 22 |
+
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
| 23 |
+
},
|
| 24 |
+
)
|
| 25 |
+
tokenizer: str = field(
|
| 26 |
+
default="fla-hub/gla-1.3B-100B",
|
| 27 |
+
metadata={"help": "Name of the tokenizer to use."}
|
| 28 |
+
)
|
| 29 |
+
use_fast_tokenizer: bool = field(
|
| 30 |
+
default=False,
|
| 31 |
+
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
|
| 32 |
+
)
|
| 33 |
+
from_config: bool = field(
|
| 34 |
+
default=True,
|
| 35 |
+
metadata={"help": "Whether to initialize models from scratch."},
|
| 36 |
+
)
|
| 37 |
+
dataset: Optional[str] = field(
|
| 38 |
+
default=None,
|
| 39 |
+
metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
|
| 40 |
+
)
|
| 41 |
+
dataset_name: Optional[str] = field(
|
| 42 |
+
default=None,
|
| 43 |
+
metadata={"help": "The name of provided dataset(s) to use."},
|
| 44 |
+
)
|
| 45 |
+
cache_dir: str = field(
|
| 46 |
+
default=None,
|
| 47 |
+
metadata={"help": "Path to the cached tokenized dataset."},
|
| 48 |
+
)
|
| 49 |
+
split: str = field(
|
| 50 |
+
default="train",
|
| 51 |
+
metadata={"help": "Which dataset split to use for training and evaluation."},
|
| 52 |
+
)
|
| 53 |
+
streaming: bool = field(
|
| 54 |
+
default=False,
|
| 55 |
+
metadata={"help": "Enable dataset streaming."},
|
| 56 |
+
)
|
| 57 |
+
hf_hub_token: Optional[str] = field(
|
| 58 |
+
default=None,
|
| 59 |
+
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
| 60 |
+
)
|
| 61 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 62 |
+
default=None,
|
| 63 |
+
metadata={"help": "The number of processes to use for the pre-processing."},
|
| 64 |
+
)
|
| 65 |
+
buffer_size: int = field(
|
| 66 |
+
default=2048,
|
| 67 |
+
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
| 68 |
+
)
|
| 69 |
+
context_length: int = field(
|
| 70 |
+
default=2048,
|
| 71 |
+
metadata={"help": "The context length of the tokenized inputs in the dataset."},
|
| 72 |
+
)
|
| 73 |
+
varlen: bool = field(
|
| 74 |
+
default=False,
|
| 75 |
+
metadata={"help": "Enable training with variable length inputs."},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_train_args():
|
| 80 |
+
parser = HfArgumentParser(TrainingArguments)
|
| 81 |
+
args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
| 82 |
+
|
| 83 |
+
if unknown_args:
|
| 84 |
+
print(parser.format_help())
|
| 85 |
+
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
| 86 |
+
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
| 87 |
+
|
| 88 |
+
if args.should_log:
|
| 89 |
+
transformers.utils.logging.set_verbosity(args.get_process_log_level())
|
| 90 |
+
transformers.utils.logging.enable_default_handler()
|
| 91 |
+
transformers.utils.logging.enable_explicit_format()
|
| 92 |
+
# set seeds manually
|
| 93 |
+
transformers.set_seed(args.seed)
|
| 94 |
+
return args
|