zaydzuhri commited on
Commit
772c35e
·
verified ·
1 Parent(s): 7dd6ad0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc +0 -0
  2. fla/ops/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/ops/attn/parallel.py +629 -0
  4. fla/ops/based/__pycache__/naive.cpython-312.pyc +0 -0
  5. fla/ops/based/__pycache__/parallel.cpython-312.pyc +0 -0
  6. fla/ops/common/chunk_delta_h.py +399 -0
  7. fla/ops/common/chunk_h_parallel.py +650 -0
  8. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  9. fla/ops/delta_rule/chunk.py +373 -0
  10. fla/ops/delta_rule/fused_recurrent.py +607 -0
  11. fla/ops/delta_rule/naive.py +120 -0
  12. fla/ops/delta_rule/wy_fast.py +340 -0
  13. fla/ops/forgetting_attn/parallel.py +708 -0
  14. fla/ops/gated_delta_rule/chunk.py +392 -0
  15. fla/ops/gated_delta_rule/fused_recurrent.py +321 -0
  16. fla/ops/generalized_delta_rule/dplr/__init__.py +7 -0
  17. fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +446 -0
  19. fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +138 -0
  20. fla/ops/generalized_delta_rule/dplr/naive.py +96 -0
  21. fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +318 -0
  22. fla/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  23. fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  24. fla/ops/gla/__pycache__/naive.cpython-312.pyc +0 -0
  25. fla/ops/gsa/__init__.py +9 -0
  26. fla/ops/gsa/chunk.py +1264 -0
  27. fla/ops/gsa/naive.py +68 -0
  28. fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  29. fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  30. fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  31. fla/ops/linear_attn/__pycache__/naive.cpython-312.pyc +0 -0
  32. fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc +0 -0
  33. fla/ops/nsa/__pycache__/naive.cpython-312.pyc +0 -0
  34. fla/ops/nsa/parallel.py +1435 -0
  35. fla/ops/rebased/__pycache__/__init__.cpython-312.pyc +0 -0
  36. fla/ops/rebased/parallel.py +466 -0
  37. fla/ops/retention/__init__.py +13 -0
  38. fla/ops/retention/fused_chunk.py +365 -0
  39. fla/ops/retention/fused_recurrent.py +42 -0
  40. fla/ops/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  41. fla/ops/rwkv6/fused_recurrent.py +709 -0
  42. fla/ops/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  43. fla/ops/rwkv7/__pycache__/channel_mixing.cpython-312.pyc +0 -0
  44. fla/ops/simple_gla/README.md +10 -0
  45. fla/ops/simple_gla/__init__.py +11 -0
  46. fla/ops/simple_gla/__pycache__/chunk.cpython-312.pyc +0 -0
  47. fla/ops/ttt/__init__.py +9 -0
  48. fla/ops/ttt/chunk.py +1539 -0
  49. fla/ops/ttt/fused_chunk.py +896 -0
  50. fla/ops/ttt/naive.py +126 -0
fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/ops/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (212 Bytes). View file
 
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/based/__pycache__/naive.cpython-312.pyc ADDED
Binary file (4.13 kB). View file
 
fla/ops/based/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/delta_rule/chunk.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.common.utils import prepare_chunk_indices
14
+ from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ beta: torch.Tensor,
23
+ scale: float,
24
+ initial_state: torch.Tensor,
25
+ output_final_state: bool,
26
+ offsets: Optional[torch.LongTensor] = None,
27
+ indices: Optional[torch.LongTensor] = None,
28
+ head_first: bool = True,
29
+ chunk_size: int = 64
30
+ ):
31
+ T = q.shape[2] if head_first else q.shape[1]
32
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, A = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ offsets=offsets,
39
+ indices=indices,
40
+ head_first=head_first,
41
+ chunk_size=BT
42
+ )
43
+
44
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
45
+ k=k,
46
+ w=w,
47
+ u=u,
48
+ g=None,
49
+ initial_state=initial_state,
50
+ output_final_state=output_final_state,
51
+ offsets=offsets,
52
+ indices=indices,
53
+ head_first=head_first,
54
+ chunk_size=BT
55
+ )
56
+ o = chunk_fwd_o(
57
+ q=q,
58
+ k=k,
59
+ v=v_new,
60
+ h=h,
61
+ g=None,
62
+ scale=scale,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ return o, A, final_state
69
+
70
+
71
+ def chunk_delta_rule_bwd(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ beta: torch.Tensor,
76
+ A: torch.Tensor,
77
+ scale: float,
78
+ initial_state: torch.Tensor,
79
+ do: torch.Tensor,
80
+ dht: torch.Tensor,
81
+ offsets: Optional[torch.LongTensor] = None,
82
+ indices: Optional[torch.LongTensor] = None,
83
+ head_first: bool = True,
84
+ chunk_size: int = 64
85
+ ):
86
+ T = q.shape[2] if head_first else q.shape[1]
87
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
88
+ w, u = fwd_recompute_w_u(
89
+ k=k,
90
+ v=v,
91
+ beta=beta,
92
+ A=A,
93
+ offsets=offsets,
94
+ indices=indices,
95
+ head_first=head_first,
96
+ chunk_size=BT
97
+ )
98
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
99
+ k=k,
100
+ w=w,
101
+ u=u,
102
+ g=None,
103
+ initial_state=initial_state,
104
+ output_final_state=False,
105
+ offsets=offsets,
106
+ indices=indices,
107
+ head_first=head_first,
108
+ chunk_size=BT
109
+ )
110
+ dv = chunk_bwd_dv_local(
111
+ q=q,
112
+ k=k,
113
+ do=do,
114
+ g=None,
115
+ dh=None,
116
+ scale=scale,
117
+ offsets=offsets,
118
+ indices=indices,
119
+ head_first=head_first,
120
+ chunk_size=BT
121
+ )
122
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
123
+ q=q,
124
+ k=k,
125
+ w=w,
126
+ g=None,
127
+ h0=initial_state,
128
+ dht=dht,
129
+ do=do,
130
+ dv=dv,
131
+ scale=scale,
132
+ offsets=offsets,
133
+ indices=indices,
134
+ head_first=head_first,
135
+ chunk_size=BT
136
+ )
137
+ dq, dk, dw, _ = chunk_bwd_dqkwg(
138
+ q=q,
139
+ k=k,
140
+ v=v_new,
141
+ h=h,
142
+ w=w,
143
+ dv=dv,
144
+ do=do,
145
+ dh=dh,
146
+ g=None,
147
+ scale=scale,
148
+ offsets=offsets,
149
+ indices=indices,
150
+ head_first=head_first,
151
+ chunk_size=BT
152
+ )
153
+ dk2, dv, db = bwd_prepare_wy_repr(
154
+ k=k,
155
+ v=v,
156
+ beta=beta,
157
+ A=A,
158
+ dw=dw,
159
+ du=dv,
160
+ offsets=offsets,
161
+ indices=indices,
162
+ head_first=head_first,
163
+ chunk_size=BT
164
+ )
165
+ dk.add_(dk2)
166
+ return dq, dk, dv, db, dh0
167
+
168
+
169
+ class ChunkDeltaRuleFunction(torch.autograd.Function):
170
+
171
+ @staticmethod
172
+ @input_guard
173
+ @autocast_custom_fwd
174
+ def forward(
175
+ ctx,
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ v: torch.Tensor,
179
+ beta: torch.Tensor,
180
+ scale: float,
181
+ initial_state: torch.Tensor,
182
+ output_final_state: bool,
183
+ offsets: Optional[torch.LongTensor] = None,
184
+ head_first: bool = True,
185
+ use_qk_l2norm_in_kernel: bool = True
186
+ ):
187
+ T = q.shape[2] if head_first else q.shape[1]
188
+ chunk_size = min(64, max(triton.next_power_of_2(T), 16))
189
+
190
+ q_orig = q
191
+ k_orig = k
192
+
193
+ if use_qk_l2norm_in_kernel:
194
+ q = l2norm_fwd(q)
195
+ k = l2norm_fwd(k)
196
+
197
+ # 2-d indices denoting the offsets of chunks in each sequence
198
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
199
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
200
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
201
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
202
+
203
+ o, A, final_state = chunk_delta_rule_fwd(
204
+ q=q,
205
+ k=k,
206
+ v=v,
207
+ beta=beta,
208
+ scale=scale,
209
+ initial_state=initial_state,
210
+ output_final_state=output_final_state,
211
+ offsets=offsets,
212
+ indices=indices,
213
+ head_first=head_first,
214
+ chunk_size=chunk_size
215
+ )
216
+ ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state)
217
+ ctx.chunk_size = chunk_size
218
+ ctx.scale = scale
219
+ ctx.offsets = offsets
220
+ ctx.indices = indices
221
+ ctx.head_first = head_first
222
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
223
+ return o.to(q.dtype), final_state
224
+
225
+ @staticmethod
226
+ @input_guard
227
+ @autocast_custom_bwd
228
+ def backward(
229
+ ctx,
230
+ do: torch.Tensor,
231
+ dht: torch.Tensor
232
+ ):
233
+ q, k, v, beta, A, initial_state = ctx.saved_tensors
234
+ use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel
235
+ if use_qk_l2norm_in_kernel:
236
+ q, q_orig = l2norm_fwd(q), q
237
+ k, k_orig = l2norm_fwd(k), k
238
+
239
+ dq, dk, dv, db, dh0 = chunk_delta_rule_bwd(
240
+ q=q,
241
+ k=k,
242
+ v=v,
243
+ beta=beta,
244
+ A=A,
245
+ scale=ctx.scale,
246
+ initial_state=initial_state,
247
+ do=do,
248
+ dht=dht,
249
+ offsets=ctx.offsets,
250
+ indices=ctx.indices,
251
+ head_first=ctx.head_first,
252
+ chunk_size=ctx.chunk_size
253
+ )
254
+ if use_qk_l2norm_in_kernel:
255
+ dq = l2norm_bwd(q_orig, dq)
256
+ dk = l2norm_bwd(k_orig, dk)
257
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None
258
+
259
+
260
+ @torch.compiler.disable
261
+ def chunk_delta_rule(
262
+ q: torch.Tensor,
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ beta: torch.Tensor,
266
+ scale: float = None,
267
+ initial_state: torch.Tensor = None,
268
+ output_final_state: bool = False,
269
+ cu_seqlens: Optional[torch.LongTensor] = None,
270
+ head_first: bool = False,
271
+ use_qk_l2norm_in_kernel: bool = False
272
+ ):
273
+ r"""
274
+ Args:
275
+ q (torch.Tensor):
276
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
277
+ k (torch.Tensor):
278
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
279
+ v (torch.Tensor):
280
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
281
+ beta (torch.Tensor):
282
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
283
+ scale (Optional[int]):
284
+ Scale factor for the RetNet attention scores.
285
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
286
+ initial_state (Optional[torch.Tensor]):
287
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
288
+ For equal-length input sequences, `N` equals the batch size `B`.
289
+ Default: `None`.
290
+ output_final_state (Optional[bool]):
291
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
292
+ cu_seqlens (torch.LongTensor):
293
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
294
+ consistent with the FlashAttention API.
295
+ head_first (Optional[bool]):
296
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
297
+ Default: `False`.
298
+ use_qk_l2norm_in_kernel (Optional[bool]):
299
+ Whether to use qk l2norm within the kernel for saving GPU memory.
300
+ Default: `False`.
301
+
302
+ Returns:
303
+ o (torch.Tensor):
304
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
305
+ final_state (torch.Tensor):
306
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
307
+
308
+ Examples::
309
+ >>> import torch
310
+ >>> import torch.nn.functional as F
311
+ >>> from einops import rearrange
312
+ >>> from fla.ops.delta_rule import chunk_delta_rule
313
+ # inputs with equal lengths
314
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
315
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
316
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
317
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
318
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
319
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
320
+ >>> o, ht = chunk_delta_rule(
321
+ q, k, v, beta,
322
+ initial_state=h0,
323
+ output_final_state=True
324
+ )
325
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
326
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
327
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
328
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
329
+ >>> o_var, ht_var = chunk_delta_rule(
330
+ q, k, v, beta,
331
+ initial_state=h0,
332
+ output_final_state=True,
333
+ cu_seqlens=cu_seqlens
334
+ )
335
+ """
336
+ assert q.dtype == k.dtype == v.dtype
337
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
338
+ assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
339
+
340
+ if cu_seqlens is not None:
341
+ if q.shape[0] != 1:
342
+ raise ValueError(
343
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
344
+ f"Please flatten variable-length inputs before processing."
345
+ )
346
+ if head_first:
347
+ raise RuntimeError(
348
+ "Sequences with variable lengths are not supported for head-first mode"
349
+ )
350
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
351
+ raise ValueError(
352
+ f"The number of initial states is expected to be equal to the number of input sequences, "
353
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
354
+ )
355
+ if head_first:
356
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
357
+ beta = rearrange(beta, 'b h t -> b t h')
358
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
359
+ o, final_state = ChunkDeltaRuleFunction.apply(
360
+ q,
361
+ k,
362
+ v,
363
+ beta,
364
+ scale,
365
+ initial_state,
366
+ output_final_state,
367
+ cu_seqlens,
368
+ False,
369
+ use_qk_l2norm_in_kernel
370
+ )
371
+ if head_first:
372
+ o = rearrange(o, 'b t h v -> b h t v')
373
+ return o, final_state
fla/ops/delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ u,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr,
40
+ STORE_FINAL_STATE: tl.constexpr,
41
+ IS_BETA_HEADWISE: tl.constexpr,
42
+ USE_OFFSETS: tl.constexpr,
43
+ HEAD_FIRST: tl.constexpr
44
+ ):
45
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+
55
+ if HEAD_FIRST:
56
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
57
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
58
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
59
+ p_u = u + i_nh * T*V + i_v * BV + tl.arange(0, BV)
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
62
+ else:
63
+ p_beta = beta + i_nh * T
64
+ p_o = o + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
65
+ else:
66
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
67
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
68
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
69
+ p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
70
+ if IS_BETA_HEADWISE:
71
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
72
+ else:
73
+ p_beta = beta + bos * H + i_h
74
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+
76
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
77
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
78
+ mask_h = mask_k[None, :] & mask_v[:, None]
79
+
80
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
81
+ if USE_INITIAL_STATE:
82
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
83
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
84
+
85
+ for _ in range(0, T):
86
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
87
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
88
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
90
+ b_v -= b_v_minus
91
+ if IS_BETA_HEADWISE:
92
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
93
+ else:
94
+ b_beta = tl.load(p_beta).to(tl.float32)
95
+ tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
96
+ b_v *= b_beta
97
+ b_h += b_k[None, :] * b_v[:, None]
98
+ b_o = b_h * b_q[None, :]
99
+ b_o = tl.sum(b_o, axis=1)
100
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
101
+
102
+ p_q += K if HEAD_FIRST else H*K
103
+ p_k += K if HEAD_FIRST else H*K
104
+ p_o += V if HEAD_FIRST else H*V
105
+ p_v += V if HEAD_FIRST else H*V
106
+ p_u += V if HEAD_FIRST else H*V
107
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
117
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
118
+ })
119
+ @triton.jit(do_not_specialize=['T'])
120
+ def fused_recurrent_delta_rule_bwd_kernel(
121
+ q,
122
+ k,
123
+ v,
124
+ beta,
125
+ h0,
126
+ dh0,
127
+ dht,
128
+ do,
129
+ dq,
130
+ dk,
131
+ dv,
132
+ db,
133
+ offsets,
134
+ scale,
135
+ B: tl.constexpr,
136
+ T,
137
+ H: tl.constexpr,
138
+ K: tl.constexpr,
139
+ V: tl.constexpr,
140
+ BK: tl.constexpr,
141
+ BV: tl.constexpr,
142
+ NK: tl.constexpr,
143
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar
144
+ USE_INITIAL_STATE: tl.constexpr, # whether to use dh0
145
+ USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht
146
+ USE_OFFSETS: tl.constexpr,
147
+ HEAD_FIRST: tl.constexpr
148
+ ):
149
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
150
+ i_n, i_h = i_nh // H, i_nh % H
151
+ if USE_OFFSETS:
152
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
153
+ all = T
154
+ T = eos - bos
155
+ else:
156
+ bos, eos = i_n * T, i_n * T + T
157
+ all = B * T
158
+
159
+ mask_k = i_k * BK + tl.arange(0, BK) < K
160
+ mask_v = i_v * BV + tl.arange(0, BV) < V
161
+
162
+ if HEAD_FIRST:
163
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
164
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
165
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
166
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
167
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
168
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
169
+ if IS_BETA_HEADWISE:
170
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
171
+ p_dbeta = db + (i_v * NK*B*H + i_k * B*H + i_nh) * T*V + tl.arange(0, BV) + (T - 1) * V
172
+ else:
173
+ p_beta = beta + i_nh * T + T - 1
174
+ p_dbeta = db + (i_v * B*H + i_nh) * T + T - 1
175
+ else:
176
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
177
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
178
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
179
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
180
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
182
+ if IS_BETA_HEADWISE:
183
+ p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
184
+ p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV)
185
+ else:
186
+ p_beta = beta + (bos + T - 1) * H + i_h
187
+ p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h
188
+
189
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
190
+ if USE_FINAL_STATE_GRADIENT:
191
+ p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
192
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
193
+
194
+ for _ in range(T):
195
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
196
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
197
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
198
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
199
+ if IS_BETA_HEADWISE:
200
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
201
+ else:
202
+ b_beta = tl.load(p_beta).to(tl.float32)
203
+ b_dh += b_q[:, None] * b_do[None, :]
204
+ b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1)
205
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
206
+
207
+ b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v)
208
+ b_dv = b_dv * b_beta
209
+
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
212
+ if IS_BETA_HEADWISE:
213
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v)
214
+ else:
215
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty))
216
+
217
+ b_dh -= b_k[:, None] * b_dv[None, :]
218
+
219
+ p_q -= K if HEAD_FIRST else H*K
220
+ p_k -= K if HEAD_FIRST else H*K
221
+ p_v -= V if HEAD_FIRST else H*V
222
+ p_do -= V if HEAD_FIRST else H*V
223
+ p_dk -= K if HEAD_FIRST else H*K
224
+ p_dv -= V if HEAD_FIRST else H*V
225
+ p_dbeta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
226
+ p_beta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
227
+
228
+ if USE_INITIAL_STATE:
229
+ p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
230
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
231
+
232
+ tl.debug_barrier()
233
+
234
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
235
+
236
+ if HEAD_FIRST:
237
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
238
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
239
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
240
+ if IS_BETA_HEADWISE:
241
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
242
+ else:
243
+ p_beta = beta + i_nh * T
244
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV)
245
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
246
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
247
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
248
+ else:
249
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
250
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
251
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
252
+ if IS_BETA_HEADWISE:
253
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
254
+ else:
255
+ p_beta = beta + bos * H + i_h
256
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
257
+ p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
258
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
260
+
261
+ if USE_INITIAL_STATE:
262
+ mask_h = mask_k[:, None] & mask_v[None, :]
263
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
264
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
265
+
266
+ for _ in range(0, T):
267
+ b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32)
268
+ b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32)
269
+ b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1)
270
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
271
+
272
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
273
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
274
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
275
+ if IS_BETA_HEADWISE:
276
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
277
+ else:
278
+ b_beta = tl.load(p_beta).to(tl.float32)
279
+ b_v *= b_beta
280
+
281
+ b_h += b_k[:, None] * b_v[None, :]
282
+ b_dq = b_h * b_do[None, :]
283
+ d_q = tl.sum(b_dq, axis=1) * scale
284
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
285
+
286
+ p_k += K if HEAD_FIRST else H*K
287
+ p_v += V if HEAD_FIRST else H*V
288
+ p_do += V if HEAD_FIRST else H*V
289
+ p_dq += K if HEAD_FIRST else H*K
290
+ p_dk += K if HEAD_FIRST else H*K
291
+ p_dv += V if HEAD_FIRST else H*V
292
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
293
+
294
+
295
+ def fused_recurrent_delta_rule_fwd(
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ beta: torch.Tensor,
300
+ scale: float,
301
+ initial_state: torch.Tensor,
302
+ output_final_state: bool,
303
+ offsets: Optional[torch.LongTensor] = None,
304
+ head_first: bool = True
305
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ N = B if offsets is None else len(offsets) - 1
311
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
312
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
313
+ assert NK == 1, "NK > 1 is not supported yet"
314
+ num_stages = 1
315
+ num_warps = 1
316
+
317
+ o = q.new_empty(NK, *v.shape)
318
+ if output_final_state:
319
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
320
+ else:
321
+ final_state = None
322
+
323
+ grid = (NV, NK, N * H)
324
+ u = torch.empty_like(v)
325
+ fused_recurrent_delta_rule_fwd_kernel[grid](
326
+ q,
327
+ k,
328
+ v,
329
+ u,
330
+ beta,
331
+ o,
332
+ initial_state,
333
+ final_state,
334
+ offsets,
335
+ scale,
336
+ T=T,
337
+ B=B,
338
+ H=H,
339
+ K=K,
340
+ V=V,
341
+ BK=BK,
342
+ BV=BV,
343
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
344
+ HEAD_FIRST=head_first,
345
+ num_warps=num_warps,
346
+ num_stages=num_stages,
347
+ )
348
+ o = o.squeeze(0)
349
+ return o, u, final_state
350
+
351
+
352
+ def fused_recurrent_delta_rule_bwd(
353
+ q: torch.Tensor,
354
+ k: torch.Tensor,
355
+ v: torch.Tensor,
356
+ beta: torch.Tensor,
357
+ dht: torch.Tensor,
358
+ do: torch.Tensor,
359
+ scale: float,
360
+ initial_state: torch.Tensor,
361
+ offsets: Optional[torch.LongTensor] = None,
362
+ head_first: bool = True
363
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
364
+ if head_first:
365
+ B, H, T, K, V = *k.shape, v.shape[-1]
366
+ else:
367
+ B, T, H, K, V = *k.shape, v.shape[-1]
368
+ N = B if offsets is None else len(offsets) - 1
369
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
370
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
371
+ assert NK == 1, "NK > 1 is not supported yet"
372
+ num_stages = 1
373
+ num_warps = 2
374
+
375
+ beta_vector = beta.ndim == v.ndim
376
+
377
+ dq = q.new_empty(NV, *q.shape)
378
+ dk = q.new_empty(NV, *k.shape)
379
+ dv = q.new_empty(NK, *v.shape)
380
+ if beta_vector:
381
+ db = q.new_empty(NV, NK, B, H, T, V) if head_first else q.new_empty(NV, NK, B, T, H, V)
382
+ else:
383
+ db = q.new_empty(NV, B, H, T) if head_first else q.new_empty(NV, B, T, H)
384
+ grid = (NV, NK, N * H)
385
+
386
+ if initial_state is not None and initial_state.requires_grad:
387
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
388
+ else:
389
+ dh0 = None
390
+
391
+ fused_recurrent_delta_rule_bwd_kernel[grid](
392
+ q,
393
+ k,
394
+ v,
395
+ beta,
396
+ initial_state,
397
+ dh0,
398
+ dht,
399
+ do,
400
+ dq,
401
+ dk,
402
+ dv,
403
+ db,
404
+ offsets,
405
+ scale,
406
+ T=T,
407
+ B=B,
408
+ H=H,
409
+ K=K,
410
+ V=V,
411
+ BK=BK,
412
+ BV=BV,
413
+ NK=NK,
414
+ IS_BETA_HEADWISE=beta_vector,
415
+ HEAD_FIRST=head_first,
416
+ num_warps=num_warps,
417
+ num_stages=num_stages
418
+ )
419
+ dq = dq.sum(0)
420
+ dk = dk.sum(0)
421
+ dv = dv.sum(0)
422
+ db = db.sum((0, 1)) if beta_vector else db.sum(0)
423
+
424
+ return dq, dk, dv, db, dh0
425
+
426
+
427
+ class FusedRecurrentFunction(torch.autograd.Function):
428
+
429
+ @staticmethod
430
+ @input_guard
431
+ def forward(
432
+ ctx,
433
+ q: torch.Tensor,
434
+ k: torch.Tensor,
435
+ v: torch.Tensor,
436
+ beta: torch.Tensor,
437
+ scale: float,
438
+ initial_state: torch.Tensor,
439
+ output_final_state: bool,
440
+ offsets: Optional[torch.LongTensor] = None,
441
+ head_first: bool = True,
442
+ use_qk_l2norm_in_kernel: bool = False
443
+ ):
444
+ q_orig = q
445
+ k_orig = k
446
+
447
+ if use_qk_l2norm_in_kernel:
448
+ q = l2norm_fwd(q)
449
+ k = l2norm_fwd(k)
450
+
451
+ o, u, final_state = fused_recurrent_delta_rule_fwd(
452
+ q=q,
453
+ k=k,
454
+ v=v,
455
+ beta=beta,
456
+ scale=scale,
457
+ initial_state=initial_state,
458
+ output_final_state=output_final_state,
459
+ offsets=offsets,
460
+ head_first=head_first
461
+ )
462
+
463
+ ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state)
464
+ ctx.scale = scale
465
+ ctx.offsets = offsets
466
+ ctx.head_first = head_first
467
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
468
+ return o, final_state
469
+
470
+ @staticmethod
471
+ @input_guard
472
+ def backward(ctx, do, dht):
473
+ q, k, v, beta, initial_state = ctx.saved_tensors
474
+ if ctx.use_qk_l2norm_in_kernel:
475
+ q, q_orig = l2norm_fwd(q), q
476
+ k, k_orig = l2norm_fwd(k), k
477
+ dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd(
478
+ q=q,
479
+ k=k,
480
+ v=v,
481
+ beta=beta,
482
+ dht=dht,
483
+ do=do,
484
+ scale=ctx.scale,
485
+ initial_state=initial_state,
486
+ offsets=ctx.offsets,
487
+ head_first=ctx.head_first
488
+ )
489
+ if ctx.use_qk_l2norm_in_kernel:
490
+ dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk)
491
+ return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None, None
492
+
493
+
494
+ @torch.compiler.disable
495
+ def fused_recurrent_delta_rule(
496
+ q: torch.Tensor,
497
+ k: torch.Tensor,
498
+ v: torch.Tensor,
499
+ beta: torch.Tensor = None,
500
+ scale: float = None,
501
+ initial_state: torch.Tensor = None,
502
+ output_final_state: bool = False,
503
+ cu_seqlens: Optional[torch.LongTensor] = None,
504
+ head_first: bool = True,
505
+ use_qk_l2norm_in_kernel: bool = False
506
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
507
+ r"""
508
+ Args:
509
+ q (torch.Tensor):
510
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
511
+ k (torch.Tensor):
512
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
513
+ v (torch.Tensor):
514
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
515
+ beta (torch.Tensor):
516
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
517
+ scale (Optional[int]):
518
+ Scale factor for the RetNet attention scores.
519
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
520
+ initial_state (Optional[torch.Tensor]):
521
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
522
+ For equal-length input sequences, `N` equals the batch size `B`.
523
+ Default: `None`.
524
+ output_final_state (Optional[bool]):
525
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
526
+ cu_seqlens (torch.LongTensor):
527
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
528
+ consistent with the FlashAttention API.
529
+ head_first (Optional[bool]):
530
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
531
+ Default: `False`.
532
+
533
+ Returns:
534
+ o (torch.Tensor):
535
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
536
+ final_state (torch.Tensor):
537
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
538
+
539
+ Examples::
540
+ >>> import torch
541
+ >>> import torch.nn.functional as F
542
+ >>> from einops import rearrange
543
+ >>> from fla.ops.delta_rule import fused_recurrent_delta_rule
544
+ # inputs with equal lengths
545
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
546
+ >>> q = torch.randn(B, T, H, K, device='cuda')
547
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
548
+ >>> v = torch.randn(B, T, H, V, device='cuda')
549
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
550
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
551
+ >>> o, ht = fused_recurrent_delta_rule(
552
+ q, k, v, beta,
553
+ initial_state=h0,
554
+ output_final_state=True
555
+ )
556
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
557
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
558
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
559
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
560
+ >>> o_var, ht_var = fused_recurrent_delta_rule(
561
+ q, k, v, beta,
562
+ initial_state=h0,
563
+ output_final_state=True,
564
+ cu_seqlens=cu_seqlens
565
+ )
566
+ >>> assert o.allclose(o_var.view(o.shape))
567
+ >>> assert ht.allclose(ht_var)
568
+ """
569
+ if cu_seqlens is not None:
570
+ if q.shape[0] != 1:
571
+ raise ValueError(
572
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
573
+ f"Please flatten variable-length inputs before processing."
574
+ )
575
+ if head_first:
576
+ raise RuntimeError(
577
+ "Sequences with variable lengths are not supported for head-first mode"
578
+ )
579
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
580
+ raise ValueError(
581
+ f"The number of initial states is expected to be equal to the number of input sequences, "
582
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
583
+ )
584
+ if scale is None:
585
+ scale = k.shape[-1] ** -0.5
586
+ else:
587
+ assert scale > 0, "scale must be positive"
588
+ if beta is None:
589
+ beta = torch.ones_like(q[..., 0])
590
+ if head_first:
591
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
592
+ beta = rearrange(beta, 'b h t -> b t h')
593
+ o, final_state = FusedRecurrentFunction.apply(
594
+ q,
595
+ k,
596
+ v,
597
+ beta,
598
+ scale,
599
+ initial_state,
600
+ output_final_state,
601
+ cu_seqlens,
602
+ False,
603
+ use_qk_l2norm_in_kernel
604
+ )
605
+ if head_first:
606
+ o = rearrange(o, 'b t h v -> b h t v')
607
+ return o, final_state
fla/ops/delta_rule/naive.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True):
8
+ orig_dtype = q.dtype
9
+ b, h, l, d_k = q.shape
10
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
11
+ d_v = v.shape[-1]
12
+ o = torch.zeros_like(v)
13
+ S = torch.zeros(b, h, d_k, d_v).to(v)
14
+ q = q * (d_k ** -0.5)
15
+
16
+ if beta.ndim < v.ndim:
17
+ beta = beta[..., None]
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i].clone()
26
+ beta_i = beta[:, :, i]
27
+ _v = _v - (S.clone() * _k[..., None]).sum(-2)
28
+ _v = _v * beta_i
29
+ S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v * beta[..., None]
40
+ k_beta = k * beta[..., None]
41
+
42
+ assert l % chunk_size == 0
43
+
44
+ # compute (I - tri(diag(beta) KK^T))^{-1}
45
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
46
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
47
+ attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
48
+ for i in range(1, chunk_size):
49
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
50
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
51
+
52
+ u = attn @ v
53
+ w = attn @ k_beta
54
+ S = k.new_zeros(b, h, d_k, d_v)
55
+ o = torch.zeros_like(v)
56
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
57
+ for i in range(0, l // chunk_size):
58
+ q_i, k_i = q[:, :, i], k[:, :, i]
59
+ attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
60
+ u_i = u[:, :, i] - w[:, :, i] @ S
61
+ o_inter = q_i @ S
62
+ o[:, :, i] = o_inter + attn @ u_i
63
+ S = S + k_i.transpose(-1, -2) @ u_i
64
+
65
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
66
+
67
+
68
+ def delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
69
+ b, h, l, d_k = q.shape
70
+ # d_v = v.shape[-1]
71
+ q = q * (d_k ** -0.5)
72
+ v = v * beta[..., None]
73
+ k_beta = k * beta[..., None]
74
+ # compute (I - tri(diag(beta) KK^T))^{-1}
75
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
76
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
77
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
78
+ for i in range(1, BN):
79
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
80
+ T = T + torch.eye(BN, dtype=torch.float, device=q.device)
81
+
82
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
83
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
84
+ o_intra = A_local @ v
85
+
86
+ # apply cumprod transition matrices on k to the last position within the chunk
87
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
88
+ # apply cumprod transition matrices on q to the first position within the chunk
89
+ q = q - A_local @ k_beta
90
+ o_intra = A_local @ v
91
+
92
+ A = torch.zeros(b, h, l, l, device=q.device)
93
+
94
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
95
+ o = torch.empty_like(v)
96
+ for i in range(0, l, BM):
97
+ q_i = q[:, :, i:i+BM]
98
+ o_i = o_intra[:, :, i:i+BM]
99
+ # intra block
100
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
101
+ k_j = k[:, :, j:j+BN]
102
+ A_ij = q_i @ k_j.transpose(-1, -2)
103
+ mask = torch.arange(i, i+BM) >= (j + BN)
104
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
105
+ A[:, :, i:i+BM, j:j+BN] = A_ij
106
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
107
+ o_i += A_ij @ v[:, :, j:j+BN]
108
+ # inter block
109
+ for j in range(i - BN, -BN, -BN):
110
+ k_j = k[:, :, j:j+BN]
111
+ A_ij = q_i @ k_j.transpose(-1, -2)
112
+ A[:, :, i:i+BM, j:j+BN] = A_ij
113
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
114
+ o_i += A_ij @ v[:, :, j:j+BN]
115
+ o[:, :, i:i+BM] = o_i
116
+
117
+ for i in range(0, l//BN):
118
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
119
+
120
+ return o, A
fla/ops/delta_rule/wy_fast.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
11
+ from fla.ops.utils.solve_tril import solve_tril
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fwd_recompute_w_u_kernel(
30
+ k,
31
+ v,
32
+ beta,
33
+ w,
34
+ u,
35
+ A,
36
+ offsets,
37
+ indices,
38
+ T,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ HEAD_FIRST: tl.constexpr,
46
+ USE_OFFSETS: tl.constexpr
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if USE_OFFSETS:
51
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
53
+ T = eos - bos
54
+ else:
55
+ bos, eos = i_b * T, i_b * T + T
56
+
57
+ if HEAD_FIRST:
58
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
59
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
60
+ else:
61
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
63
+ b_beta = tl.load(p_beta, boundary_check=(0,))
64
+ b_A = tl.load(p_A, boundary_check=(0, 1))
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ if HEAD_FIRST:
68
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
69
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
70
+ else:
71
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
72
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ b_v = tl.load(p_v, boundary_check=(0, 1))
74
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
75
+ b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False)
76
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ for i_k in range(tl.cdiv(K, BK)):
79
+ if HEAD_FIRST:
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ else:
83
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
84
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
85
+ b_k = tl.load(p_k, boundary_check=(0, 1))
86
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
87
+ b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False)
88
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
89
+
90
+
91
+ @triton.heuristics({
92
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
93
+ })
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
97
+ for num_warps in NUM_WARPS
98
+ for num_stages in [2, 3, 4]
99
+ ],
100
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
101
+ )
102
+ @triton.jit(do_not_specialize=['T'])
103
+ def bwd_prepare_wy_repr_kernel(
104
+ k,
105
+ v,
106
+ beta,
107
+ A,
108
+ dw,
109
+ du,
110
+ dk,
111
+ dv,
112
+ dbeta,
113
+ offsets,
114
+ indices,
115
+ T,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BT: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ HEAD_FIRST: tl.constexpr,
123
+ USE_OFFSETS: tl.constexpr
124
+ ):
125
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
126
+ i_b, i_h = i_bh // H, i_bh % H
127
+ if USE_OFFSETS:
128
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
129
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
130
+ T = eos - bos
131
+ else:
132
+ bos, eos = i_b * T, i_b * T + T
133
+
134
+ if HEAD_FIRST:
135
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
136
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
137
+ else:
138
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
139
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
140
+
141
+ b_beta = tl.load(p_beta, boundary_check=(0,))
142
+ b_A = tl.load(p_A, boundary_check=(0, 1))
143
+
144
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
145
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
146
+ for i_v in range(tl.cdiv(V, BV)):
147
+ if HEAD_FIRST:
148
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
149
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
150
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
151
+ else:
152
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
154
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
155
+
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
158
+ b_du = tl.load(p_du, boundary_check=(0, 1))
159
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
160
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
161
+ b_dv = b_dv_beta * b_beta[:, None]
162
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
163
+
164
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ for i_k in range(tl.cdiv(K, BK)):
167
+ if HEAD_FIRST:
168
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
170
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
171
+ else:
172
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
173
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
174
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
177
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
178
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
179
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
180
+ b_dk = b_dk_beta * b_beta[:, None]
181
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
182
+
183
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
184
+
185
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
186
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
187
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
188
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
189
+
190
+ for i_k in range(tl.cdiv(K, BK)):
191
+ if HEAD_FIRST:
192
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
193
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ else:
195
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ b_k = tl.load(p_k, boundary_check=(0, 1))
198
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
199
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
200
+
201
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
202
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
203
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
204
+ b_dk += b_dk_beta * b_beta[:, None]
205
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
206
+
207
+ if HEAD_FIRST:
208
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
209
+ else:
210
+ p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
211
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
212
+
213
+
214
+ def fwd_prepare_wy_repr(
215
+ k: torch.Tensor,
216
+ v: torch.Tensor,
217
+ beta: torch.Tensor,
218
+ offsets: Optional[torch.LongTensor],
219
+ indices: Optional[torch.LongTensor],
220
+ head_first: bool = False,
221
+ chunk_size: int = 64
222
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
+ A = chunk_scaled_dot_kkt_fwd(
224
+ k=k,
225
+ beta=beta,
226
+ cu_seqlens=offsets,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ output_dtype=torch.float32
230
+ )
231
+ A = solve_tril(
232
+ A=A,
233
+ cu_seqlens=offsets,
234
+ head_first=head_first,
235
+ output_dtype=k.dtype
236
+ )
237
+
238
+ w, u = fwd_recompute_w_u(
239
+ k=k,
240
+ v=v,
241
+ beta=beta,
242
+ A=A,
243
+ offsets=offsets,
244
+ indices=indices,
245
+ head_first=head_first,
246
+ chunk_size=chunk_size
247
+ )
248
+ return w, u, A
249
+
250
+
251
+ def fwd_recompute_w_u(
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ beta: torch.Tensor,
255
+ A: torch.Tensor,
256
+ offsets: Optional[torch.LongTensor],
257
+ indices: Optional[torch.LongTensor],
258
+ head_first: bool,
259
+ chunk_size: int
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ if head_first:
262
+ B, H, T, K, V = *k.shape, v.shape[-1]
263
+ else:
264
+ B, T, H, K, V = *k.shape, v.shape[-1]
265
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
266
+ CONST_TILING = 64 if check_shared_mem() else 32
267
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
268
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
269
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
270
+
271
+ u = torch.empty_like(v)
272
+ w = torch.empty_like(k)
273
+ fwd_recompute_w_u_kernel[(NT, B*H)](
274
+ k,
275
+ v,
276
+ beta,
277
+ w,
278
+ u,
279
+ A,
280
+ offsets=offsets,
281
+ indices=indices,
282
+ T=T,
283
+ H=H,
284
+ K=K,
285
+ V=V,
286
+ BT=BT,
287
+ BK=BK,
288
+ BV=BV,
289
+ HEAD_FIRST=head_first
290
+ )
291
+ return w, u
292
+
293
+
294
+ def bwd_prepare_wy_repr(
295
+ k: torch.Tensor,
296
+ v: torch.Tensor,
297
+ beta: torch.Tensor,
298
+ A: torch.Tensor,
299
+ dw: torch.Tensor,
300
+ du: torch.Tensor,
301
+ offsets: Optional[torch.LongTensor],
302
+ indices: Optional[torch.LongTensor],
303
+ head_first: bool,
304
+ chunk_size: int
305
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
311
+ CONST_TILING = 64 if check_shared_mem() else 32
312
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
313
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
314
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
315
+
316
+ dk = torch.empty_like(k)
317
+ dv = torch.empty_like(v)
318
+ dbeta = torch.empty_like(beta)
319
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
320
+ k,
321
+ v,
322
+ beta,
323
+ A,
324
+ dw,
325
+ du,
326
+ dk,
327
+ dv,
328
+ dbeta,
329
+ offsets=offsets,
330
+ indices=indices,
331
+ T=T,
332
+ H=H,
333
+ K=K,
334
+ V=V,
335
+ BT=BT,
336
+ BK=BK,
337
+ BV=BV,
338
+ HEAD_FIRST=head_first
339
+ )
340
+ return dk, dv, dbeta
fla/ops/forgetting_attn/parallel.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
13
+ from fla.ops.utils.op import div, exp, log
14
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
24
+ for num_stages in [2, 3, 4, 5]
25
+ ],
26
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
27
+ )
28
+ @triton.jit
29
+ def parallel_forgetting_attn_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ g,
34
+ o,
35
+ lse,
36
+ scale,
37
+ offsets,
38
+ indices,
39
+ T,
40
+ B: tl.constexpr,
41
+ H: tl.constexpr,
42
+ HQ: tl.constexpr,
43
+ G: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BS: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr
51
+ ):
52
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
54
+ i_h = i_hq // G
55
+
56
+ if USE_OFFSETS:
57
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
58
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ else:
61
+ i_n = i_b
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
65
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
66
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
67
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
68
+
69
+ # the Q block is kept in the shared memory throughout the whole kernel
70
+ # [BT, BK]
71
+ b_q = tl.load(p_q, boundary_check=(0, 1))
72
+ b_q = (b_q * scale).to(b_q.dtype)
73
+ # [BT,]
74
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
75
+ # [BT, BV]
76
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
77
+
78
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
79
+ b_acc = tl.zeros([BT], dtype=tl.float32)
80
+
81
+ # [BT]
82
+ o_q = i_t * BT + tl.arange(0, BT)
83
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
84
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
86
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
87
+
88
+ # [BS]
89
+ o_k = i_s + tl.arange(0, BS)
90
+ # [BK, BS]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BS, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ # [BS,]
95
+ b_gk = tl.load(p_gk, boundary_check=(0,))
96
+ # [BT, BS]
97
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :]
98
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
99
+
100
+ # [BT]
101
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
102
+ b_r = exp(b_mp - b_m)
103
+ # [BT, BS]
104
+ b_p = exp(b_s - b_m[:, None])
105
+ # [BT]
106
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
107
+ # [BT, BV]
108
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
109
+
110
+ b_mp = b_m
111
+
112
+ for i_s in range(i_t * BT - BS, -BS, -BS):
113
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
114
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
115
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
116
+
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BS,]
122
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
123
+
124
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
125
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
126
+ # [BT, BS]
127
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :]
128
+
129
+ b_gq += b_gn - b_gp
130
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
131
+ b_r = exp(b_mp - b_m)
132
+ # [BT, BS]
133
+ b_p = exp(b_s - b_m[:, None])
134
+ # [BT]
135
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
136
+ # [BT, BV]
137
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
138
+
139
+ b_mp = b_m
140
+
141
+ b_o = div(b_o, b_acc[:, None])
142
+ b_m += log(b_acc)
143
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
144
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
145
+
146
+
147
+ @triton.jit
148
+ def parallel_forgetting_attn_bwd_kernel_preprocess(
149
+ o,
150
+ do,
151
+ delta,
152
+ B: tl.constexpr,
153
+ V: tl.constexpr
154
+ ):
155
+ i_n = tl.program_id(0)
156
+ o_d = tl.arange(0, B)
157
+ m_d = o_d < V
158
+
159
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
160
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
161
+ b_delta = tl.sum(b_o * b_do)
162
+
163
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
164
+
165
+
166
+ @triton.heuristics({
167
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
168
+ })
169
+ @triton.autotune(
170
+ configs=[
171
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
172
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
173
+ for num_stages in [2, 3, 4]
174
+ ],
175
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
176
+ )
177
+ @triton.jit(do_not_specialize=['T'])
178
+ def parallel_forgetting_attn_bwd_kernel_dq(
179
+ q,
180
+ k,
181
+ v,
182
+ g,
183
+ lse,
184
+ delta,
185
+ do,
186
+ dq,
187
+ dg,
188
+ scale,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ B: tl.constexpr,
193
+ H: tl.constexpr,
194
+ HQ: tl.constexpr,
195
+ G: tl.constexpr,
196
+ K: tl.constexpr,
197
+ V: tl.constexpr,
198
+ BT: tl.constexpr,
199
+ BS: tl.constexpr,
200
+ BK: tl.constexpr,
201
+ BV: tl.constexpr,
202
+ USE_OFFSETS: tl.constexpr
203
+ ):
204
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
206
+ i_h = i_hq // G
207
+
208
+ if USE_OFFSETS:
209
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
210
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
211
+ T = eos - bos
212
+ else:
213
+ i_n = i_b
214
+ bos, eos = i_n * T, i_n * T + T
215
+
216
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
217
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
218
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
219
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
220
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
221
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
222
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
223
+
224
+ # [BT, BK]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale).to(b_q.dtype)
227
+ # [BT, BV]
228
+ b_do = tl.load(p_do, boundary_check=(0, 1))
229
+ # [BT]
230
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
231
+ b_lse = tl.load(p_lse, boundary_check=(0,))
232
+ b_delta = tl.load(p_delta, boundary_check=(0,))
233
+
234
+ # [BT]
235
+ o_q = i_t * BT + tl.arange(0, BT)
236
+ # [BT, BK]
237
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
238
+ # [BT]
239
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
240
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
241
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
242
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
243
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
244
+
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+ # [BS,]
252
+ b_gk = tl.load(p_gk, boundary_check=(0,))
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :]
255
+ b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')))
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
260
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
261
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
262
+ # [BT]
263
+ b_dg += tl.sum(b_ds, 1)
264
+
265
+ for i_s in range(i_t * BT - BS, -BS, -BS):
266
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
267
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
268
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
269
+
270
+ # [BK, BS]
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ # [BV, BS]
273
+ b_v = tl.load(p_v, boundary_check=(0, 1))
274
+ # [BS,]
275
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
276
+
277
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
278
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
279
+ # [BT, BS]
280
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :]
281
+ b_p = exp(b_s)
282
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
283
+ b_dp = tl.dot(b_do, b_v)
284
+ b_ds = b_p * (b_dp - b_delta[:, None])
285
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
286
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
287
+ # [BT]
288
+ b_dg += tl.sum(b_ds, 1)
289
+
290
+ b_gq += b_gn - b_gp
291
+
292
+ b_dq *= scale
293
+
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
296
+
297
+
298
+ @triton.heuristics({
299
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
300
+ })
301
+ @triton.autotune(
302
+ configs=[
303
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
304
+ for num_warps in [1, 2, 4, 8]
305
+ for num_stages in [2, 3, 4]
306
+ ],
307
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
308
+ )
309
+ @triton.jit(do_not_specialize=['T'])
310
+ def parallel_forgetting_attn_bwd_kernel_dkv(
311
+ q,
312
+ k,
313
+ v,
314
+ g,
315
+ lse,
316
+ delta,
317
+ do,
318
+ dk,
319
+ dv,
320
+ dg,
321
+ offsets,
322
+ indices,
323
+ scale,
324
+ T,
325
+ B: tl.constexpr,
326
+ H: tl.constexpr,
327
+ HQ: tl.constexpr,
328
+ G: tl.constexpr,
329
+ K: tl.constexpr,
330
+ V: tl.constexpr,
331
+ BT: tl.constexpr,
332
+ BS: tl.constexpr,
333
+ BK: tl.constexpr,
334
+ BV: tl.constexpr,
335
+ USE_OFFSETS: tl.constexpr
336
+ ):
337
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
338
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
339
+ i_h = i_hq // G
340
+
341
+ if USE_OFFSETS:
342
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
343
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
344
+ T = eos - bos
345
+ else:
346
+ i_n = i_b
347
+ bos, eos = i_n * T, i_n * T + T
348
+
349
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
350
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
351
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
352
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
353
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
355
+
356
+ # [BT, BK]
357
+ b_k = tl.load(p_k, boundary_check=(0, 1))
358
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
359
+ # [BT, BV]
360
+ b_v = tl.load(p_v, boundary_check=(0, 1))
361
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
362
+ # [BT]
363
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
364
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
365
+
366
+ o_k = i_t * BT + tl.arange(0, BT)
367
+ m_k = o_k < T
368
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
374
+
375
+ # [BS]
376
+ o_q = i_s + tl.arange(0, BS)
377
+ # [BS, BK]
378
+ b_q = tl.load(p_q, boundary_check=(0, 1))
379
+ b_q = (b_q * scale).to(b_q.dtype)
380
+ # [BS, BV]
381
+ b_do = tl.load(p_do, boundary_check=(0, 1))
382
+ # [BS]
383
+ b_lse = tl.load(p_lse, boundary_check=(0,))
384
+ b_delta = tl.load(p_delta, boundary_check=(0,))
385
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
386
+
387
+ m_q = o_q < T
388
+ m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :]
389
+ # [BT, BS]
390
+ b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :]
391
+ b_p = tl.where(m_s, exp(b_s), 0)
392
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
393
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
394
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
395
+ b_dp = tl.dot(b_v, tl.trans(b_do))
396
+ # [BT, BS]
397
+ b_ds = b_p * (b_dp - b_delta[None, :])
398
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
399
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
400
+ # [BT]
401
+ b_dg -= tl.sum(b_ds, 1)
402
+
403
+ b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32)
404
+ for i_s in range((i_t + 1) * BT, T, BS):
405
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
406
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
407
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
408
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
409
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
410
+
411
+ # [BS]
412
+ o_q = i_s + tl.arange(0, BS)
413
+ # [BS, BK]
414
+ b_q = tl.load(p_q, boundary_check=(0, 1))
415
+ b_q = (b_q * scale).to(b_q.dtype)
416
+ # [BS, BV]
417
+ b_do = tl.load(p_do, boundary_check=(0, 1))
418
+ # [BS]
419
+ b_lse = tl.load(p_lse, boundary_check=(0,))
420
+ b_delta = tl.load(p_delta, boundary_check=(0,))
421
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
422
+
423
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
424
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
425
+ # [BT, BS]
426
+ b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :]
427
+ b_p = exp(b_s)
428
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
429
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
430
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
431
+ b_dp = tl.dot(b_v, tl.trans(b_do))
432
+ # [BT, BS]
433
+ b_ds = b_p * (b_dp - b_delta[None, :])
434
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
435
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
436
+ # [BT]
437
+ b_dg -= tl.sum(b_ds, 1)
438
+
439
+ b_gk -= b_gn - b_gp
440
+
441
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
442
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
443
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
444
+
445
+
446
+ def parallel_forgetting_attn_fwd(
447
+ q: torch.Tensor,
448
+ k: torch.Tensor,
449
+ v: torch.Tensor,
450
+ g: torch.Tensor,
451
+ scale: float,
452
+ chunk_size: int = 128,
453
+ offsets: Optional[torch.LongTensor] = None,
454
+ indices: Optional[torch.LongTensor] = None,
455
+ ):
456
+ B, T, H, K, V = *k.shape, v.shape[-1]
457
+ HQ = q.shape[2]
458
+ G = HQ // H
459
+ BT = chunk_size
460
+ BK = max(16, triton.next_power_of_2(K))
461
+ assert V <= 256, "V must be less than or equal to 256"
462
+ if check_shared_mem('hopper'):
463
+ BS = min(64, max(16, triton.next_power_of_2(T)))
464
+ else:
465
+ BS = min(32, max(16, triton.next_power_of_2(T)))
466
+ BV = min(256, max(16, triton.next_power_of_2(V)))
467
+ NV = triton.cdiv(V, BV)
468
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
469
+
470
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
471
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
472
+
473
+ grid = (NV, NT, B * HQ)
474
+ parallel_forgetting_attn_fwd_kernel[grid](
475
+ q=q,
476
+ k=k,
477
+ v=v,
478
+ g=g,
479
+ o=o,
480
+ lse=lse,
481
+ scale=scale,
482
+ offsets=offsets,
483
+ indices=indices,
484
+ B=B,
485
+ T=T,
486
+ H=H,
487
+ HQ=HQ,
488
+ G=G,
489
+ K=K,
490
+ V=V,
491
+ BT=BT,
492
+ BS=BS,
493
+ BK=BK,
494
+ BV=BV,
495
+ )
496
+ return o, lse
497
+
498
+
499
+ def parallel_forgetting_attn_bwd_preprocess(
500
+ o: torch.Tensor,
501
+ do: torch.Tensor
502
+ ):
503
+ V = o.shape[-1]
504
+ delta = torch.empty_like(o[..., 0], dtype=torch.float)
505
+ parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)](
506
+ o=o,
507
+ do=do,
508
+ delta=delta,
509
+ B=triton.next_power_of_2(V),
510
+ V=V,
511
+ )
512
+ return delta
513
+
514
+
515
+ def parallel_forgetting_attn_bwd(
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ g: torch.Tensor,
520
+ o: torch.Tensor,
521
+ lse: torch.Tensor,
522
+ do: torch.Tensor,
523
+ scale: float = None,
524
+ chunk_size: int = 128,
525
+ offsets: Optional[torch.LongTensor] = None,
526
+ indices: Optional[torch.LongTensor] = None,
527
+ ):
528
+ B, T, H, K, V = *k.shape, v.shape[-1]
529
+ HQ = q.shape[2]
530
+ G = HQ // H
531
+ BT = chunk_size
532
+ BS = min(32, max(16, triton.next_power_of_2(T)))
533
+ BK = max(16, triton.next_power_of_2(K))
534
+ BV = max(16, triton.next_power_of_2(V))
535
+ NV = triton.cdiv(V, BV)
536
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
537
+
538
+ delta = parallel_forgetting_attn_bwd_preprocess(o, do)
539
+ dq = q.new_empty(B, T, HQ, K, dtype=q.dtype)
540
+ dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float)
541
+ dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float)
542
+ dg = q.new_empty(g.shape, dtype=torch.float)
543
+ # NOTE: the original `dg` can be destroyed during autotuning
544
+ # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?)
545
+ # so we need to make a copy of `dg`
546
+ dg2 = q.new_empty(g.shape, dtype=torch.float)
547
+ grid = (NV, NT, B * HQ)
548
+ parallel_forgetting_attn_bwd_kernel_dq[grid](
549
+ q=q,
550
+ k=k,
551
+ v=v,
552
+ g=g,
553
+ lse=lse,
554
+ delta=delta,
555
+ do=do,
556
+ dq=dq,
557
+ dg=dg,
558
+ offsets=offsets,
559
+ indices=indices,
560
+ scale=scale,
561
+ T=T,
562
+ B=B,
563
+ H=H,
564
+ HQ=HQ,
565
+ G=G,
566
+ K=K,
567
+ V=V,
568
+ BT=BT,
569
+ BS=BS,
570
+ BK=BK,
571
+ BV=BV
572
+ )
573
+ parallel_forgetting_attn_bwd_kernel_dkv[grid](
574
+ q=q,
575
+ k=k,
576
+ v=v,
577
+ g=g,
578
+ lse=lse,
579
+ delta=delta,
580
+ do=do,
581
+ dk=dk,
582
+ dv=dv,
583
+ dg=dg2,
584
+ offsets=offsets,
585
+ indices=indices,
586
+ scale=scale,
587
+ T=T,
588
+ B=B,
589
+ H=H,
590
+ HQ=HQ,
591
+ G=G,
592
+ K=K,
593
+ V=V,
594
+ BT=BT,
595
+ BS=BS,
596
+ BK=BK,
597
+ BV=BV
598
+ )
599
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
600
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
601
+ dg = dg.add_(dg2)
602
+ return dq, dk, dv, dg
603
+
604
+
605
+ @torch.compile
606
+ class ParallelForgettingAttentionFunction(torch.autograd.Function):
607
+
608
+ @staticmethod
609
+ @input_guard
610
+ @autocast_custom_fwd
611
+ def forward(ctx, q, k, v, g, scale, offsets):
612
+ ctx.dtype = q.dtype
613
+ if check_shared_mem('hopper'):
614
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
615
+ else:
616
+ chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1])))
617
+ # 2-d indices denoting the offsets of chunks in each sequence
618
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
619
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
620
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
621
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
622
+
623
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False)
624
+ o, lse = parallel_forgetting_attn_fwd(
625
+ q=q,
626
+ k=k,
627
+ v=v,
628
+ g=g,
629
+ scale=scale,
630
+ chunk_size=chunk_size,
631
+ offsets=offsets,
632
+ indices=indices
633
+ )
634
+ ctx.save_for_backward(q, k, v, g, o, lse)
635
+ ctx.chunk_size = chunk_size
636
+ ctx.offsets = offsets
637
+ ctx.indices = indices
638
+ ctx.scale = scale
639
+ return o.to(q.dtype)
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do):
645
+ q, k, v, g, o, lse = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_forgetting_attn_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ o=o,
652
+ lse=lse,
653
+ do=do,
654
+ scale=ctx.scale,
655
+ chunk_size=ctx.chunk_size,
656
+ offsets=ctx.offsets,
657
+ indices=ctx.indices
658
+ )
659
+ dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets)
660
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None
661
+
662
+
663
+ def parallel_forgetting_attn(
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ g: torch.Tensor,
668
+ scale: Optional[float] = None,
669
+ cu_seqlens: Optional[torch.LongTensor] = None,
670
+ head_first: bool = False
671
+ ) -> torch.Tensor:
672
+ r"""
673
+ Args:
674
+ q (torch.Tensor):
675
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
676
+ k (torch.Tensor):
677
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
678
+ GQA will be applied if HQ is divisible by H.
679
+ v (torch.Tensor):
680
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
681
+ g (torch.Tensor):
682
+ Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
683
+ scale (Optional[int]):
684
+ Scale factor for attention scores.
685
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
686
+ cu_seqlens (torch.LongTensor):
687
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
688
+ consistent with the FlashAttention API.
689
+ head_first (Optional[bool]):
690
+ Whether the inputs are in the head-first format. Default: `False`.
691
+
692
+ Returns:
693
+ o (torch.Tensor):
694
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
695
+ """
696
+ if scale is None:
697
+ scale = k.shape[-1] ** -0.5
698
+ if cu_seqlens is not None:
699
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
700
+ if g is not None:
701
+ g = g.float()
702
+ if head_first:
703
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
704
+ g = rearrange(g, 'b h t -> b t h')
705
+ o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens)
706
+ if head_first:
707
+ o = rearrange(o, 'b t h d -> b h t d')
708
+ return o
fla/ops/gated_delta_rule/chunk.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
14
+ from fla.ops.utils import chunk_local_cumsum
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_gated_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ g: torch.Tensor,
23
+ beta: torch.Tensor,
24
+ scale: float,
25
+ initial_state: torch.Tensor,
26
+ output_final_state: bool,
27
+ offsets: Optional[torch.LongTensor] = None,
28
+ indices: Optional[torch.LongTensor] = None,
29
+ head_first: bool = True,
30
+ chunk_size: int = 64
31
+ ):
32
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, Aw, Au = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ g=g,
39
+ offsets=offsets,
40
+ indices=indices,
41
+ head_first=head_first,
42
+ chunk_size=chunk_size
43
+ )
44
+
45
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
46
+ k=k,
47
+ w=w,
48
+ u=u,
49
+ g=g,
50
+ initial_state=initial_state,
51
+ output_final_state=output_final_state,
52
+ offsets=offsets,
53
+ indices=indices,
54
+ head_first=head_first,
55
+ chunk_size=chunk_size
56
+ )
57
+
58
+ # obtain output
59
+ o = chunk_fwd_o(
60
+ q=q,
61
+ k=k,
62
+ v=v_new,
63
+ h=h,
64
+ g=g,
65
+ scale=scale,
66
+ offsets=offsets,
67
+ indices=indices,
68
+ head_first=head_first,
69
+ chunk_size=chunk_size
70
+ )
71
+ return g, o, Aw, Au, final_state
72
+
73
+
74
+ def chunk_gated_delta_rule_bwd(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ v: torch.Tensor,
78
+ g: torch.Tensor,
79
+ beta: torch.Tensor,
80
+ Aw: torch.Tensor,
81
+ Au: torch.Tensor,
82
+ scale: float,
83
+ initial_state: torch.Tensor,
84
+ do: torch.Tensor,
85
+ dht: torch.Tensor,
86
+ offsets: Optional[torch.LongTensor] = None,
87
+ indices: Optional[torch.LongTensor] = None,
88
+ head_first: bool = True,
89
+ chunk_size: int = 64
90
+ ):
91
+ T = q.shape[2] if head_first else q.shape[1]
92
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
93
+ w, u = fwd_recompute_w_u(
94
+ k=k,
95
+ v=v,
96
+ beta=beta,
97
+ Aw=Aw,
98
+ Au=Au,
99
+ offsets=offsets,
100
+ indices=indices,
101
+ head_first=head_first,
102
+ chunk_size=BT
103
+ )
104
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
105
+ k=k,
106
+ w=w,
107
+ u=u,
108
+ g=g,
109
+ initial_state=initial_state,
110
+ output_final_state=False,
111
+ offsets=offsets,
112
+ indices=indices,
113
+ head_first=head_first,
114
+ chunk_size=BT
115
+ )
116
+ dv = chunk_bwd_dv_local(
117
+ q=q,
118
+ k=k,
119
+ g=g,
120
+ do=do,
121
+ dh=None,
122
+ scale=scale,
123
+ offsets=offsets,
124
+ indices=indices,
125
+ head_first=head_first,
126
+ chunk_size=BT
127
+ )
128
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
129
+ q=q,
130
+ k=k,
131
+ w=w,
132
+ g=g,
133
+ h0=initial_state,
134
+ dht=dht,
135
+ do=do,
136
+ dv=dv,
137
+ scale=scale,
138
+ offsets=offsets,
139
+ indices=indices,
140
+ head_first=head_first,
141
+ chunk_size=BT
142
+ )
143
+ dq, dk, dw, dg = chunk_bwd_dqkwg(
144
+ q=q,
145
+ k=k,
146
+ v=v_new,
147
+ w=w,
148
+ g=g,
149
+ h=h,
150
+ dv=dv,
151
+ do=do,
152
+ dh=dh,
153
+ scale=scale,
154
+ offsets=offsets,
155
+ indices=indices,
156
+ head_first=head_first,
157
+ chunk_size=BT
158
+ )
159
+ dk2, dv, db, dg2 = bwd_prepare_wy_repr(
160
+ k=k,
161
+ v=v,
162
+ beta=beta,
163
+ g=g,
164
+ Aw=Aw,
165
+ Au=Au,
166
+ dw=dw,
167
+ du=dv,
168
+ offsets=offsets,
169
+ indices=indices,
170
+ head_first=head_first,
171
+ chunk_size=BT
172
+ )
173
+ dk.add_(dk2)
174
+ dg.add_(dg2)
175
+ assert dg.dtype == torch.float32, "dg should be fp32"
176
+ dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first)
177
+ return dq, dk, dv, db, dg, dh0
178
+
179
+
180
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
181
+
182
+ @staticmethod
183
+ @input_guard
184
+ @autocast_custom_fwd
185
+ def forward(
186
+ ctx,
187
+ q: torch.Tensor,
188
+ k: torch.Tensor,
189
+ v: torch.Tensor,
190
+ g: torch.Tensor,
191
+ beta: torch.Tensor,
192
+ scale: float,
193
+ initial_state: torch.Tensor,
194
+ output_final_state: bool,
195
+ offsets: Optional[torch.LongTensor] = None,
196
+ head_first: bool = True,
197
+ use_qk_l2norm_in_kernel: bool = False
198
+ ):
199
+ chunk_size = 64
200
+ q_orig = q
201
+ k_orig = k
202
+
203
+ if use_qk_l2norm_in_kernel:
204
+ q = l2norm_fwd(q)
205
+ k = l2norm_fwd(k)
206
+
207
+ # 2-d indices denoting the offsets of chunks in each sequence
208
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
209
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
210
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
211
+ indices = None
212
+ if offsets is not None:
213
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
214
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
215
+
216
+ g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd(
217
+ q=q,
218
+ k=k,
219
+ v=v,
220
+ g=g,
221
+ beta=beta,
222
+ scale=scale,
223
+ initial_state=initial_state,
224
+ output_final_state=output_final_state,
225
+ offsets=offsets,
226
+ indices=indices,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ )
230
+ ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices)
231
+ ctx.chunk_size = chunk_size
232
+ ctx.scale = scale
233
+ ctx.head_first = head_first
234
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
235
+ return o.to(q.dtype), final_state
236
+
237
+ @staticmethod
238
+ @input_guard
239
+ @autocast_custom_bwd
240
+ def backward(
241
+ ctx,
242
+ do: torch.Tensor,
243
+ dht: torch.Tensor
244
+ ):
245
+ q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors
246
+ if ctx.use_qk_l2norm_in_kernel:
247
+ q, q_orig = l2norm_fwd(q), q
248
+ k, k_orig = l2norm_fwd(k), k
249
+ dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ g=g,
254
+ beta=beta,
255
+ Aw=Aw,
256
+ Au=Au,
257
+ scale=ctx.scale,
258
+ initial_state=initial_state,
259
+ do=do,
260
+ dht=dht,
261
+ offsets=offsets,
262
+ indices=indices,
263
+ head_first=ctx.head_first,
264
+ chunk_size=ctx.chunk_size
265
+ )
266
+ if ctx.use_qk_l2norm_in_kernel:
267
+ dq = l2norm_bwd(q_orig, dq)
268
+ dk = l2norm_bwd(k_orig, dk)
269
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
270
+
271
+
272
+ @torch.compiler.disable
273
+ def chunk_gated_delta_rule(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ g: torch.Tensor,
278
+ beta: torch.Tensor,
279
+ scale: float = None,
280
+ initial_state: torch.Tensor = None,
281
+ output_final_state: bool = False,
282
+ cu_seqlens: Optional[torch.LongTensor] = None,
283
+ head_first: bool = False,
284
+ use_qk_l2norm_in_kernel: bool = False
285
+ ):
286
+ r"""
287
+ Args:
288
+ q (torch.Tensor):
289
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
290
+ k (torch.Tensor):
291
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
292
+ v (torch.Tensor):
293
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
294
+ g (torch.Tensor):
295
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
296
+ beta (torch.Tensor):
297
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
298
+ scale (Optional[int]):
299
+ Scale factor for the RetNet attention scores.
300
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
301
+ initial_state (Optional[torch.Tensor]):
302
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
303
+ For equal-length input sequences, `N` equals the batch size `B`.
304
+ Default: `None`.
305
+ output_final_state (Optional[bool]):
306
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
307
+ cu_seqlens (torch.LongTensor):
308
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
309
+ consistent with the FlashAttention API.
310
+ head_first (Optional[bool]):
311
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
312
+ Default: `False`.
313
+
314
+ Returns:
315
+ o (torch.Tensor):
316
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
317
+ final_state (torch.Tensor):
318
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319
+
320
+ Examples::
321
+ >>> import torch
322
+ >>> import torch.nn.functional as F
323
+ >>> from einops import rearrange
324
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
325
+ # inputs with equal lengths
326
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
327
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
328
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
329
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
330
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
331
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
332
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
333
+ >>> o, ht = chunk_gated_delta_rule(
334
+ q, k, v, g, beta,
335
+ initial_state=h0,
336
+ output_final_state=True,
337
+ head_first=False
338
+ )
339
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
340
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
341
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
342
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
343
+ >>> o_var, ht_var = chunk_gated_delta_rule(
344
+ q, k, v, g, beta,
345
+ initial_state=h0,
346
+ output_final_state=True,
347
+ cu_seqlens=cu_seqlens,
348
+ head_first=False
349
+ )
350
+ """
351
+ assert q.dtype == k.dtype == v.dtype
352
+ assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
353
+ assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False."
354
+
355
+ if cu_seqlens is not None:
356
+ if q.shape[0] != 1:
357
+ raise ValueError(
358
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
359
+ f"Please flatten variable-length inputs before processing."
360
+ )
361
+ if head_first:
362
+ raise RuntimeError(
363
+ "Sequences with variable lengths are not supported for head-first mode"
364
+ )
365
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
366
+ raise ValueError(
367
+ f"The number of initial states is expected to be equal to the number of input sequences, "
368
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
369
+ )
370
+ if head_first:
371
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
372
+ beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g))
373
+ if scale is None:
374
+ scale = k.shape[-1] ** -0.5
375
+ else:
376
+ assert scale > 0, "Scale must be positive."
377
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
378
+ q,
379
+ k,
380
+ v,
381
+ g,
382
+ beta,
383
+ scale,
384
+ initial_state,
385
+ output_final_state,
386
+ cu_seqlens,
387
+ False,
388
+ use_qk_l2norm_in_kernel
389
+ )
390
+ if head_first:
391
+ o = rearrange(o, 'b t h v -> b h t v')
392
+ return o, final_state
fla/ops/gated_delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_gated_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ g,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
40
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
41
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
42
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
43
+ USE_OFFSETS: tl.constexpr
44
+ ):
45
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+ o_k = i_k * BK + tl.arange(0, BK)
55
+ o_v = i_v * BV + tl.arange(0, BV)
56
+
57
+ p_q = q + (bos * H + i_h) * K + o_k
58
+ p_k = k + (bos * H + i_h) * K + o_k
59
+ p_v = v + (bos * H + i_h) * V + o_v
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + (bos * H + i_h) * V + o_v
62
+ else:
63
+ p_beta = beta + bos * H + i_h
64
+ p_g = g + bos * H + i_h
65
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v
66
+
67
+ mask_k = o_k < K
68
+ mask_v = o_v < V
69
+ mask_h = mask_k[:, None] & mask_v[None, :]
70
+
71
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
74
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
75
+
76
+ for _ in range(0, T):
77
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_g = tl.load(p_g).to(tl.float32)
81
+
82
+ if USE_QK_L2NORM_IN_KERNEL:
83
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
84
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
85
+ b_q = b_q * scale
86
+ # [BK, BV]
87
+ b_h *= exp(b_g)
88
+ # [BV]
89
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
90
+ if IS_BETA_HEADWISE:
91
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
92
+ else:
93
+ b_beta = tl.load(p_beta).to(tl.float32)
94
+ b_v *= b_beta
95
+ # [BK, BV]
96
+ b_h += b_k[:, None] * b_v[None, :]
97
+ # [BV]
98
+ b_o = tl.sum(b_h * b_q[:, None], 0)
99
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
100
+
101
+ p_q += H*K
102
+ p_k += H*K
103
+ p_o += H*V
104
+ p_v += H*V
105
+ p_g += H
106
+ p_beta += H * (V if IS_BETA_HEADWISE else 1)
107
+
108
+ if STORE_FINAL_STATE:
109
+ p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
110
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
111
+
112
+
113
+ def fused_recurrent_gated_delta_rule_fwd(
114
+ q: torch.Tensor,
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ g: torch.Tensor,
118
+ beta: torch.Tensor,
119
+ scale: float,
120
+ initial_state: torch.Tensor,
121
+ output_final_state: bool,
122
+ use_qk_l2norm_in_kernel: bool = False,
123
+ offsets: Optional[torch.LongTensor] = None,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ B, T, H, K, V = *k.shape, v.shape[-1]
126
+ N = B if offsets is None else len(offsets) - 1
127
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
128
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
129
+ assert NK == 1, "NK > 1 is not supported yet"
130
+ num_stages = 3
131
+ num_warps = 1
132
+
133
+ o = q.new_empty(NK, *v.shape)
134
+ if output_final_state:
135
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
136
+ else:
137
+ final_state = None
138
+
139
+ grid = (NK, NV, N * H)
140
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ g=g,
145
+ beta=beta,
146
+ o=o,
147
+ h0=initial_state,
148
+ ht=final_state,
149
+ offsets=offsets,
150
+ scale=scale,
151
+ T=T,
152
+ B=B,
153
+ H=H,
154
+ K=K,
155
+ V=V,
156
+ BK=BK,
157
+ BV=BV,
158
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
159
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
160
+ num_warps=num_warps,
161
+ num_stages=num_stages,
162
+ )
163
+ o = o.squeeze(0)
164
+ return o, final_state
165
+
166
+
167
+ class FusedRecurrentFunction(torch.autograd.Function):
168
+
169
+ @staticmethod
170
+ @input_guard
171
+ def forward(
172
+ ctx,
173
+ q: torch.Tensor,
174
+ k: torch.Tensor,
175
+ v: torch.Tensor,
176
+ g: torch.Tensor,
177
+ beta: torch.Tensor,
178
+ scale: float,
179
+ initial_state: torch.Tensor,
180
+ output_final_state: bool,
181
+ offsets: Optional[torch.LongTensor] = None,
182
+ use_qk_l2norm_in_kernel: bool = False
183
+ ):
184
+ o, final_state = fused_recurrent_gated_delta_rule_fwd(
185
+ q=q,
186
+ k=k,
187
+ v=v,
188
+ g=g,
189
+ beta=beta,
190
+ scale=scale,
191
+ initial_state=initial_state,
192
+ output_final_state=output_final_state,
193
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
194
+ offsets=offsets
195
+ )
196
+
197
+ return o, final_state
198
+
199
+ @staticmethod
200
+ @input_guard
201
+ def backward(ctx, do, dht):
202
+ raise NotImplementedError(
203
+ "Backward pass is not implemented yet and we do not have plans to implement it "
204
+ "because we haven't figured out how to compute dg without materializing the full "
205
+ "hidden states for all time steps."
206
+ )
207
+
208
+
209
+ def fused_recurrent_gated_delta_rule(
210
+ q: torch.Tensor,
211
+ k: torch.Tensor,
212
+ v: torch.Tensor,
213
+ g: torch.Tensor,
214
+ beta: torch.Tensor = None,
215
+ scale: float = None,
216
+ initial_state: torch.Tensor = None,
217
+ output_final_state: bool = False,
218
+ cu_seqlens: Optional[torch.LongTensor] = None,
219
+ use_qk_l2norm_in_kernel: bool = False,
220
+ head_first: bool = False,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ r"""
223
+ Args:
224
+ q (torch.Tensor):
225
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
226
+ k (torch.Tensor):
227
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
228
+ v (torch.Tensor):
229
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
230
+ g (torch.Tensor):
231
+ g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
232
+ beta (torch.Tensor):
233
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
234
+ scale (Optional[int]):
235
+ Scale factor for the RetNet attention scores.
236
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
237
+ initial_state (Optional[torch.Tensor]):
238
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
239
+ For equal-length input sequences, `N` equals the batch size `B`.
240
+ Default: `None`.
241
+ output_final_state (Optional[bool]):
242
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
243
+ cu_seqlens (torch.LongTensor):
244
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
245
+ consistent with the FlashAttention API.
246
+
247
+ Returns:
248
+ o (torch.Tensor):
249
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
250
+ final_state (torch.Tensor):
251
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
252
+
253
+ Examples::
254
+ >>> import torch
255
+ >>> import torch.nn.functional as F
256
+ >>> from einops import rearrange
257
+ >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
258
+ # inputs with equal lengths
259
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
260
+ >>> q = torch.randn(B, T, H, K, device='cuda')
261
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
262
+ >>> v = torch.randn(B, T, H, V, device='cuda')
263
+ >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda'))
264
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
265
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
266
+ >>> o, ht = fused_gated_recurrent_delta_rule(
267
+ q, k, v, g, beta,
268
+ initial_state=h0,
269
+ output_final_state=True,
270
+ )
271
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
272
+ >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
273
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
274
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
275
+ >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
276
+ q, k, v, g, beta,
277
+ initial_state=h0,
278
+ output_final_state=True,
279
+ cu_seqlens=cu_seqlens
280
+ )
281
+ >>> assert o.allclose(o_var.view(o.shape))
282
+ >>> assert ht.allclose(ht_var)
283
+ """
284
+ if cu_seqlens is not None:
285
+ if q.shape[0] != 1:
286
+ raise ValueError(
287
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
288
+ f"Please flatten variable-length inputs before processing."
289
+ )
290
+ if head_first:
291
+ raise RuntimeError(
292
+ "Sequences with variable lengths are not supported for head-first mode"
293
+ )
294
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
295
+ raise ValueError(
296
+ f"The number of initial states is expected to be equal to the number of input sequences, "
297
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
298
+ )
299
+ if scale is None:
300
+ scale = k.shape[-1] ** -0.5
301
+ else:
302
+ assert scale > 0, "scale must be positive"
303
+ if beta is None:
304
+ beta = torch.ones_like(q[..., 0])
305
+ if head_first:
306
+ q, k, v, g, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g, beta))
307
+ o, final_state = FusedRecurrentFunction.apply(
308
+ q,
309
+ k,
310
+ v,
311
+ g,
312
+ beta,
313
+ scale,
314
+ initial_state,
315
+ output_final_state,
316
+ cu_seqlens,
317
+ use_qk_l2norm_in_kernel
318
+ )
319
+ if head_first:
320
+ o = rearrange(o, 'b t h v -> b h t v')
321
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_dplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_dplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import check_shared_mem, is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['BK', 'NC', 'BT', 'K'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def chunk_dplr_bwd_kernel_intra(
28
+ q,
29
+ k,
30
+ a,
31
+ b,
32
+ gi,
33
+ ge,
34
+ dAqk,
35
+ dAqb,
36
+ dAak,
37
+ dAab,
38
+ dq,
39
+ dk,
40
+ da,
41
+ db,
42
+ dqg,
43
+ dkg,
44
+ dag,
45
+ dbg,
46
+ dgk,
47
+ dgk_offset,
48
+ offsets,
49
+ indices,
50
+ scale: tl.constexpr,
51
+ T,
52
+ H: tl.constexpr,
53
+ K: tl.constexpr,
54
+ BT: tl.constexpr,
55
+ BC: tl.constexpr,
56
+ BK: tl.constexpr,
57
+ NC: tl.constexpr,
58
+ USE_OFFSETS: tl.constexpr,
59
+ HEAD_FIRST: tl.constexpr,
60
+ GATHER_SUPPORTED: tl.constexpr
61
+ ):
62
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
63
+ i_b, i_h = i_bh // H, i_bh % H
64
+ i_t, i_i = i_c // NC, i_c % NC
65
+ if USE_OFFSETS:
66
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
67
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
68
+ else:
69
+ bos, eos = i_b * T, i_b * T + T
70
+ T = eos - bos
71
+ if i_t * BT + i_i * BC >= T:
72
+ return
73
+
74
+ # offset calculation
75
+ ge += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
76
+ gi += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
77
+ q += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
78
+ a += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
79
+ b += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
80
+ k += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
81
+ dq += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
82
+ dk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
83
+ da += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
84
+ db += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
85
+ dqg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
86
+ dag += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
87
+ dkg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
88
+ dbg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
89
+ dgk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
90
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
91
+ dAqk += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
92
+ dAqb += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
93
+ dAak += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
94
+ dAab += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
95
+
96
+ stride_qk = K if HEAD_FIRST else H*K
97
+ stride_A = BT if HEAD_FIRST else H*BT
98
+
99
+ p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
100
+ p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
101
+ # [BC, BK]
102
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
103
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
104
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
107
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
108
+ # intra chunk gradient calculation
109
+ p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
110
+ p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
111
+ p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
112
+ p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
113
+ o_i = tl.arange(0, BC)
114
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
115
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
116
+ p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
117
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
118
+ b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32)
119
+ b_b = tl.load(p_b, boundary_check=(0, 1)).to(tl.float32)
120
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
121
+ b_a = tl.load(p_a, boundary_check=(0, 1)).to(tl.float32)
122
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32)
123
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1)).to(tl.float32)
124
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)).to(tl.float32)
125
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1)).to(tl.float32)
126
+
127
+ # inter chunk gradient calculation
128
+ o_k = i_k * BK + tl.arange(0, BK)
129
+ m_k = o_k < K
130
+ if i_i > 0:
131
+ p_gn = gi + (i_t * BT + i_i * BC - 1) * stride_qk + o_k
132
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
133
+ # [BK,]
134
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
135
+ # [BK,]
136
+ for i_j in range(0, i_i):
137
+ p_kj = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
138
+ p_bj = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
139
+ p_gkj = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
140
+ p_dAqikj = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
141
+ p_dAaibj = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
142
+ p_dAqibj = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
143
+ p_dAaikj = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
144
+ # [BC, BK]
145
+ b_kj = tl.load(p_kj, boundary_check=(0, 1))
146
+ b_bj = tl.load(p_bj, boundary_check=(0, 1))
147
+ b_gkj = tl.load(p_gkj, boundary_check=(0, 1))
148
+ tmp = exp(b_gn[None, :] - b_gkj)
149
+ b_kjg = b_kj * tmp
150
+ b_bjg = b_bj * tmp
151
+ # [BC, BC]
152
+ b_dAqikj = tl.load(p_dAqikj, boundary_check=(0, 1))
153
+ b_dAaibj = tl.load(p_dAaibj, boundary_check=(0, 1))
154
+ b_dAqibj = tl.load(p_dAqibj, boundary_check=(0, 1))
155
+ b_dAaikj = tl.load(p_dAaikj, boundary_check=(0, 1))
156
+ # [BC, BK]
157
+ b_dq += tl.dot(b_dAqikj, b_kjg)
158
+ b_dq += tl.dot(b_dAqibj, b_bjg)
159
+ # [BC, BC]
160
+ b_da += tl.dot(b_dAaibj, b_bjg)
161
+ b_da += tl.dot(b_dAaikj, b_kjg)
162
+ b_dq *= exp(b_gi - b_gn[None, :])
163
+ b_da *= exp(b_ge - b_gn[None, :])
164
+
165
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
166
+ if i_i < NC - 1:
167
+ p_gn = gi + (min(i_t * BT + i_i * BC + BC, T) - 1)*stride_qk + o_k
168
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
169
+ # [BK,]
170
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
171
+ for i_j in range(i_i + 1, NC):
172
+ m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T
173
+ p_qj = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
174
+ p_aj = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
175
+ p_gij = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
176
+ p_gej = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
177
+ p_dAqjki = tl.make_block_ptr(dAqk, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
178
+ p_dAajbi = tl.make_block_ptr(dAab, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
179
+ p_dAqjbi = tl.make_block_ptr(dAqb, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
180
+ p_dAajki = tl.make_block_ptr(dAak, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
181
+ b_qj = tl.load(p_qj, boundary_check=(0, 1))
182
+ b_aj = tl.load(p_aj, boundary_check=(0, 1))
183
+ b_gij = tl.load(p_gij, boundary_check=(0, 1))
184
+ b_gej = tl.load(p_gej, boundary_check=(0, 1))
185
+ b_gij = tl.where(m_j[:, None] & m_k, b_gij, float('-inf'))
186
+ b_gej = tl.where(m_j[:, None] & m_k, b_gej, float('-inf'))
187
+ b_qjg = b_qj * exp(b_gij - b_gn[None, :])
188
+ b_ajg = b_aj * exp(b_gej - b_gn[None, :])
189
+ # [BC, BC]
190
+ b_dAqjki = tl.load(p_dAqjki, boundary_check=(0, 1))
191
+ b_dAajbi = tl.load(p_dAajbi, boundary_check=(0, 1))
192
+ b_dAqjbi = tl.load(p_dAqjbi, boundary_check=(0, 1))
193
+ b_dAajki = tl.load(p_dAajki, boundary_check=(0, 1))
194
+ b_dk += tl.dot(b_dAqjki, b_qjg)
195
+ b_dk += tl.dot(b_dAajki, b_ajg)
196
+ b_db += tl.dot(b_dAqjbi, b_qjg)
197
+ b_db += tl.dot(b_dAajbi, b_ajg)
198
+ tmp = exp(b_gn[None, :] - b_gi)
199
+ b_dk *= tmp
200
+ b_db *= tmp
201
+
202
+ # intra chunk gradient calculation
203
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
204
+ # trick to index the block
205
+ if GATHER_SUPPORTED:
206
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
207
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
208
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
209
+ # [1, BK]
210
+ b_kj = gather(b_k, row_idx, axis=0)
211
+ b_bj = gather(b_b, row_idx, axis=0)
212
+ b_gij = gather(b_gi, row_idx, axis=0)
213
+ b_gej = gather(b_ge, row_idx, axis=0)
214
+ b_qj = gather(b_q, row_idx, axis=0)
215
+ b_aj = gather(b_a, row_idx, axis=0)
216
+ # [BC, 1]
217
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
218
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
219
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
220
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
221
+ # [1, BC] -> [BC, 1]
222
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
223
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
224
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
225
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
226
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
227
+ else:
228
+ mask_idx = tl.arange(0, BC) == j
229
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
230
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
231
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
232
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
233
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
234
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
235
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
236
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
237
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
238
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
239
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
240
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
241
+ # [1, BK] b_qj, b_aj
242
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
243
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
244
+ # tl.static_print(b_kj)
245
+ m_e = o_i[:, None] > j
246
+ m_i = o_i[:, None] >= j
247
+ tmp1 = exp(b_gi - b_gij)
248
+ tmp2 = exp(b_ge - b_gij)
249
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.)
250
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.)
251
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.)
252
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.)
253
+
254
+ m_i = o_i[:, None] <= j
255
+ m_e = o_i[:, None] < j
256
+ tmp1 = exp(b_gij - b_gi)
257
+ tmp2 = exp(b_gej - b_gi)
258
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.)
259
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.)
260
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.)
261
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.)
262
+ # post processing
263
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
264
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
265
+ p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
266
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
267
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
268
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
269
+ p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
270
+ p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
271
+ p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
272
+ p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
273
+ p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k
274
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
275
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
276
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
277
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
278
+ tmp = exp(b_gn[None, :] - b_gi)
279
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)) * tmp
280
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)) * tmp
281
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
282
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
283
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
284
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
285
+ b_dgk = b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b
286
+ b_dgk_offset = b_da * b_a
287
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
288
+ tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1))
289
+
290
+
291
+ @triton.heuristics({
292
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
293
+ })
294
+ @triton.autotune(
295
+ configs=[
296
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
297
+ for num_warps in [2, 4, 8, 16, 32]
298
+ for num_stages in [2, 3, 4]
299
+ for BK in [32, 64]
300
+ ],
301
+ key=['BK', 'BT', 'K'],
302
+ use_cuda_graph=use_cuda_graph,
303
+ )
304
+ @triton.jit(do_not_specialize=['T'])
305
+ def chunk_dplr_bwd_dgk_kernel(
306
+ dgk,
307
+ dgk_offset,
308
+ dgk_last,
309
+ dgk_output,
310
+ offsets,
311
+ indices,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ BT: tl.constexpr,
316
+ BK: tl.constexpr,
317
+ USE_OFFSETS: tl.constexpr,
318
+ HEAD_FIRST: tl.constexpr,
319
+ ):
320
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
321
+ i_b, i_h = i_bh // H, i_bh % H
322
+ if USE_OFFSETS:
323
+ i_tg = i_t
324
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
325
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
326
+ T = eos - bos
327
+ NT = tl.cdiv(T, BT)
328
+ else:
329
+ NT = tl.cdiv(T, BT)
330
+ i_tg = i_b * NT + i_t
331
+ bos, eos = i_b * T, i_b * T + T
332
+ T = eos - bos
333
+ stride_qk = K if HEAD_FIRST else H * K
334
+ dgk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
335
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
336
+ dgk_last += ((i_bh * NT + i_t) * K) if HEAD_FIRST else (i_tg * H + i_h) * K
337
+ dgk_output += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
338
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
339
+ m_k = tl.arange(0, BK) + i_k * BK < K
340
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
341
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
342
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
343
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
344
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
345
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
346
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
347
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
348
+ b_dgk_cumsum += b_dgk_last[None, :]
349
+ b_dgk_cumsum -= b_dgk_offset
350
+ p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1))
352
+
353
+
354
+ def chunk_dplr_bwd_dqk_intra(
355
+ q: torch.Tensor,
356
+ k: torch.Tensor,
357
+ a: torch.Tensor,
358
+ b: torch.Tensor,
359
+ gi: torch.Tensor,
360
+ ge: torch.Tensor,
361
+ dAqk: torch.Tensor,
362
+ dAqb: torch.Tensor,
363
+ dAak: torch.Tensor,
364
+ dAab: torch.Tensor,
365
+ dqg: torch.Tensor,
366
+ dkg: torch.Tensor,
367
+ dag: torch.Tensor,
368
+ dbg: torch.Tensor,
369
+ dgk_last: torch.Tensor,
370
+ offsets: Optional[torch.LongTensor] = None,
371
+ indices: Optional[torch.LongTensor] = None,
372
+ head_first: bool = True,
373
+ scale: float = 1.0,
374
+ chunk_size: int = 64,
375
+ ):
376
+ if head_first:
377
+ B, H, T, K = q.shape
378
+ else:
379
+ B, T, H, K = q.shape
380
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
381
+ BC = min(16, BT)
382
+ BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
383
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
384
+ NC = triton.cdiv(BT, BC)
385
+ NK = triton.cdiv(K, BK)
386
+
387
+ dq = torch.empty_like(q)
388
+ dk = torch.empty_like(k)
389
+ da = torch.empty_like(a)
390
+ db = torch.empty_like(b)
391
+ dgk = torch.empty_like(gi, dtype=torch.float)
392
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
393
+
394
+ grid = (NK, NT * NC, B * H)
395
+ chunk_dplr_bwd_kernel_intra[grid](
396
+ q=q,
397
+ k=k,
398
+ a=a,
399
+ b=b,
400
+ gi=gi,
401
+ ge=ge,
402
+ dAqk=dAqk,
403
+ dAqb=dAqb,
404
+ dAak=dAak,
405
+ dAab=dAab,
406
+ dq=dq,
407
+ dk=dk,
408
+ dgk=dgk,
409
+ dgk_offset=dgk_offset,
410
+ dqg=dqg,
411
+ dkg=dkg,
412
+ dag=dag,
413
+ dbg=dbg,
414
+ da=da,
415
+ db=db,
416
+ offsets=offsets,
417
+ indices=indices,
418
+ scale=scale,
419
+ T=T,
420
+ H=H,
421
+ K=K,
422
+ BT=BT,
423
+ BC=BC,
424
+ BK=BK,
425
+ NC=NC,
426
+ HEAD_FIRST=head_first,
427
+ GATHER_SUPPORTED=is_gather_supported
428
+ )
429
+
430
+ def grid2(meta): return (NT, triton.cdiv(K, meta['BK']), B * H)
431
+ dgk_output = torch.empty_like(dgk)
432
+
433
+ chunk_dplr_bwd_dgk_kernel[grid2](
434
+ dgk=dgk,
435
+ dgk_offset=dgk_offset,
436
+ dgk_last=dgk_last,
437
+ dgk_output=dgk_output,
438
+ offsets=offsets,
439
+ indices=indices,
440
+ T=T,
441
+ H=H,
442
+ K=K,
443
+ BT=BT,
444
+ HEAD_FIRST=head_first
445
+ )
446
+ return dq, dk, da, db, dgk_output
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, use_cuda_graph
11
+
12
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in BK_LIST
22
+ for BV in BK_LIST
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_o(
31
+ qg,
32
+ v,
33
+ v_new,
34
+ A_qk,
35
+ A_qb,
36
+ h,
37
+ o,
38
+ offsets,
39
+ indices,
40
+ T,
41
+ H: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ USE_OFFSETS: tl.constexpr,
48
+ HEAD_FIRST: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if USE_OFFSETS:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ if HEAD_FIRST:
67
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
68
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
69
+ else:
70
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
71
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
73
+ b_h = tl.load(p_h, boundary_check=(0, 1))
74
+ b_o += tl.dot(b_qg, b_h)
75
+
76
+ if HEAD_FIRST:
77
+ p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
78
+ p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
79
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
82
+ else:
83
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
84
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+
89
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
90
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
91
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
92
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
93
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
96
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+
100
+ def chunk_dplr_fwd_o(
101
+ qg: torch.Tensor,
102
+ v: torch.Tensor,
103
+ v_new: torch.Tensor,
104
+ A_qk: torch.Tensor,
105
+ A_qb: torch.Tensor,
106
+ h: torch.Tensor,
107
+ offsets: Optional[torch.LongTensor] = None,
108
+ indices: Optional[torch.LongTensor] = None,
109
+ head_first: bool = True,
110
+ chunk_size: int = 64
111
+ ) -> torch.Tensor:
112
+ if head_first:
113
+ B, H, T, K, V = *qg.shape, v.shape[-1]
114
+ else:
115
+ B, T, H, K, V = *qg.shape, v.shape[-1]
116
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
117
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
118
+
119
+ o = torch.empty_like(v)
120
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
121
+ chunk_dplr_fwd_kernel_o[grid](
122
+ qg=qg,
123
+ v=v,
124
+ v_new=v_new,
125
+ A_qk=A_qk,
126
+ A_qb=A_qb,
127
+ h=h,
128
+ o=o,
129
+ offsets=offsets,
130
+ indices=indices,
131
+ T=T,
132
+ H=H,
133
+ K=K,
134
+ V=V,
135
+ BT=BT,
136
+ HEAD_FIRST=head_first
137
+ )
138
+ return o
fla/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps)
20
+ for num_warps in [1, 2, 4, 8, 16]
21
+ ],
22
+ key=['BT'],
23
+ use_cuda_graph=use_cuda_graph,
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ A_ab,
28
+ A_ab_inv,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ BC: tl.constexpr, # placeholder, do not delete
35
+ USE_OFFSETS: tl.constexpr,
36
+ HEAD_FIRST: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ if HEAD_FIRST:
47
+ p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
49
+ else:
50
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
51
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
53
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
54
+ for i in range(1, BT):
55
+ mask = tl.arange(0, BT) == i
56
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
58
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
59
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
60
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
61
+
62
+
63
+ @triton.heuristics({
64
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
65
+ })
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
69
+ for num_warps in [2, 4, 8]
70
+ for num_stages in [2, 3, 4]
71
+ ],
72
+ key=['BC'],
73
+ use_cuda_graph=use_cuda_graph,
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def fwd_prepare_wy_repr_kernel_chunk64(
77
+ A_ab,
78
+ A_ab_inv,
79
+ offsets,
80
+ indices,
81
+ T,
82
+ H: tl.constexpr,
83
+ BT: tl.constexpr,
84
+ BC: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ HEAD_FIRST: tl.constexpr,
87
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
88
+ ):
89
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
90
+ i_b, i_h = i_bh // H, i_bh % H
91
+ if USE_OFFSETS:
92
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
93
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
94
+ T = eos - bos
95
+ else:
96
+ bos, eos = i_b * T, i_b * T + T
97
+
98
+ if HEAD_FIRST:
99
+
100
+ p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
101
+ p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
102
+ p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
103
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
104
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
105
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
106
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
107
+ else:
108
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
109
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
110
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
111
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
112
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
113
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
114
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
115
+
116
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
117
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
118
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
119
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
120
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
121
+
122
+ for i in range(1, BC):
123
+ if GATHER_SUPPORTED:
124
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
125
+ # [1, BK] -> [BK]
126
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
127
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
128
+ else:
129
+ mask = tl.arange(0, BC) == i
130
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
131
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
132
+ mask = tl.arange(0, BC) == i
133
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
134
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
135
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
136
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
137
+ b_A = tl.where(mask[:, None], b_a, b_A)
138
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
139
+
140
+ # blockwise computation of lower triangular matrix's inverse
141
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
142
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
143
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
144
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
145
+ # tl.debug_barrier()
146
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
147
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
148
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
149
+ # causal mask
150
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
155
+ })
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
159
+ for num_warps in [2, 4, 8, 16, 32]
160
+ for num_stages in [2, 3, 4]
161
+ ],
162
+ key=['BT', 'BK', 'BV'],
163
+ use_cuda_graph=use_cuda_graph,
164
+ )
165
+ @triton.jit(do_not_specialize=['T'])
166
+ def fwd_wu_kernel(
167
+ u,
168
+ w,
169
+ ag,
170
+ v,
171
+ A_ab_inv,
172
+ A_ak,
173
+ offsets,
174
+ indices,
175
+ T,
176
+ H: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BK: tl.constexpr,
181
+ BV: tl.constexpr,
182
+ USE_OFFSETS: tl.constexpr,
183
+ HEAD_FIRST: tl.constexpr,
184
+ ):
185
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
186
+ i_b, i_h = i_bh // H, i_bh % H
187
+ if USE_OFFSETS:
188
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
189
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
190
+ T = eos - bos
191
+ else:
192
+ bos, eos = i_b * T, i_b * T + T
193
+
194
+ if HEAD_FIRST:
195
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
196
+ p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
197
+ else:
198
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
199
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
200
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
201
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
202
+ o_s = tl.arange(0, BT)
203
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
204
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
205
+ # let's use tf32 here
206
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
207
+ # (SY 01/04) should be bf16 or tf32? To verify.
208
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
209
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
210
+
211
+ for i_k in range(tl.cdiv(K, BK)):
212
+ if HEAD_FIRST:
213
+ p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
214
+ p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
215
+ else:
216
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
217
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
218
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
219
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
220
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
221
+
222
+ for i_v in range(tl.cdiv(V, BV)):
223
+ if HEAD_FIRST:
224
+ p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
225
+ p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
226
+ else:
227
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
228
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
229
+ b_v = tl.load(p_v, boundary_check=(0, 1))
230
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
231
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
232
+
233
+
234
+ def fwd_prepare_wy_repr(
235
+ ag: torch.Tensor,
236
+ v: torch.Tensor,
237
+ A_ak: torch.Tensor,
238
+ A_ab: torch.Tensor,
239
+ offsets: Optional[torch.LongTensor],
240
+ indices: Optional[torch.LongTensor],
241
+ head_first: bool = True,
242
+ chunk_size: int = 64
243
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if head_first:
245
+ B, H, T, K = ag.shape
246
+ else:
247
+ B, T, H, K = ag.shape
248
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
249
+
250
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
251
+ BC = min(BT, 32)
252
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
253
+ A_ab_inv = torch.empty_like(A_ab)
254
+ fwd_fn[(NT, B * H)](
255
+ A_ab=A_ab,
256
+ A_ab_inv=A_ab_inv,
257
+ offsets=offsets,
258
+ indices=indices,
259
+ T=T,
260
+ H=H,
261
+ BT=BT,
262
+ BC=BC,
263
+ HEAD_FIRST=head_first
264
+ )
265
+ w, u = fwd_wu(
266
+ ag=ag,
267
+ v=v,
268
+ A_ak=A_ak,
269
+ A_ab_inv=A_ab_inv,
270
+ offsets=offsets,
271
+ indices=indices,
272
+ head_first=head_first,
273
+ chunk_size=BT
274
+ )
275
+ return w, u, A_ab_inv
276
+
277
+
278
+ def fwd_wu(
279
+ ag: torch.Tensor,
280
+ v: torch.Tensor,
281
+ A_ak: torch.Tensor,
282
+ A_ab_inv: torch.Tensor,
283
+ offsets: Optional[torch.LongTensor],
284
+ indices: Optional[torch.LongTensor],
285
+ head_first: bool,
286
+ chunk_size: int
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ if head_first:
289
+ B, H, T, K, V = *ag.shape, v.shape[-1]
290
+ else:
291
+ B, T, H, K, V = *ag.shape, v.shape[-1]
292
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
293
+
294
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
295
+ BK = min(triton.next_power_of_2(K), 64)
296
+ BV = min(triton.next_power_of_2(V), 64)
297
+
298
+ u = torch.empty_like(v)
299
+ w = torch.empty_like(ag)
300
+ fwd_wu_kernel[(NT, B*H)](
301
+ ag=ag,
302
+ v=v,
303
+ A_ak=A_ak,
304
+ A_ab_inv=A_ab_inv,
305
+ w=w,
306
+ u=u,
307
+ offsets=offsets,
308
+ indices=indices,
309
+ T=T,
310
+ H=H,
311
+ K=K,
312
+ V=V,
313
+ BT=BT,
314
+ BK=BK,
315
+ BV=BV,
316
+ HEAD_FIRST=head_first
317
+ )
318
+ return w, u
fla/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (5.69 kB). View file
 
fla/ops/gla/__pycache__/naive.cpython-312.pyc ADDED
Binary file (2.09 kB). View file
 
fla/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import reduce
10
+
11
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
12
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
13
+ from fla.ops.utils import chunk_local_cumsum, softmax_bwd, softmax_fwd
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import input_guard
16
+
17
+
18
+ @triton.heuristics({
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in [32, 64]
25
+ for BV in [32, 64]
26
+ for num_warps in [2, 4, 8]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BT']
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_gsa_fwd_k_kernel_inter(
33
+ q,
34
+ k,
35
+ h,
36
+ g,
37
+ o,
38
+ A,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ HQ: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ NG: tl.constexpr,
51
+ USE_OFFSETS: tl.constexpr,
52
+ HEAD_FIRST: tl.constexpr
53
+ ):
54
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
55
+ i_bg = i_bh // NG
56
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
57
+ i_h = i_hq // NG
58
+ if USE_OFFSETS:
59
+ i_tg = i_t
60
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ NT = tl.cdiv(T, BT)
66
+ i_tg = i_b * NT + i_t
67
+ bos, eos = i_b * T, i_b * T + T
68
+
69
+ o_i = tl.arange(0, BT)
70
+ m_s = o_i[:, None] >= o_i[None, :]
71
+
72
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
73
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
74
+ for i_k in range(tl.cdiv(K, BK)):
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_bg * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+ else:
80
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
82
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
83
+
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ b_q = (b_q * scale).to(b_q.dtype)
87
+ # [BK, BT]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ # [BK, BV]
90
+ b_h = tl.load(p_h, boundary_check=(0, 1))
91
+ # [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+ if HEAD_FIRST:
96
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
98
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
99
+ else:
100
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
101
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
102
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
103
+ # [BT, BV]
104
+ b_g = tl.load(p_g, boundary_check=(0, 1))
105
+ b_o = b_o * exp(b_g)
106
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
107
+
108
+ # [BT, BT]
109
+ b_A = tl.where(m_s, b_A, 0.)
110
+ if i_v == 0:
111
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
116
+ })
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_gsa_fwd_k_kernel_intra(
119
+ v,
120
+ g,
121
+ o,
122
+ A,
123
+ offsets,
124
+ indices,
125
+ T,
126
+ HQ: tl.constexpr,
127
+ H: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BC: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ NC: tl.constexpr,
133
+ NG: tl.constexpr,
134
+ USE_OFFSETS: tl.constexpr,
135
+ HEAD_FIRST: tl.constexpr
136
+ ):
137
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
138
+ i_bg = i_bh // NG
139
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
140
+ i_h = i_hq // NG
141
+ i_t, i_i = i_c // NC, i_c % NC
142
+ if USE_OFFSETS:
143
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ else:
147
+ bos, eos = i_b * T, i_b * T + T
148
+
149
+ o_v = i_v * BV + tl.arange(0, BV)
150
+ m_v = o_v < V
151
+
152
+ if i_t * BT + i_i * BC > T:
153
+ return
154
+
155
+ if HEAD_FIRST:
156
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
157
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + min(i_t * BT + i_i * BC, T) * V + o_v, BV), BV)
158
+ else:
159
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
160
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
161
+ # [BV,]
162
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
163
+ # [BC, BV]
164
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
165
+ for i_j in range(0, i_i):
166
+ if HEAD_FIRST:
167
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
168
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
169
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
170
+ else:
171
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
172
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
173
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
174
+ # [BC, BV]
175
+ b_v = tl.load(p_v, boundary_check=(0, 1))
176
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
177
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
178
+ # [BC, BC]
179
+ b_A = tl.load(p_A, boundary_check=(0, 1))
180
+ b_o += tl.dot(b_A, b_vg)
181
+ # [BC, BV]
182
+ b_g = tl.load(p_g, boundary_check=(0, 1))
183
+ b_o *= exp(b_g - b_gn[None, :])
184
+
185
+ o_i = tl.arange(0, BC)
186
+ if HEAD_FIRST:
187
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
188
+ else:
189
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
190
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
191
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
192
+ if HEAD_FIRST:
193
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
194
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
195
+ else:
196
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
197
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
198
+ # [BC,]
199
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
200
+ # [BV,]
201
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
202
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
203
+ # [BC, BV]
204
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
205
+ # avoid 0 * inf = inf
206
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
207
+ if HEAD_FIRST:
208
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
209
+ else:
210
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
211
+ b_o += tl.load(p_o, boundary_check=(0, 1))
212
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
213
+
214
+
215
+ @triton.heuristics({
216
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
217
+ })
218
+ @triton.autotune(
219
+ configs=[
220
+ triton.Config({}, num_warps=num_warps)
221
+ for num_warps in [2, 4, 8]
222
+ ],
223
+ key=["BT"]
224
+ )
225
+ @triton.jit(do_not_specialize=['T'])
226
+ def chunk_gsa_bwd_k_kernel_dA(
227
+ v,
228
+ g,
229
+ do,
230
+ dA,
231
+ indices,
232
+ offsets,
233
+ scale,
234
+ T,
235
+ B: tl.constexpr,
236
+ HQ: tl.constexpr,
237
+ H: tl.constexpr,
238
+ V: tl.constexpr,
239
+ BT: tl.constexpr,
240
+ BC: tl.constexpr,
241
+ BV: tl.constexpr,
242
+ NC: tl.constexpr,
243
+ NG: tl.constexpr,
244
+ USE_OFFSETS: tl.constexpr,
245
+ HEAD_FIRST: tl.constexpr
246
+ ):
247
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_bg = i_bh // NG
249
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
250
+ i_h = i_hq // NG
251
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
252
+ if USE_OFFSETS:
253
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
254
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
255
+ all = T
256
+ T = eos - bos
257
+ else:
258
+ bos, eos = i_b * T, i_b * T + T
259
+ all = B * T
260
+
261
+ o_v = i_v * BV + tl.arange(0, BV)
262
+ m_v = o_v < V
263
+
264
+ if i_t * BT + i_i * BC > T:
265
+ return
266
+
267
+ if HEAD_FIRST:
268
+ p_dA = tl.make_block_ptr(dA+(i_v*B*H+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
269
+ else:
270
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
271
+
272
+ # [BC, BC]
273
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
274
+ if i_i > i_j:
275
+ if HEAD_FIRST:
276
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
277
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
278
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
279
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
280
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
281
+ else:
282
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
283
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
284
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
285
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
286
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
287
+ # [BV,]
288
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
289
+ # [BC, BV]
290
+ b_g = tl.load(p_g, boundary_check=(0, 1))
291
+ b_do = tl.load(p_do, boundary_check=(0, 1))
292
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
293
+ # [BV, BC]
294
+ b_v = tl.load(p_v, boundary_check=(0, 1))
295
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
296
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
297
+ # [BC, BC]
298
+ b_dA = tl.dot(b_do, b_vg)
299
+ elif i_i == i_j:
300
+ if HEAD_FIRST:
301
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
302
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
303
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
304
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
305
+ else:
306
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
307
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
308
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
309
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
310
+ # [BC, BV]
311
+ b_g = tl.load(p_g, boundary_check=(0, 1))
312
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
313
+ m_v = o_v < V
314
+
315
+ o_i = tl.arange(0, BC)
316
+ # [BC, BC]
317
+ m_dA = o_i[:, None] >= o_i[None, :]
318
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
319
+ # [BV,]
320
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
321
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
322
+ # [BC,]
323
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
324
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
325
+
326
+ p_v += (1 if HEAD_FIRST else H) * V
327
+ p_gv += (1 if HEAD_FIRST else H) * V
328
+ b_dA = tl.where(m_dA, b_dA, 0.)
329
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
330
+
331
+
332
+ @triton.heuristics({
333
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
334
+ })
335
+ @triton.autotune(
336
+ configs=[
337
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338
+ for num_warps in [2, 4]
339
+ for num_stages in [2, 3, 4]
340
+ ],
341
+ key=['BT']
342
+ )
343
+ @triton.jit(do_not_specialize=['T'])
344
+ def chunk_gsa_bwd_k_kernel_dqkvg(
345
+ q,
346
+ k,
347
+ v,
348
+ h,
349
+ g,
350
+ A,
351
+ do,
352
+ dh,
353
+ dq,
354
+ dk,
355
+ dv,
356
+ dg,
357
+ dgv,
358
+ dA,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ B: tl.constexpr,
364
+ HQ: tl.constexpr,
365
+ H: tl.constexpr,
366
+ K: tl.constexpr,
367
+ V: tl.constexpr,
368
+ BT: tl.constexpr,
369
+ BK: tl.constexpr,
370
+ BV: tl.constexpr,
371
+ NG: tl.constexpr,
372
+ USE_OFFSETS: tl.constexpr,
373
+ HEAD_FIRST: tl.constexpr
374
+ ):
375
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
376
+ i_bg = i_bh // NG
377
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
378
+ i_h = i_hq // NG
379
+ if USE_OFFSETS:
380
+ i_tg = i_t
381
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
382
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
383
+ all = T
384
+ T = eos - bos
385
+ NT = tl.cdiv(T, BT)
386
+ else:
387
+ NT = tl.cdiv(T, BT)
388
+ i_tg = i_b * NT + i_t
389
+ bos, eos = i_b * T, i_b * T + T
390
+ all = B * T
391
+
392
+ o_i = tl.arange(0, BT)
393
+ o_t = min(i_t * BT + BT, T)
394
+ m_s = o_i[:, None] >= o_i[None, :]
395
+
396
+ if HEAD_FIRST:
397
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
398
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
399
+ p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh) * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
400
+ else:
401
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
402
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
403
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
404
+
405
+ # [BT, BK]
406
+ b_q = tl.load(p_q, boundary_check=(0, 1))
407
+ b_k = tl.load(p_k, boundary_check=(0, 1))
408
+ # [BT, BT]
409
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
410
+ b_A = tl.where(m_s, b_A, 0.)
411
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
412
+
413
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
414
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
415
+ for i_v in range(tl.cdiv(V, BV)):
416
+ o_v = i_v * BV + tl.arange(0, BV)
417
+ if HEAD_FIRST:
418
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
419
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
420
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (o_t - 1) * V + o_v, BV), BV)
421
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
422
+ p_dv = tl.make_block_ptr(dv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
423
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
424
+ p_dgv = tl.make_block_ptr(dgv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
425
+ p_h = tl.make_block_ptr(h + i_bg * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
426
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
427
+ else:
428
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
429
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
430
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
431
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
432
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
433
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
434
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
436
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
437
+ m_v = o_v < V
438
+
439
+ # [BV,]
440
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
441
+ # [BT, BV]
442
+ b_v = tl.load(p_v, boundary_check=(0, 1))
443
+ b_g = tl.load(p_g, boundary_check=(0, 1))
444
+ b_gv = exp(b_gn[None, :] - b_g)
445
+ # [BV, BK]
446
+ b_h = tl.load(p_h, boundary_check=(0, 1))
447
+ # [BT, BV]
448
+ b_do = tl.load(p_do, boundary_check=(0, 1))
449
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
450
+ # [BK, BV]
451
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
452
+ # [BV]
453
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
454
+
455
+ b_dh = b_dh.to(b_k.dtype)
456
+ # [BT, BK]
457
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
458
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
459
+ # [BT, BV]
460
+ b_dv = tl.dot(b_k, b_dh) * b_gv
461
+ # [BV]
462
+ b_dg += tl.sum(b_dv * b_v, 0)
463
+
464
+ if i_k == 0:
465
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
466
+ else:
467
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
468
+
469
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
470
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
471
+ if HEAD_FIRST:
472
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
473
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
474
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
475
+ else:
476
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
477
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
478
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
479
+ # [BT, BT]
480
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
481
+ # [BT, BK]
482
+ b_dq += tl.dot(b_dA, b_k)
483
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
484
+
485
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
486
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
487
+
488
+
489
+ @triton.heuristics({
490
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
491
+ })
492
+ @triton.jit(do_not_specialize=['T'])
493
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
494
+ v,
495
+ g,
496
+ o,
497
+ A,
498
+ do,
499
+ dv,
500
+ dg,
501
+ offsets,
502
+ indices,
503
+ T,
504
+ HQ: tl.constexpr,
505
+ H: tl.constexpr,
506
+ V: tl.constexpr,
507
+ BT: tl.constexpr,
508
+ BC: tl.constexpr,
509
+ BV: tl.constexpr,
510
+ NC: tl.constexpr,
511
+ NG: tl.constexpr,
512
+ USE_OFFSETS: tl.constexpr,
513
+ HEAD_FIRST: tl.constexpr
514
+ ):
515
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
516
+ i_bg = i_bh // NG
517
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
518
+ i_h = i_hq // NG
519
+ i_t, i_i = i_c // NC, i_c % NC
520
+ if USE_OFFSETS:
521
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
523
+ T = eos - bos
524
+ else:
525
+ bos, eos = i_b * T, i_b * T + T
526
+
527
+ o_v = i_v * BV + tl.arange(0, BV)
528
+ m_v = o_v < V
529
+
530
+ if i_t * BT + i_i * BC > T:
531
+ return
532
+
533
+ if HEAD_FIRST:
534
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
535
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (min(i_t * BT + i_i * BC + BC, T) - 1) * V + o_v, BV), BV)
536
+ else:
537
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
538
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
539
+ # [BV,]
540
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
541
+ # [BC, BV]
542
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
543
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
544
+ for i_j in range(i_i + 1, NC):
545
+ if HEAD_FIRST:
546
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
547
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
548
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
549
+ else:
550
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
551
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
552
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
553
+ # [BC, BV]
554
+ b_g = tl.load(p_g, boundary_check=(0, 1))
555
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
556
+ # [BC, BC]
557
+ b_A = tl.load(p_A, boundary_check=(0, 1))
558
+ # [BC, BV]
559
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
560
+ b_dv *= exp(b_gn[None, :] - b_gv)
561
+
562
+ o_i = tl.arange(0, BC)
563
+ o_c = i_i * BC + tl.arange(0, BC)
564
+
565
+ if HEAD_FIRST:
566
+ p_g = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
567
+ p_A = tl.max_contiguous(tl.multiple_of(A + i_bh * T*BT + (i_t * BT + i_i * BC) * BT + o_c, BC), BC)
568
+ p_do = tl.max_contiguous(tl.multiple_of(do + i_bh * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
569
+ else:
570
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
571
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
572
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
573
+
574
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
575
+ # [BC,]
576
+ b_A = tl.load(p_A)
577
+ # [BV,]
578
+ b_g = tl.load(p_g, mask=m_v, other=0)
579
+ b_do = tl.load(p_do, mask=m_v, other=0)
580
+ # [BC, BV]
581
+ m_i = o_i[:, None] <= j
582
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
583
+
584
+ p_g += (1 if HEAD_FIRST else H) * V
585
+ p_A += (1 if HEAD_FIRST else HQ) * BT
586
+ p_do += (1 if HEAD_FIRST else HQ) * V
587
+ if HEAD_FIRST:
588
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
589
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
590
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
591
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
592
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
593
+ else:
594
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
595
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
596
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
597
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
598
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
599
+
600
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
601
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
602
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
603
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
604
+ b_dg = b_o * b_do - b_v * b_dv
605
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
606
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
607
+
608
+
609
+ def chunk_gsa_fwd_v(
610
+ q: torch.Tensor,
611
+ k: torch.Tensor,
612
+ v: torch.Tensor,
613
+ g: torch.Tensor,
614
+ scale: float = 1.,
615
+ initial_state: Optional[torch.Tensor] = None,
616
+ output_final_state: bool = False,
617
+ offsets: Optional[torch.LongTensor] = None,
618
+ indices: Optional[torch.LongTensor] = None,
619
+ head_first: bool = True,
620
+ chunk_size: int = 64
621
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
622
+ _, A, h, ht, o = chunk_gla_fwd(
623
+ q=q,
624
+ k=k,
625
+ v=v,
626
+ g=None,
627
+ g_cumsum=g,
628
+ scale=scale,
629
+ initial_state=initial_state,
630
+ output_final_state=output_final_state,
631
+ offsets=offsets,
632
+ indices=indices,
633
+ head_first=head_first,
634
+ chunk_size=chunk_size
635
+ )
636
+ return A, h, ht, o
637
+
638
+
639
+ def chunk_gsa_fwd_k(
640
+ q: torch.Tensor,
641
+ k: torch.Tensor,
642
+ v: torch.Tensor,
643
+ g: torch.Tensor,
644
+ h0: Optional[torch.Tensor] = None,
645
+ output_final_state: bool = False,
646
+ scale: float = 1.,
647
+ offsets: Optional[torch.LongTensor] = None,
648
+ indices: Optional[torch.LongTensor] = None,
649
+ head_first: bool = True,
650
+ chunk_size: int = 64
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ if head_first:
653
+ B, H, T, K, V = *k.shape, v.shape[-1]
654
+ else:
655
+ B, T, H, K, V = *k.shape, v.shape[-1]
656
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
657
+ BC = min(16, BT)
658
+ BV = min(64, triton.next_power_of_2(V))
659
+ HQ = q.shape[1] if head_first else q.shape[2]
660
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
661
+ NC = triton.cdiv(BT, BC)
662
+ NG = HQ // H
663
+
664
+ h, ht = chunk_fwd_h(
665
+ k=k,
666
+ v=v,
667
+ g=None,
668
+ gk=None,
669
+ gv=g,
670
+ h0=h0,
671
+ output_final_state=output_final_state,
672
+ offsets=offsets,
673
+ head_first=head_first,
674
+ chunk_size=BT,
675
+ states_in_fp32=False
676
+ )
677
+ o = v.new_empty(B, *((HQ, T) if head_first else (T, HQ)), V)
678
+ A = q.new_empty(B, *((HQ, T) if head_first else (T, HQ)), BT)
679
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
680
+ chunk_gsa_fwd_k_kernel_inter[grid](
681
+ q,
682
+ k,
683
+ h,
684
+ g,
685
+ o,
686
+ A,
687
+ offsets=offsets,
688
+ indices=indices,
689
+ scale=scale,
690
+ T=T,
691
+ HQ=HQ,
692
+ H=H,
693
+ K=K,
694
+ V=V,
695
+ BT=BT,
696
+ NG=NG,
697
+ HEAD_FIRST=head_first
698
+ )
699
+
700
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
701
+ chunk_gsa_fwd_k_kernel_intra[grid](
702
+ v,
703
+ g,
704
+ o,
705
+ A,
706
+ offsets=offsets,
707
+ indices=indices,
708
+ T=T,
709
+ HQ=HQ,
710
+ H=H,
711
+ V=V,
712
+ BT=BT,
713
+ BC=BC,
714
+ BV=BV,
715
+ NC=NC,
716
+ NG=NG,
717
+ HEAD_FIRST=head_first,
718
+ num_warps=4,
719
+ num_stages=2
720
+ )
721
+ return A, h, ht, o
722
+
723
+
724
+ def chunk_gsa_bwd_v(
725
+ q: torch.Tensor,
726
+ k: torch.Tensor,
727
+ v: torch.Tensor,
728
+ g: torch.Tensor,
729
+ h0: torch.Tensor,
730
+ h: torch.Tensor,
731
+ A: torch.Tensor,
732
+ do: torch.Tensor,
733
+ dht: torch.Tensor,
734
+ dg: torch.Tensor,
735
+ scale: float = 1.,
736
+ offsets: Optional[torch.LongTensor] = None,
737
+ indices: Optional[torch.LongTensor] = None,
738
+ head_first: bool = True,
739
+ chunk_size: int = 64
740
+ ):
741
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
742
+ q=q,
743
+ k=k,
744
+ v=v,
745
+ g=None,
746
+ g_cumsum=g,
747
+ scale=scale,
748
+ initial_state=h0,
749
+ h=h,
750
+ A=A,
751
+ do=do,
752
+ dht=dht,
753
+ offsets=offsets,
754
+ indices=indices,
755
+ head_first=head_first,
756
+ chunk_size=chunk_size
757
+ )
758
+ return dq, dk, dv, dg, dh0
759
+
760
+
761
+ def chunk_gsa_bwd_k(
762
+ q: torch.Tensor,
763
+ k: torch.Tensor,
764
+ v: torch.Tensor,
765
+ g: torch.Tensor,
766
+ h: torch.Tensor,
767
+ h0: torch.Tensor,
768
+ o: torch.Tensor,
769
+ do: torch.Tensor,
770
+ dht: torch.Tensor,
771
+ dg: torch.Tensor,
772
+ scale: float = 1.,
773
+ offsets: Optional[torch.LongTensor] = None,
774
+ indices: Optional[torch.LongTensor] = None,
775
+ head_first: bool = True,
776
+ chunk_size: int = 64
777
+ ):
778
+ if head_first:
779
+ B, H, T, K, V = *k.shape, v.shape[-1]
780
+ else:
781
+ B, T, H, K, V = *k.shape, v.shape[-1]
782
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
783
+ BC = min(16, BT)
784
+ BK = min(64, triton.next_power_of_2(K))
785
+ BV = min(64, triton.next_power_of_2(V))
786
+ HQ = q.shape[1] if head_first else q.shape[2]
787
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
788
+ NC = triton.cdiv(BT, BC)
789
+ NK = triton.cdiv(K, BK)
790
+ NV = triton.cdiv(V, BV)
791
+ NG = HQ // H
792
+
793
+ if h is None:
794
+ h, _ = chunk_fwd_h(
795
+ k=k,
796
+ v=v,
797
+ g=None,
798
+ gk=None,
799
+ gv=g,
800
+ h0=h0,
801
+ output_final_state=False,
802
+ offsets=offsets,
803
+ head_first=head_first,
804
+ chunk_size=BT,
805
+ states_in_fp32=False
806
+ )
807
+ dh, dh0 = chunk_bwd_dh(
808
+ q=q,
809
+ k=k,
810
+ v=v,
811
+ g=None,
812
+ gk=None,
813
+ gv=g,
814
+ do=do,
815
+ h0=h0,
816
+ dht=dht,
817
+ scale=scale,
818
+ offsets=offsets,
819
+ head_first=head_first,
820
+ chunk_size=BT,
821
+ states_in_fp32=True
822
+ )
823
+ dA = q.new_empty(NV, B, *((HQ, T) if head_first else (T, HQ)), BT)
824
+ grid = (NV, NT * NC * NC, B * HQ)
825
+ chunk_gsa_bwd_k_kernel_dA[grid](
826
+ v,
827
+ g,
828
+ do,
829
+ dA,
830
+ offsets=offsets,
831
+ indices=indices,
832
+ scale=scale,
833
+ T=T,
834
+ B=B,
835
+ HQ=HQ,
836
+ H=H,
837
+ V=V,
838
+ BT=BT,
839
+ BC=BC,
840
+ BV=BV,
841
+ NC=NC,
842
+ NG=NG,
843
+ HEAD_FIRST=head_first
844
+ )
845
+ dA = dA.sum(0, dtype=dA.dtype)
846
+
847
+ A = do.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), BT)
848
+ dq = torch.empty_like(q)
849
+ dk = k.new_empty(B, *((HQ, T) if head_first else (T, HQ)), K)
850
+ dv = v.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V)
851
+ dgv = g.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V, dtype=torch.float)
852
+ grid = (NK, NT, B * HQ)
853
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
854
+ q,
855
+ k,
856
+ v,
857
+ h,
858
+ g,
859
+ A,
860
+ do,
861
+ dh,
862
+ dq,
863
+ dk,
864
+ dv,
865
+ dg,
866
+ dgv,
867
+ dA,
868
+ offsets=offsets,
869
+ indices=indices,
870
+ scale=scale,
871
+ T=T,
872
+ B=B,
873
+ HQ=HQ,
874
+ H=H,
875
+ K=K,
876
+ V=V,
877
+ BT=BT,
878
+ BK=BK,
879
+ BV=BV,
880
+ NG=NG,
881
+ HEAD_FIRST=head_first
882
+ )
883
+ A = A.sum(0, dtype=A.dtype)
884
+ dv = dv.sum(0, dtype=dv.dtype)
885
+ dgv = dgv.sum(0, dtype=dgv.dtype)
886
+
887
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
888
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
889
+ v,
890
+ g,
891
+ o,
892
+ A,
893
+ do,
894
+ dv,
895
+ dg,
896
+ offsets=offsets,
897
+ indices=indices,
898
+ T=T,
899
+ HQ=HQ,
900
+ H=H,
901
+ V=V,
902
+ BT=BT,
903
+ BC=BC,
904
+ BV=BV,
905
+ NC=NC,
906
+ NG=NG,
907
+ HEAD_FIRST=head_first,
908
+ num_warps=4,
909
+ num_stages=2
910
+ )
911
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets, indices=indices, head_first=head_first))
912
+
913
+ return dq, dk, dv, dg, dh0
914
+
915
+
916
+ def chunk_gsa_fwd(
917
+ q: torch.Tensor,
918
+ k: torch.Tensor,
919
+ v: torch.Tensor,
920
+ s: torch.Tensor,
921
+ g: torch.Tensor,
922
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
923
+ output_final_state: bool = False,
924
+ scale: float = 1.,
925
+ offsets: Optional[torch.LongTensor] = None,
926
+ indices: Optional[torch.LongTensor] = None,
927
+ head_first: bool = True,
928
+ chunk_size: int = 64
929
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
930
+ hk0, hv0 = None, None
931
+ if initial_state is not None:
932
+ hk0, hv0 = initial_state
933
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
934
+ q=q,
935
+ k=k,
936
+ v=s,
937
+ g=g,
938
+ h0=hk0,
939
+ output_final_state=output_final_state,
940
+ scale=scale,
941
+ offsets=offsets,
942
+ indices=indices,
943
+ head_first=head_first,
944
+ chunk_size=chunk_size
945
+ )
946
+
947
+ # p is kept in fp32 for safe softmax backward
948
+ p = softmax_fwd(ok, dtype=torch.float)
949
+
950
+ qv = p.to(q.dtype)
951
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
952
+ q=qv,
953
+ k=s,
954
+ v=v,
955
+ g=g,
956
+ scale=1.,
957
+ initial_state=hv0,
958
+ output_final_state=output_final_state,
959
+ offsets=offsets,
960
+ indices=indices,
961
+ head_first=head_first,
962
+ chunk_size=chunk_size
963
+ )
964
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
965
+
966
+
967
+ def chunk_gsa_bwd(
968
+ q: torch.Tensor,
969
+ k: torch.Tensor,
970
+ v: torch.Tensor,
971
+ s: torch.Tensor,
972
+ g: torch.Tensor,
973
+ ok: torch.Tensor,
974
+ p: torch.Tensor,
975
+ A: Tuple[torch.Tensor, torch.Tensor],
976
+ h: Tuple[torch.Tensor, torch.Tensor],
977
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
978
+ scale: float,
979
+ do: torch.Tensor,
980
+ dht: Tuple[torch.Tensor, torch.Tensor],
981
+ offsets: Optional[torch.LongTensor] = None,
982
+ indices: Optional[torch.LongTensor] = None,
983
+ head_first: bool = True,
984
+ chunk_size: int = 64
985
+ ):
986
+ hk0, hv0 = None, None
987
+ if initial_state is not None:
988
+ hk0, hv0 = initial_state
989
+
990
+ _, Av = A
991
+ hk, hv = h
992
+ dhkt, dhvt = dht
993
+
994
+ qv = p.to(q.dtype)
995
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
996
+ q=qv,
997
+ k=s,
998
+ v=v,
999
+ g=g,
1000
+ h0=hv0,
1001
+ h=hv,
1002
+ A=Av,
1003
+ do=do,
1004
+ dht=dhvt,
1005
+ dg=None,
1006
+ scale=1.,
1007
+ offsets=offsets,
1008
+ indices=indices,
1009
+ head_first=head_first,
1010
+ chunk_size=chunk_size
1011
+ )
1012
+
1013
+ # softmax gradient, equivalent to:
1014
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
1015
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
1016
+
1017
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
1018
+ q=q,
1019
+ k=k,
1020
+ v=s,
1021
+ g=g,
1022
+ h0=hk0,
1023
+ h=hk,
1024
+ o=ok,
1025
+ do=dok,
1026
+ dht=dhkt,
1027
+ dg=dg,
1028
+ scale=scale,
1029
+ offsets=offsets,
1030
+ indices=indices,
1031
+ head_first=head_first,
1032
+ chunk_size=chunk_size
1033
+ )
1034
+
1035
+ ds = dsv.add_(dsk)
1036
+ if q.shape[1] != k.shape[1]:
1037
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
1038
+ dg = dg.to(s.dtype)
1039
+ return dq, dk, dv, ds, dg, dhk0, dhv0
1040
+
1041
+
1042
+ class ChunkGSAFunction(torch.autograd.Function):
1043
+
1044
+ @staticmethod
1045
+ @input_guard
1046
+ def forward(
1047
+ ctx,
1048
+ q: torch.Tensor,
1049
+ k: torch.Tensor,
1050
+ v: torch.Tensor,
1051
+ s: torch.Tensor,
1052
+ g: torch.Tensor,
1053
+ scale: float,
1054
+ hk0: Optional[torch.Tensor],
1055
+ hv0: Optional[torch.Tensor],
1056
+ output_final_state: bool,
1057
+ checkpoint_level: int,
1058
+ offsets: Optional[torch.LongTensor],
1059
+ head_first: bool = True
1060
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1061
+ T = q.shape[2] if head_first else q.shape[1]
1062
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1063
+
1064
+ # 2-d indices denoting the offsets of chunks in each sequence
1065
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1066
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1067
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1068
+ indices = None
1069
+ if offsets is not None:
1070
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
1071
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
1072
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1073
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
1074
+ q=q,
1075
+ k=k,
1076
+ v=v,
1077
+ s=s,
1078
+ g=g,
1079
+ initial_state=(hk0, hv0),
1080
+ output_final_state=output_final_state,
1081
+ scale=scale,
1082
+ offsets=offsets,
1083
+ indices=indices,
1084
+ head_first=head_first,
1085
+ chunk_size=chunk_size
1086
+ )
1087
+
1088
+ if checkpoint_level >= 1:
1089
+ del g
1090
+ g = g_org
1091
+ if checkpoint_level > 1:
1092
+ del hk
1093
+ del hv
1094
+ hk, hv = None, None
1095
+ else:
1096
+ hk0, hv0 = None, None
1097
+
1098
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
1099
+ ctx.checkpoint_level = checkpoint_level
1100
+ ctx.scale = scale
1101
+ ctx.offsets = offsets
1102
+ ctx.indices = indices
1103
+ ctx.head_first = head_first
1104
+ ctx.chunk_size = chunk_size
1105
+ return ov, hkt, hvt
1106
+
1107
+ @staticmethod
1108
+ @input_guard
1109
+ def backward(ctx, dov, dhkt=None, dhvt=None):
1110
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
1111
+ scale = ctx.scale
1112
+ offsets = ctx.offsets
1113
+ indices = ctx.indices
1114
+ head_first = ctx.head_first
1115
+ chunk_size = ctx.chunk_size
1116
+
1117
+ if ctx.checkpoint_level >= 1:
1118
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1119
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
1120
+ q=q,
1121
+ k=k,
1122
+ v=v,
1123
+ s=s,
1124
+ g=g,
1125
+ ok=ok,
1126
+ p=p,
1127
+ A=(None, Av),
1128
+ h=(hk, hv),
1129
+ initial_state=(hk0, hv0),
1130
+ scale=scale,
1131
+ do=dov,
1132
+ dht=(dhkt, dhvt),
1133
+ offsets=offsets,
1134
+ indices=indices,
1135
+ head_first=head_first,
1136
+ chunk_size=chunk_size
1137
+ )
1138
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
1139
+
1140
+
1141
+ @torch.compiler.disable
1142
+ def chunk_gsa(
1143
+ q: torch.Tensor,
1144
+ k: torch.Tensor,
1145
+ v: torch.Tensor,
1146
+ s: torch.Tensor,
1147
+ g: Optional[torch.Tensor] = None,
1148
+ scale: Optional[int] = None,
1149
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1150
+ output_final_state: Optional[bool] = False,
1151
+ checkpoint_level: Optional[int] = 2,
1152
+ cu_seqlens: Optional[torch.LongTensor] = None,
1153
+ head_first: Optional[bool] = True
1154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1155
+ r"""
1156
+ Args:
1157
+ q (torch.Tensor):
1158
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
1159
+ k (torch.Tensor):
1160
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1161
+ GQA is performed if `H` is not equal to `HQ`.
1162
+ v (torch.Tensor):
1163
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1164
+ s (torch.Tensor):
1165
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`.
1166
+ g (torch.Tensor):
1167
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1168
+ If not provided, this function is equivalent to vanilla ABC.
1169
+ scale (Optional[int]):
1170
+ Scale factor for attention scores.
1171
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1172
+ initial_state (Optional[Tuple[torch.Tensor]]):
1173
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1174
+ For equal-length input sequences, `N` equals the batch size `B`.
1175
+ Default: `None`.
1176
+ output_final_state (Optional[bool]):
1177
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1178
+ Default: `False`.
1179
+ checkpoint_level (Optional[int]):
1180
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1181
+ Default: `2`:
1182
+ - Level `0`: no memory saved, no recomputation.
1183
+ - Level `1`: recompute the fp32 cumulative values during backward.
1184
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1185
+ cu_seqlens (torch.LongTensor):
1186
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1187
+ consistent with the FlashAttention API.
1188
+ head_first (Optional[bool]):
1189
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1190
+ Default: `True`.
1191
+
1192
+ Returns:
1193
+ o (torch.Tensor):
1194
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1195
+ final_state (Tuple[torch.Tensor]):
1196
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1197
+ `None` otherwise.
1198
+
1199
+ Examples::
1200
+ >>> import torch
1201
+ >>> import torch.nn.functional as F
1202
+ >>> from einops import rearrange
1203
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1204
+ # inputs with equal lengths
1205
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1206
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1207
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1208
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1209
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1210
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1211
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1212
+ >>> o, (hk, hv) = chunk_gsa(q, k, v, s, g,
1213
+ initial_state=h0,
1214
+ output_final_state=True,
1215
+ head_first=False)
1216
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1217
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1218
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1219
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1220
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g,
1221
+ initial_state=h0,
1222
+ output_final_state=True,
1223
+ cu_seqlens=cu_seqlens,
1224
+ head_first=False)
1225
+ >>> assert o.allclose(o_var.view(o.shape))
1226
+ >>> assert hk.allclose(hk_var)
1227
+ >>> assert hv.allclose(hv_var)
1228
+ """
1229
+ if cu_seqlens is not None:
1230
+ if q.shape[0] != 1:
1231
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1232
+ f"Please flatten variable-length inputs before processing.")
1233
+ if head_first:
1234
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1235
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1236
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1237
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.")
1238
+ assert checkpoint_level in [0, 1, 2]
1239
+ if g is None:
1240
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1241
+ z = s.float().logcumsumexp(2)
1242
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
1243
+ s = torch.exp(s - z).to(k.dtype)
1244
+ if scale is None:
1245
+ scale = q.shape[-1] ** -0.5
1246
+
1247
+ hk0, hv0 = None, None
1248
+ if initial_state is not None:
1249
+ hk0, hv0 = initial_state
1250
+ o, *final_state = ChunkGSAFunction.apply(
1251
+ q,
1252
+ k,
1253
+ v,
1254
+ s,
1255
+ g,
1256
+ scale,
1257
+ hk0,
1258
+ hv0,
1259
+ output_final_state,
1260
+ checkpoint_level,
1261
+ cu_seqlens,
1262
+ head_first
1263
+ )
1264
+ return o, final_state
fla/ops/gsa/naive.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (315 Bytes). View file
 
fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/ops/linear_attn/__pycache__/naive.cpython-312.pyc ADDED
Binary file (1.96 kB). View file
 
fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc ADDED
Binary file (554 Bytes). View file
 
fla/ops/nsa/__pycache__/naive.cpython-312.pyc ADDED
Binary file (5.82 kB). View file
 
fla/ops/nsa/parallel.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices
13
+ from fla.ops.nsa.utils import _bitonic_merge
14
+ from fla.ops.utils import mean_pooling
15
+ from fla.ops.utils.op import exp, log
16
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
17
+
18
+ try:
19
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
20
+ except ImportError:
21
+ warnings.warn(
22
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
23
+ category=ImportWarning
24
+ )
25
+ flash_attn_func = None
26
+
27
+
28
+ @triton.heuristics({
29
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
30
+ })
31
+ @triton.autotune(
32
+ configs=[
33
+ triton.Config({}, num_warps=num_warps)
34
+ for num_warps in [1, 2, 4]
35
+ ],
36
+ key=['BS', 'BK', 'BV'],
37
+ )
38
+ @triton.jit
39
+ def parallel_nsa_compression_fwd_kernel(
40
+ q,
41
+ k,
42
+ v,
43
+ o,
44
+ lse,
45
+ scale,
46
+ offsets,
47
+ token_indices,
48
+ chunk_offsets,
49
+ T,
50
+ H: tl.constexpr,
51
+ HQ: tl.constexpr,
52
+ G: tl.constexpr,
53
+ K: tl.constexpr,
54
+ V: tl.constexpr,
55
+ BC: tl.constexpr,
56
+ BS: tl.constexpr,
57
+ BK: tl.constexpr,
58
+ BV: tl.constexpr,
59
+ USE_OFFSETS: tl.constexpr,
60
+ ):
61
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_b, i_h = i_bh // H, i_bh % H
63
+
64
+ if USE_OFFSETS:
65
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
66
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
67
+ T = eos - bos
68
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
69
+ else:
70
+ bos, eos = i_b * T, i_b * T + T
71
+ boc = i_b * tl.cdiv(T, BS)
72
+
73
+ p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
74
+
75
+ # the Q block is kept in the shared memory throughout the whole kernel
76
+ # [G, BK]
77
+ b_q = tl.load(p_q, boundary_check=(0, 1))
78
+ b_q = (b_q * scale).to(b_q.dtype)
79
+
80
+ # the number of compression representations in total
81
+ TC = tl.cdiv(T, BS)
82
+ # the number of compression representations required to iterate over
83
+ # incomplete compression blocks are not included
84
+ NC = (i_t + 1) // BS
85
+
86
+ p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
87
+ # [G, BV]
88
+ b_o = tl.zeros([G, BV], dtype=tl.float32)
89
+ # max scores for the current block
90
+ b_m = tl.full([G], float('-inf'), dtype=tl.float32)
91
+ # lse = log(acc) + m
92
+ b_acc = tl.zeros([G], dtype=tl.float32)
93
+
94
+ for i_c in range(0, NC, BC):
95
+ o_c = i_c + tl.arange(0, BC)
96
+
97
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
98
+ p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_k = tl.load(p_k, boundary_check=(0, 1))
101
+ # [BC, BV]
102
+ b_v = tl.load(p_v, boundary_check=(0, 1))
103
+ # [G, BC]
104
+ b_s = tl.dot(b_q, b_k)
105
+ b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf'))
106
+
107
+ # [G]
108
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
109
+ b_r = exp(b_mp - b_m)
110
+ # [G, BC]
111
+ b_p = exp(b_s - b_m[:, None])
112
+ # [G]
113
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
114
+
115
+ # [G, BV]
116
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
117
+
118
+ b_mp = b_m
119
+ if NC == 0:
120
+ b_lse = tl.zeros([G], dtype=tl.float32)
121
+ else:
122
+ b_o = b_o / b_acc[:, None]
123
+ b_lse = b_m + log(b_acc)
124
+
125
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
126
+ if i_v == 0:
127
+ tl.store(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G), b_lse.to(lse.dtype.element_ty))
128
+
129
+
130
+ @triton.heuristics({
131
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
132
+ })
133
+ @triton.autotune(
134
+ configs=[
135
+ triton.Config({}, num_warps=num_warps)
136
+ for num_warps in [1, 2, 4]
137
+ ],
138
+ key=['BS', 'BK', 'BV'],
139
+ )
140
+ @triton.jit(do_not_specialize=['T'])
141
+ def parallel_nsa_compression_bwd_kernel_dq(
142
+ q,
143
+ k,
144
+ v,
145
+ lse,
146
+ delta,
147
+ do,
148
+ dq,
149
+ scale,
150
+ offsets,
151
+ token_indices,
152
+ chunk_offsets,
153
+ T,
154
+ B: tl.constexpr,
155
+ H: tl.constexpr,
156
+ HQ: tl.constexpr,
157
+ G: tl.constexpr,
158
+ K: tl.constexpr,
159
+ V: tl.constexpr,
160
+ BC: tl.constexpr,
161
+ BS: tl.constexpr,
162
+ BK: tl.constexpr,
163
+ BV: tl.constexpr,
164
+ USE_OFFSETS: tl.constexpr
165
+ ):
166
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
167
+ i_b, i_h = i_bh // H, i_bh % H
168
+
169
+ if USE_OFFSETS:
170
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
171
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
172
+ T = eos - bos
173
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
174
+ else:
175
+ bos, eos = i_b * T, i_b * T + T
176
+ boc = i_b * tl.cdiv(T, BS)
177
+
178
+ q += (bos + i_t) * HQ*K
179
+ do += (bos + i_t) * HQ*V
180
+ lse += (bos + i_t) * HQ
181
+ delta += (bos + i_t) * HQ
182
+ dq += (i_v * B * T + bos + i_t) * HQ*K
183
+
184
+ p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
185
+ p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
186
+
187
+ # [G, BK]
188
+ b_q = tl.load(p_q, boundary_check=(0, 1))
189
+ b_q = (b_q * scale).to(b_q.dtype)
190
+
191
+ p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
192
+ p_lse = lse + i_h * G + tl.arange(0, G)
193
+ p_delta = delta + i_h * G + tl.arange(0, G)
194
+
195
+ # the number of compression representations in total
196
+ TC = tl.cdiv(T, BS)
197
+ # the number of compression representations required to iterate over
198
+ # incomplete compression blocks are not included
199
+ NC = (i_t + 1) // BS
200
+
201
+ # [G, BV]
202
+ b_do = tl.load(p_do, boundary_check=(0, 1))
203
+ # [G]
204
+ b_lse = tl.load(p_lse)
205
+ b_delta = tl.load(p_delta)
206
+
207
+ # [G, BK]
208
+ b_dq = tl.zeros([G, BK], dtype=tl.float32)
209
+ for i_c in range(0, NC, BC):
210
+ o_c = i_c + tl.arange(0, BC)
211
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
212
+ p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1))
213
+ # [BK, BC]
214
+ b_k = tl.load(p_k, boundary_check=(0, 1))
215
+ # [BV, BC]
216
+ b_v = tl.load(p_v, boundary_check=(0, 1))
217
+
218
+ # [G, BC]
219
+ b_s = tl.dot(b_q, b_k)
220
+ b_p = exp(b_s - b_lse[:, None])
221
+ b_p = tl.where((o_c < NC)[None, :], b_p, 0)
222
+
223
+ # [G, BV] @ [BV, BC] -> [G, BC]
224
+ b_dp = tl.dot(b_do, b_v)
225
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
226
+ # [G, BC] @ [BC, BK] -> [G, BK]
227
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
228
+ b_dq *= scale
229
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
230
+
231
+
232
+ @triton.heuristics({
233
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
234
+ })
235
+ @triton.autotune(
236
+ configs=[
237
+ triton.Config({}, num_warps=num_warps)
238
+ for num_warps in [1, 2, 4]
239
+ ],
240
+ key=['BS', 'BK', 'BV'],
241
+ )
242
+ @triton.jit(do_not_specialize=['T'])
243
+ def parallel_nsa_compression_bwd_kernel_dkv(
244
+ q,
245
+ k,
246
+ v,
247
+ lse,
248
+ delta,
249
+ do,
250
+ dk,
251
+ dv,
252
+ offsets,
253
+ chunk_indices,
254
+ chunk_offsets,
255
+ scale,
256
+ T,
257
+ B: tl.constexpr,
258
+ H: tl.constexpr,
259
+ HQ: tl.constexpr,
260
+ G: tl.constexpr,
261
+ K: tl.constexpr,
262
+ V: tl.constexpr,
263
+ BC: tl.constexpr,
264
+ BS: tl.constexpr,
265
+ BK: tl.constexpr,
266
+ BV: tl.constexpr,
267
+ USE_OFFSETS: tl.constexpr
268
+ ):
269
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
270
+ i_b, i_h = i_bh // H, i_bh % H
271
+
272
+ if USE_OFFSETS:
273
+ i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32)
274
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
275
+ T = eos - bos
276
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
277
+ else:
278
+ bos, eos = i_b * T, i_b * T + T
279
+ boc = i_b * tl.cdiv(T, BS)
280
+
281
+ # the number of compression representations in total
282
+ TC = tl.cdiv(T, BS)
283
+
284
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
285
+ p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
286
+ p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
287
+ p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
288
+
289
+ # [BC, BK]
290
+ b_k = tl.load(p_k, boundary_check=(0, 1))
291
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
292
+ # [BC, BV]
293
+ b_v = tl.load(p_v, boundary_check=(0, 1))
294
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
295
+
296
+ for i in range(i_c * BC * BS, T):
297
+ o_c = i_c * BC + tl.arange(0, BC)
298
+
299
+ p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
300
+ # [G, BK]
301
+ b_q = tl.load(p_q, boundary_check=(0, 1))
302
+ b_q = (b_q * scale).to(b_q.dtype)
303
+
304
+ p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
305
+ p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G)
306
+ p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G)
307
+ # [G, BV]
308
+ b_do = tl.load(p_do, boundary_check=(0, 1))
309
+ # [G]
310
+ b_lse = tl.load(p_lse)
311
+ b_delta = tl.load(p_delta)
312
+ # [BC, G]
313
+ b_s = tl.dot(b_k, tl.trans(b_q))
314
+ b_p = exp(b_s - b_lse[None, :])
315
+ b_p = tl.where((i >= max(0, (o_c + 1) * BS - 1))[:, None], b_p, 0)
316
+ # [BC, G] @ [G, BV] -> [BC, BV]
317
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
318
+ # [BC, BV] @ [BV, G] -> [BC, G]
319
+ b_dp = tl.dot(b_v, tl.trans(b_do))
320
+ # [BC, G]
321
+ b_ds = b_p * (b_dp - b_delta[None, :])
322
+ # [BC, G] @ [G, BK] -> [BC, BK]
323
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
324
+
325
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
326
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
327
+
328
+
329
+ @triton.heuristics({
330
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
331
+ })
332
+ @triton.autotune(
333
+ configs=[
334
+ triton.Config({}, num_warps=num_warps)
335
+ for num_warps in [1, 2, 4]
336
+ ],
337
+ key=['BS', 'BK'],
338
+ )
339
+ @triton.jit
340
+ def parallel_nsa_kernel_topk(
341
+ q,
342
+ k,
343
+ lse,
344
+ scale,
345
+ block_indices,
346
+ offsets,
347
+ token_indices,
348
+ chunk_offsets,
349
+ T,
350
+ H: tl.constexpr,
351
+ HQ: tl.constexpr,
352
+ G: tl.constexpr,
353
+ K: tl.constexpr,
354
+ S: tl.constexpr,
355
+ BC: tl.constexpr,
356
+ BS: tl.constexpr,
357
+ BK: tl.constexpr,
358
+ USE_OFFSETS: tl.constexpr,
359
+ ):
360
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
361
+ i_b, i_h = i_bh // H, i_bh % H
362
+
363
+ if USE_OFFSETS:
364
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
365
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
366
+ T = eos - bos
367
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
368
+ else:
369
+ bos, eos = i_b * T, i_b * T + T
370
+ boc = i_b * tl.cdiv(T, BS)
371
+
372
+ p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
373
+
374
+ # the Q block is kept in the shared memory throughout the whole kernel
375
+ # [G, BK]
376
+ b_q = tl.load(p_q, boundary_check=(0, 1))
377
+ b_q = (b_q * scale).to(b_q.dtype)
378
+
379
+ # the number of compression representations in total
380
+ TC = tl.cdiv(T, BS)
381
+ # the number of compression representations required to iterate over
382
+ # incomplete compression blocks are not included
383
+ NC = (i_t + 1) // BS
384
+ ################################
385
+ # 1. lse computation
386
+ ################################
387
+ if lse is not None:
388
+ b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G))
389
+ else:
390
+ # max scores for the current block
391
+ b_m = tl.full([G], float('-inf'), dtype=tl.float32)
392
+ # lse = log(acc) + m
393
+ b_acc = tl.zeros([G], dtype=tl.float32)
394
+ for i_c in range(0, NC, BC):
395
+ o_c = i_c + tl.arange(0, BC)
396
+
397
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
398
+ # [BK, BC]
399
+ b_k = tl.load(p_k, boundary_check=(0, 1))
400
+
401
+ # [G, BC]
402
+ b_s = tl.dot(b_q, b_k)
403
+ b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf'))
404
+
405
+ # [G]
406
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
407
+ b_r = exp(b_mp - b_m)
408
+ # [G, BC]
409
+ b_p = exp(b_s - b_m[:, None])
410
+ # [G]
411
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
412
+
413
+ b_mp = b_m
414
+ if NC == 0:
415
+ b_lse = tl.zeros([G], dtype=tl.float32)
416
+ else:
417
+ b_lse = b_m + log(b_acc)
418
+
419
+ ################################
420
+ # 2. topk selection
421
+ ################################
422
+ # [BC]
423
+ b_i = tl.full([BC], -1, dtype=tl.float32)
424
+ o_i = tl.zeros([BC], dtype=tl.int32)
425
+ m_i = tl.arange(0, BC) < BC//2
426
+ for i_c in range(0, i_t // BS + 1, BC):
427
+ o_c = i_c + tl.arange(0, BC)
428
+
429
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
430
+ # [BK, BC]
431
+ b_k = tl.load(p_k, boundary_check=(0, 1))
432
+ # [G, BC]
433
+ b_s = tl.dot(b_q, b_k)
434
+ b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf'))
435
+ # [G, BC]
436
+ b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), exp(b_s - b_lse[:, None]))
437
+ # the importance scores of the current block
438
+ # [BC]
439
+ b_i, b_ip = tl.sum(b_p, 0), b_i
440
+ o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i
441
+
442
+ n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0])
443
+ for i in tl.static_range(1, n_dims):
444
+ b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims)
445
+
446
+ if i_c != 0:
447
+ b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims)
448
+ b_i_new = b_ip * m_i + b_i * (1 - m_i)
449
+ o_i_new = o_ip * m_i + o_i * (1 - m_i)
450
+ b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims)
451
+ else:
452
+ b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims)
453
+
454
+ m_top = tl.arange(0, BC//S) == 0
455
+ b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0)
456
+
457
+ p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,))
458
+ tl.store(p_b, b_top.to(p_b.dtype.element_ty))
459
+
460
+
461
+ @triton.heuristics({
462
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
463
+ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
464
+ })
465
+ @triton.autotune(
466
+ configs=[
467
+ triton.Config({}, num_warps=num_warps)
468
+ for num_warps in [1, 2, 4]
469
+ ],
470
+ key=['BS', 'BK', 'BV'],
471
+ )
472
+ @triton.jit
473
+ def parallel_nsa_fwd_kernel(
474
+ q,
475
+ k,
476
+ v,
477
+ o,
478
+ lse,
479
+ scale,
480
+ block_indices,
481
+ block_counts,
482
+ offsets,
483
+ token_indices,
484
+ T,
485
+ H: tl.constexpr,
486
+ HQ: tl.constexpr,
487
+ G: tl.constexpr,
488
+ K: tl.constexpr,
489
+ V: tl.constexpr,
490
+ S: tl.constexpr,
491
+ BS: tl.constexpr,
492
+ BK: tl.constexpr,
493
+ BV: tl.constexpr,
494
+ USE_OFFSETS: tl.constexpr,
495
+ USE_BLOCK_COUNTS: tl.constexpr
496
+ ):
497
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
498
+ i_b, i_h = i_bh // H, i_bh % H
499
+
500
+ if USE_OFFSETS:
501
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
502
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
503
+ T = eos - bos
504
+ else:
505
+ bos, eos = i_b * T, i_b * T + T
506
+
507
+ k += (bos * H + i_h) * K
508
+ v += (bos * H + i_h) * V
509
+ block_indices += (bos + i_t) * H*S + i_h * S
510
+
511
+ if USE_BLOCK_COUNTS:
512
+ NS = tl.load(block_counts + (bos + i_t) * H + i_h)
513
+ else:
514
+ NS = S
515
+
516
+ p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
517
+ # the Q block is kept in the shared memory throughout the whole kernel
518
+ # [G, BK]
519
+ b_q = tl.load(p_q, boundary_check=(0, 1))
520
+ b_q = (b_q * scale).to(b_q.dtype)
521
+
522
+ p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
523
+ p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
524
+ # [G, BV]
525
+ b_o = tl.zeros([G, BV], dtype=tl.float32)
526
+
527
+ b_m = tl.full([G], float('-inf'), dtype=tl.float32)
528
+ b_acc = tl.zeros([G], dtype=tl.float32)
529
+ for i in range(NS):
530
+ i_s = tl.load(block_indices + i).to(tl.int32) * BS
531
+ if i_s <= i_t and i_s >= 0:
532
+ p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
533
+ p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
534
+ # [BK, BS]
535
+ b_k = tl.load(p_k, boundary_check=(0, 1))
536
+ # [BS, BV]
537
+ b_v = tl.load(p_v, boundary_check=(0, 1))
538
+ # [G, BS]
539
+ b_s = tl.dot(b_q, b_k)
540
+ b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf'))
541
+
542
+ # [G]
543
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
544
+ b_r = exp(b_mp - b_m)
545
+ # [G, BS]
546
+ b_p = exp(b_s - b_m[:, None])
547
+ # [G]
548
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
549
+ # [G, BV]
550
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
551
+
552
+ b_mp = b_m
553
+ b_o = b_o / b_acc[:, None]
554
+ b_m += log(b_acc)
555
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
556
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty))
557
+
558
+
559
+ @triton.heuristics({
560
+ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
561
+ })
562
+ @triton.jit
563
+ def parallel_nsa_kernel_mask(
564
+ block_indices,
565
+ block_counts,
566
+ block_mask,
567
+ T: tl.constexpr,
568
+ H: tl.constexpr,
569
+ S: tl.constexpr,
570
+ BS: tl.constexpr,
571
+ NS: tl.constexpr,
572
+ USE_BLOCK_COUNTS: tl.constexpr
573
+ ):
574
+ i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2)
575
+ i_h, i_s = i_hs // S, i_hs % S
576
+
577
+ b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s)
578
+ if USE_BLOCK_COUNTS:
579
+ b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h)
580
+ else:
581
+ b_m = b_i * BS <= i_t
582
+
583
+ if b_i < NS and b_i >= 0:
584
+ tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty))
585
+
586
+
587
+ @triton.jit
588
+ def parallel_nsa_bwd_kernel_preprocess(
589
+ o,
590
+ do,
591
+ delta,
592
+ B: tl.constexpr,
593
+ V: tl.constexpr
594
+ ):
595
+ i_n = tl.program_id(0)
596
+ o_d = tl.arange(0, B)
597
+ m_d = o_d < V
598
+
599
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
600
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
601
+ b_delta = tl.sum(b_o * b_do)
602
+
603
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
604
+
605
+
606
+ @triton.heuristics({
607
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
608
+ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
609
+ })
610
+ @triton.autotune(
611
+ configs=[
612
+ triton.Config({}, num_warps=num_warps)
613
+ for num_warps in [1, 2, 4]
614
+ ],
615
+ key=['BS', 'BK', 'BV'],
616
+ )
617
+ @triton.jit(do_not_specialize=['T'])
618
+ def parallel_nsa_bwd_kernel_dq(
619
+ q,
620
+ k,
621
+ v,
622
+ lse,
623
+ delta,
624
+ do,
625
+ dq,
626
+ scale,
627
+ block_indices,
628
+ block_counts,
629
+ offsets,
630
+ token_indices,
631
+ T,
632
+ B: tl.constexpr,
633
+ H: tl.constexpr,
634
+ HQ: tl.constexpr,
635
+ G: tl.constexpr,
636
+ K: tl.constexpr,
637
+ V: tl.constexpr,
638
+ S: tl.constexpr,
639
+ BS: tl.constexpr,
640
+ BK: tl.constexpr,
641
+ BV: tl.constexpr,
642
+ USE_OFFSETS: tl.constexpr,
643
+ USE_BLOCK_COUNTS: tl.constexpr
644
+ ):
645
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
646
+ i_b, i_h = i_bh // H, i_bh % H
647
+
648
+ if USE_OFFSETS:
649
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
650
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
651
+ T = eos - bos
652
+ else:
653
+ bos, eos = i_b * T, i_b * T + T
654
+
655
+ q += (bos + i_t) * HQ*K
656
+ do += (bos + i_t) * HQ*V
657
+ lse += (bos + i_t) * HQ
658
+ delta += (bos + i_t) * HQ
659
+ dq += (i_v * B * T + bos + i_t) * HQ*K
660
+ block_indices += (bos + i_t) * H*S + i_h * S
661
+
662
+ if USE_BLOCK_COUNTS:
663
+ NS = tl.load(block_counts + (bos + i_t) * H + i_h)
664
+ else:
665
+ NS = S
666
+
667
+ k += (bos * H + i_h) * K
668
+ v += (bos * H + i_h) * V
669
+
670
+ p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
671
+ p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
672
+
673
+ # [G, BK]
674
+ b_q = tl.load(p_q, boundary_check=(0, 1))
675
+ b_q = (b_q * scale).to(b_q.dtype)
676
+
677
+ p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
678
+ p_lse = lse + i_h * G + tl.arange(0, G)
679
+ p_delta = delta + i_h * G + tl.arange(0, G)
680
+
681
+ # [G, BV]
682
+ b_do = tl.load(p_do, boundary_check=(0, 1))
683
+ # [G]
684
+ b_lse = tl.load(p_lse)
685
+ b_delta = tl.load(p_delta)
686
+
687
+ # [G, BK]
688
+ b_dq = tl.zeros([G, BK], dtype=tl.float32)
689
+ for i in range(NS):
690
+ i_s = tl.load(block_indices + i).to(tl.int32) * BS
691
+ if i_s <= i_t and i_s >= 0:
692
+ p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
693
+ p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
694
+ # [BK, BS]
695
+ b_k = tl.load(p_k, boundary_check=(0, 1))
696
+ # [BV, BS]
697
+ b_v = tl.load(p_v, boundary_check=(0, 1))
698
+
699
+ # [G, BS]
700
+ b_s = tl.dot(b_q, b_k)
701
+ b_p = exp(b_s - b_lse[:, None])
702
+ b_p = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p, 0)
703
+
704
+ # [G, BV] @ [BV, BS] -> [G, BS]
705
+ b_dp = tl.dot(b_do, b_v)
706
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
707
+ # [G, BS] @ [BS, BK] -> [G, BK]
708
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
709
+ b_dq *= scale
710
+
711
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
712
+
713
+
714
+ @triton.heuristics({
715
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
716
+ })
717
+ @triton.autotune(
718
+ configs=[
719
+ triton.Config({}, num_warps=num_warps)
720
+ for num_warps in [1, 2, 4]
721
+ ],
722
+ key=['BS', 'BK', 'BV'],
723
+ )
724
+ @triton.jit(do_not_specialize=['T'])
725
+ def parallel_nsa_bwd_kernel_dkv(
726
+ q,
727
+ k,
728
+ v,
729
+ lse,
730
+ delta,
731
+ do,
732
+ dk,
733
+ dv,
734
+ block_mask,
735
+ offsets,
736
+ chunk_indices,
737
+ scale,
738
+ T,
739
+ B: tl.constexpr,
740
+ H: tl.constexpr,
741
+ HQ: tl.constexpr,
742
+ G: tl.constexpr,
743
+ K: tl.constexpr,
744
+ V: tl.constexpr,
745
+ M: tl.constexpr,
746
+ BS: tl.constexpr,
747
+ BK: tl.constexpr,
748
+ BV: tl.constexpr,
749
+ USE_OFFSETS: tl.constexpr
750
+ ):
751
+ i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
752
+ i_b, i_h = i_bh // H, i_bh % H
753
+
754
+ if USE_OFFSETS:
755
+ i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32)
756
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
757
+ T = eos - bos
758
+ else:
759
+ bos, eos = i_b * T, i_b * T + T
760
+
761
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
762
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
763
+ p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
764
+ p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
765
+
766
+ # [BS, BK]
767
+ b_k = tl.load(p_k, boundary_check=(0, 1))
768
+ b_dk = tl.zeros([BS, BK], dtype=tl.float32)
769
+ # [BS, BV]
770
+ b_v = tl.load(p_v, boundary_check=(0, 1))
771
+ b_dv = tl.zeros([BS, BV], dtype=tl.float32)
772
+
773
+ for i in range(i_s * BS, T):
774
+ b_m = tl.load(block_mask + (bos + i) * H*M + i_h * M + i_s)
775
+ if b_m:
776
+ p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
777
+ # [G, BK]
778
+ b_q = tl.load(p_q, boundary_check=(0, 1))
779
+ b_q = (b_q * scale).to(b_q.dtype)
780
+
781
+ p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
782
+ p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G)
783
+ p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G)
784
+ # [G, BV]
785
+ b_do = tl.load(p_do, boundary_check=(0, 1))
786
+ # [G]
787
+ b_lse = tl.load(p_lse)
788
+ b_delta = tl.load(p_delta)
789
+ # [BS, G]
790
+ b_s = tl.dot(b_k, tl.trans(b_q))
791
+ b_p = exp(b_s - b_lse[None, :])
792
+ b_p = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p, 0)
793
+ # [BS, G] @ [G, BV] -> [BS, BV]
794
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
795
+ # [BS, BV] @ [BV, G] -> [BS, G]
796
+ b_dp = tl.dot(b_v, tl.trans(b_do))
797
+ # [BS, G]
798
+ b_ds = b_p * (b_dp - b_delta[None, :])
799
+ # [BS, G] @ [G, BK] -> [BS, BK]
800
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
801
+
802
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
803
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
804
+
805
+
806
+ def parallel_nsa_compression_fwd(
807
+ q: torch.Tensor,
808
+ k: torch.Tensor,
809
+ v: torch.Tensor,
810
+ block_size: int,
811
+ scale: float,
812
+ offsets: Optional[torch.LongTensor] = None,
813
+ token_indices: Optional[torch.LongTensor] = None,
814
+ ):
815
+ B, T, HQ, K, V = *q.shape, v.shape[-1]
816
+ H = k.shape[2]
817
+ G = HQ // H
818
+ BC = BS = block_size
819
+ if check_shared_mem('hopper', q.device.index):
820
+ BK = min(256, triton.next_power_of_2(K))
821
+ BV = min(256, triton.next_power_of_2(V))
822
+ else:
823
+ BK = min(128, triton.next_power_of_2(K))
824
+ BV = min(128, triton.next_power_of_2(V))
825
+ NK = triton.cdiv(K, BK)
826
+ NV = triton.cdiv(V, BV)
827
+ assert NK == 1, "The key dimension can not be larger than 256"
828
+
829
+ chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None
830
+
831
+ grid = (T, NV, B * H)
832
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
833
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
834
+
835
+ parallel_nsa_compression_fwd_kernel[grid](
836
+ q=q,
837
+ k=k,
838
+ v=v,
839
+ o=o,
840
+ lse=lse,
841
+ scale=scale,
842
+ offsets=offsets,
843
+ token_indices=token_indices,
844
+ chunk_offsets=chunk_offsets,
845
+ T=T,
846
+ H=H,
847
+ HQ=HQ,
848
+ G=G,
849
+ K=K,
850
+ V=V,
851
+ BC=BC,
852
+ BS=BS,
853
+ BK=BK,
854
+ BV=BV,
855
+ )
856
+ return o, lse
857
+
858
+
859
+ def parallel_nsa_compression_bwd(
860
+ q: torch.Tensor,
861
+ k: torch.Tensor,
862
+ v: torch.Tensor,
863
+ o: torch.Tensor,
864
+ lse: torch.Tensor,
865
+ do: torch.Tensor,
866
+ block_size: int = 64,
867
+ scale: float = None,
868
+ offsets: Optional[torch.LongTensor] = None,
869
+ token_indices: Optional[torch.LongTensor] = None,
870
+ ):
871
+ B, T, HQ, K, V = *q.shape, v.shape[-1]
872
+ H = k.shape[2]
873
+ G = HQ // H
874
+ BC = BS = block_size
875
+ BK = triton.next_power_of_2(K)
876
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
877
+ NV = triton.cdiv(V, BV)
878
+ if offsets is not None:
879
+ lens = prepare_lens(offsets)
880
+ chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()])
881
+ chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets)
882
+ chunk_offsets = prepare_chunk_offsets(offsets, BS)
883
+ NC = len(chunk_indices)
884
+ else:
885
+ chunk_indices, chunk_offsets = None, None
886
+ NC = triton.cdiv(triton.cdiv(T, BS), BC)
887
+
888
+ delta = parallel_nsa_bwd_preprocess(o, do)
889
+
890
+ dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
891
+ grid = (T, NV, B * H)
892
+ parallel_nsa_compression_bwd_kernel_dq[grid](
893
+ q=q,
894
+ k=k,
895
+ v=v,
896
+ lse=lse,
897
+ delta=delta,
898
+ do=do,
899
+ dq=dq,
900
+ scale=scale,
901
+ offsets=offsets,
902
+ token_indices=token_indices,
903
+ chunk_offsets=chunk_offsets,
904
+ T=T,
905
+ B=B,
906
+ H=H,
907
+ HQ=HQ,
908
+ G=G,
909
+ K=K,
910
+ V=V,
911
+ BC=BC,
912
+ BS=BS,
913
+ BK=BK,
914
+ BV=BV
915
+ )
916
+ dq = dq.sum(0)
917
+
918
+ dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
919
+ dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
920
+
921
+ grid = (NV, NC, B * H)
922
+ parallel_nsa_compression_bwd_kernel_dkv[grid](
923
+ q=q,
924
+ k=k,
925
+ v=v,
926
+ lse=lse,
927
+ delta=delta,
928
+ do=do,
929
+ dk=dk,
930
+ dv=dv,
931
+ offsets=offsets,
932
+ chunk_indices=chunk_indices,
933
+ chunk_offsets=chunk_offsets,
934
+ scale=scale,
935
+ T=T,
936
+ B=B,
937
+ H=H,
938
+ HQ=HQ,
939
+ G=G,
940
+ K=K,
941
+ V=V,
942
+ BC=BC,
943
+ BS=BS,
944
+ BK=BK,
945
+ BV=BV
946
+ )
947
+ dk = dk.sum(0)
948
+ return dq, dk, dv
949
+
950
+
951
+ class ParallelNSACompressionFunction(torch.autograd.Function):
952
+
953
+ @staticmethod
954
+ @contiguous
955
+ @autocast_custom_fwd
956
+ def forward(
957
+ ctx,
958
+ q,
959
+ k,
960
+ v,
961
+ block_size,
962
+ scale,
963
+ offsets
964
+ ):
965
+ ctx.dtype = q.dtype
966
+
967
+ # 2-d sequence indices denoting the offsets of tokens in each sequence
968
+ # for example, if the passed `offsets` is [0, 2, 6],
969
+ # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
970
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
971
+ token_indices = prepare_token_indices(offsets) if offsets is not None else None
972
+
973
+ o, lse = parallel_nsa_compression_fwd(
974
+ q=q,
975
+ k=k,
976
+ v=v,
977
+ block_size=block_size,
978
+ scale=scale,
979
+ offsets=offsets,
980
+ token_indices=token_indices
981
+ )
982
+ ctx.save_for_backward(q, k, v, o, lse)
983
+ ctx.offsets = offsets
984
+ ctx.token_indices = token_indices
985
+ ctx.block_size = block_size
986
+ ctx.scale = scale
987
+ return o.to(q.dtype), lse
988
+
989
+ @staticmethod
990
+ @contiguous
991
+ @autocast_custom_bwd
992
+ def backward(ctx, do, *args):
993
+ q, k, v, o, lse = ctx.saved_tensors
994
+ dq, dk, dv = parallel_nsa_compression_bwd(
995
+ q=q,
996
+ k=k,
997
+ v=v,
998
+ o=o,
999
+ lse=lse,
1000
+ do=do,
1001
+ block_size=ctx.block_size,
1002
+ scale=ctx.scale,
1003
+ offsets=ctx.offsets,
1004
+ token_indices=ctx.token_indices
1005
+ )
1006
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None
1007
+
1008
+
1009
+ def parallel_nsa_topk(
1010
+ q: torch.Tensor,
1011
+ k: torch.Tensor,
1012
+ lse: torch.Tensor,
1013
+ block_counts: Union[torch.LongTensor, int],
1014
+ block_size: int = 64,
1015
+ scale: float = None,
1016
+ offsets: Optional[torch.LongTensor] = None,
1017
+ ) -> torch.LongTensor:
1018
+ B, T, HQ, K = q.shape
1019
+ H = k.shape[2]
1020
+ G = HQ // H
1021
+ S = block_counts if isinstance(block_counts, int) else block_counts.max().item()
1022
+ S = triton.next_power_of_2(S)
1023
+ # here we set BC = BS, but beware that they are actually decoupled
1024
+ BC = BS = block_size
1025
+ BK = triton.next_power_of_2(K)
1026
+
1027
+ block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device)
1028
+ token_indices = prepare_token_indices(offsets) if offsets is not None else None
1029
+ chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None
1030
+ grid = (T, B * H)
1031
+ parallel_nsa_kernel_topk[grid](
1032
+ q=q,
1033
+ k=k,
1034
+ lse=lse,
1035
+ scale=scale,
1036
+ block_indices=block_indices,
1037
+ offsets=offsets,
1038
+ token_indices=token_indices,
1039
+ chunk_offsets=chunk_offsets,
1040
+ T=T,
1041
+ H=H,
1042
+ HQ=HQ,
1043
+ G=G,
1044
+ K=K,
1045
+ S=S,
1046
+ BC=BC,
1047
+ BS=BS,
1048
+ BK=BK
1049
+ )
1050
+ return block_indices
1051
+
1052
+
1053
+ def parallel_nsa_fwd(
1054
+ q: torch.Tensor,
1055
+ k: torch.Tensor,
1056
+ v: torch.Tensor,
1057
+ block_indices: torch.LongTensor,
1058
+ block_counts: Union[torch.LongTensor, int],
1059
+ block_size: int,
1060
+ scale: float,
1061
+ offsets: Optional[torch.LongTensor] = None,
1062
+ token_indices: Optional[torch.LongTensor] = None,
1063
+ ):
1064
+ B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
1065
+ HQ = q.shape[2]
1066
+ G = HQ // H
1067
+ BS = block_size
1068
+ if check_shared_mem('hopper', q.device.index):
1069
+ BK = min(256, triton.next_power_of_2(K))
1070
+ BV = min(256, triton.next_power_of_2(V))
1071
+ else:
1072
+ BK = min(128, triton.next_power_of_2(K))
1073
+ BV = min(128, triton.next_power_of_2(V))
1074
+ NK = triton.cdiv(K, BK)
1075
+ NV = triton.cdiv(V, BV)
1076
+ assert NK == 1, "The key dimension can not be larger than 256"
1077
+
1078
+ grid = (T, NV, B * H)
1079
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
1080
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
1081
+
1082
+ parallel_nsa_fwd_kernel[grid](
1083
+ q=q,
1084
+ k=k,
1085
+ v=v,
1086
+ o=o,
1087
+ lse=lse,
1088
+ scale=scale,
1089
+ block_indices=block_indices,
1090
+ block_counts=block_counts,
1091
+ offsets=offsets,
1092
+ token_indices=token_indices,
1093
+ T=T,
1094
+ H=H,
1095
+ HQ=HQ,
1096
+ G=G,
1097
+ K=K,
1098
+ V=V,
1099
+ S=S,
1100
+ BS=BS,
1101
+ BK=BK,
1102
+ BV=BV,
1103
+ )
1104
+ return o, lse
1105
+
1106
+
1107
+ def parallel_nsa_block_mask(
1108
+ block_indices: torch.LongTensor,
1109
+ block_counts: Union[torch.LongTensor, int],
1110
+ offsets: torch.LongTensor,
1111
+ block_size: int,
1112
+ ):
1113
+ B, T, H, S = block_indices.shape
1114
+ BS = block_size
1115
+ if offsets is not None:
1116
+ NS = triton.cdiv(prepare_lens(offsets).max().item(), BS)
1117
+ else:
1118
+ NS = triton.cdiv(T, BS)
1119
+ block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device)
1120
+
1121
+ parallel_nsa_kernel_mask[(T, B, H*S)](
1122
+ block_indices=block_indices,
1123
+ block_counts=block_counts,
1124
+ block_mask=block_mask,
1125
+ T=T,
1126
+ H=H,
1127
+ S=S,
1128
+ BS=BS,
1129
+ NS=NS
1130
+ )
1131
+ return block_mask
1132
+
1133
+
1134
+ def parallel_nsa_bwd_preprocess(
1135
+ o: torch.Tensor,
1136
+ do: torch.Tensor
1137
+ ):
1138
+ V = o.shape[-1]
1139
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
1140
+ parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)](
1141
+ o=o,
1142
+ do=do,
1143
+ delta=delta,
1144
+ B=triton.next_power_of_2(V),
1145
+ V=V,
1146
+ )
1147
+ return delta
1148
+
1149
+
1150
+ def parallel_nsa_bwd(
1151
+ q: torch.Tensor,
1152
+ k: torch.Tensor,
1153
+ v: torch.Tensor,
1154
+ o: torch.Tensor,
1155
+ lse: torch.Tensor,
1156
+ do: torch.Tensor,
1157
+ block_indices: torch.Tensor,
1158
+ block_counts: Union[torch.LongTensor, int],
1159
+ block_size: int = 64,
1160
+ scale: float = None,
1161
+ offsets: Optional[torch.LongTensor] = None,
1162
+ token_indices: Optional[torch.LongTensor] = None,
1163
+ ):
1164
+ B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
1165
+ HQ = q.shape[2]
1166
+ G = HQ // H
1167
+ BS = block_size
1168
+ BK = triton.next_power_of_2(K)
1169
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
1170
+ NV = triton.cdiv(V, BV)
1171
+
1172
+ delta = parallel_nsa_bwd_preprocess(o, do)
1173
+
1174
+ dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
1175
+ grid = (T, NV, B * H)
1176
+ parallel_nsa_bwd_kernel_dq[grid](
1177
+ q=q,
1178
+ k=k,
1179
+ v=v,
1180
+ lse=lse,
1181
+ delta=delta,
1182
+ do=do,
1183
+ dq=dq,
1184
+ block_indices=block_indices,
1185
+ block_counts=block_counts,
1186
+ offsets=offsets,
1187
+ token_indices=token_indices,
1188
+ scale=scale,
1189
+ T=T,
1190
+ B=B,
1191
+ H=H,
1192
+ HQ=HQ,
1193
+ G=G,
1194
+ K=K,
1195
+ V=V,
1196
+ S=S,
1197
+ BS=BS,
1198
+ BK=BK,
1199
+ BV=BV
1200
+ )
1201
+ dq = dq.sum(0)
1202
+
1203
+ if offsets is not None:
1204
+ chunk_indices = prepare_chunk_indices(offsets, BS)
1205
+ NS = len(chunk_indices)
1206
+ else:
1207
+ chunk_indices = None
1208
+ NS = triton.cdiv(T, BS)
1209
+
1210
+ # [B, T, H, M]
1211
+ block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size)
1212
+ dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
1213
+ dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
1214
+
1215
+ grid = (NV, NS, B * H)
1216
+ parallel_nsa_bwd_kernel_dkv[grid](
1217
+ q=q,
1218
+ k=k,
1219
+ v=v,
1220
+ lse=lse,
1221
+ delta=delta,
1222
+ do=do,
1223
+ dk=dk,
1224
+ dv=dv,
1225
+ block_mask=block_mask,
1226
+ offsets=offsets,
1227
+ chunk_indices=chunk_indices,
1228
+ scale=scale,
1229
+ T=T,
1230
+ B=B,
1231
+ H=H,
1232
+ HQ=HQ,
1233
+ G=G,
1234
+ K=K,
1235
+ V=V,
1236
+ M=block_mask.shape[-1],
1237
+ BS=BS,
1238
+ BK=BK,
1239
+ BV=BV
1240
+ )
1241
+ dk = dk.sum(0)
1242
+ return dq, dk, dv
1243
+
1244
+
1245
+ @torch.compile
1246
+ class ParallelNSAFunction(torch.autograd.Function):
1247
+
1248
+ @staticmethod
1249
+ @contiguous
1250
+ @autocast_custom_fwd
1251
+ def forward(ctx, q, k, v, block_indices, block_counts, block_size, scale, offsets):
1252
+ ctx.dtype = q.dtype
1253
+
1254
+ # 2-d sequence indices denoting the offsets of tokens in each sequence
1255
+ # for example, if the passed `offsets` is [0, 2, 6],
1256
+ # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
1257
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1258
+ token_indices = prepare_token_indices(offsets) if offsets is not None else None
1259
+
1260
+ o, lse = parallel_nsa_fwd(
1261
+ q=q,
1262
+ k=k,
1263
+ v=v,
1264
+ block_indices=block_indices,
1265
+ block_counts=block_counts,
1266
+ block_size=block_size,
1267
+ scale=scale,
1268
+ offsets=offsets,
1269
+ token_indices=token_indices
1270
+ )
1271
+ ctx.save_for_backward(q, k, v, o, lse)
1272
+ ctx.block_indices = block_indices
1273
+ ctx.block_counts = block_counts
1274
+ ctx.offsets = offsets
1275
+ ctx.token_indices = token_indices
1276
+ ctx.block_size = block_size
1277
+ ctx.scale = scale
1278
+ return o.to(q.dtype)
1279
+
1280
+ @staticmethod
1281
+ @contiguous
1282
+ @autocast_custom_bwd
1283
+ def backward(ctx, do):
1284
+ q, k, v, o, lse = ctx.saved_tensors
1285
+ dq, dk, dv = parallel_nsa_bwd(
1286
+ q=q,
1287
+ k=k,
1288
+ v=v,
1289
+ o=o,
1290
+ lse=lse,
1291
+ do=do,
1292
+ block_indices=ctx.block_indices,
1293
+ block_counts=ctx.block_counts,
1294
+ block_size=ctx.block_size,
1295
+ scale=ctx.scale,
1296
+ offsets=ctx.offsets,
1297
+ token_indices=ctx.token_indices
1298
+ )
1299
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
1300
+
1301
+
1302
+ def parallel_nsa_compression(
1303
+ q: torch.Tensor,
1304
+ k: torch.Tensor,
1305
+ v: torch.Tensor,
1306
+ block_size: int = 64,
1307
+ scale: float = None,
1308
+ offsets: Optional[torch.LongTensor] = None
1309
+ ):
1310
+ return ParallelNSACompressionFunction.apply(
1311
+ q,
1312
+ k,
1313
+ v,
1314
+ block_size,
1315
+ scale,
1316
+ offsets
1317
+ )
1318
+
1319
+
1320
+ def parallel_nsa(
1321
+ q: torch.Tensor,
1322
+ k: torch.Tensor,
1323
+ v: torch.Tensor,
1324
+ g_cmp: torch.Tensor,
1325
+ g_slc: torch.Tensor,
1326
+ g_swa: torch.Tensor,
1327
+ block_indices: Optional[torch.LongTensor] = None,
1328
+ block_counts: Union[torch.LongTensor, int] = 16,
1329
+ block_size: int = 64,
1330
+ window_size: int = 0,
1331
+ scale: Optional[float] = None,
1332
+ cu_seqlens: Optional[torch.LongTensor] = None,
1333
+ head_first: bool = False
1334
+ ) -> torch.Tensor:
1335
+ r"""
1336
+ Args:
1337
+ q (torch.Tensor):
1338
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
1339
+ k (torch.Tensor):
1340
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1341
+ GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
1342
+ v (torch.Tensor):
1343
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1344
+ g_cmp (torch.Tensor):
1345
+ Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1346
+ g_slc (torch.Tensor):
1347
+ Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1348
+ g_swa (torch.Tensor):
1349
+ Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1350
+ block_indices (torch.LongTensor):
1351
+ Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1352
+ `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1353
+ If `g_cmp` is provided, the passed `block_indices` will be ignored.
1354
+ block_counts (Optional[Union[torch.LongTensor, int]]):
1355
+ Number of selected blocks for each query.
1356
+ If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
1357
+ each query can select the same number of blocks.
1358
+ If not provided, it will default to 16.
1359
+ block_size (int):
1360
+ Selected block size. Default: 64.
1361
+ window_size (int):
1362
+ Sliding window size. Default: 0.
1363
+ scale (Optional[int]):
1364
+ Scale factor for attention scores.
1365
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1366
+ head_first (Optional[bool]):
1367
+ Whether the inputs are in the head-first format. Default: `False`.
1368
+ cu_seqlens (torch.LongTensor):
1369
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1370
+ consistent with the FlashAttention API.
1371
+
1372
+ Returns:
1373
+ o (torch.Tensor):
1374
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
1375
+ """
1376
+ assert block_counts is not None, "block counts must be provided for selection"
1377
+ if scale is None:
1378
+ scale = k.shape[-1] ** -0.5
1379
+ if cu_seqlens is not None:
1380
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
1381
+ if head_first:
1382
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
1383
+ g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa))
1384
+ if not isinstance(block_counts, int):
1385
+ block_counts = rearrange(block_counts, 'b h t -> b t h')
1386
+ assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
1387
+
1388
+ k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens)
1389
+ o_cmp, lse_cmp = None, None
1390
+ if g_cmp is not None:
1391
+ o_cmp, lse_cmp = parallel_nsa_compression(
1392
+ q=q,
1393
+ k=k_cmp,
1394
+ v=v_cmp,
1395
+ block_size=block_size,
1396
+ scale=scale,
1397
+ offsets=cu_seqlens
1398
+ )
1399
+ if block_indices is not None:
1400
+ warnings.warn("`block_indices` will be ignored when `g_cmp` is provided")
1401
+ block_indices = parallel_nsa_topk(
1402
+ q=q,
1403
+ k=k_cmp,
1404
+ lse=lse_cmp,
1405
+ block_counts=block_counts,
1406
+ block_size=block_size,
1407
+ scale=scale,
1408
+ offsets=cu_seqlens
1409
+ )
1410
+ o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens)
1411
+ o = o_slc * g_slc.unsqueeze(-1)
1412
+ if o_cmp is not None:
1413
+ o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1))
1414
+ if window_size > 0:
1415
+ if cu_seqlens is not None:
1416
+ max_seqlen = q.shape[1]
1417
+ o_swa = flash_attn_varlen_func(
1418
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
1419
+ cu_seqlens_q=cu_seqlens,
1420
+ cu_seqlens_k=cu_seqlens,
1421
+ max_seqlen_q=max_seqlen,
1422
+ max_seqlen_k=max_seqlen,
1423
+ causal=True,
1424
+ window_size=(window_size-1, 0)
1425
+ ).unsqueeze(0)
1426
+ else:
1427
+ o_swa = flash_attn_func(
1428
+ q, k, v,
1429
+ causal=True,
1430
+ window_size=(window_size-1, 0)
1431
+ )
1432
+ o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1))
1433
+ if head_first:
1434
+ o = rearrange(o, 'b t h d -> b h t d')
1435
+ return o
fla/ops/rebased/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (226 Bytes). View file
 
fla/ops/rebased/parallel.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
10
+
11
+ # Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
12
+ # https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def parallel_rebased_fwd_kernel(
17
+ q,
18
+ k,
19
+ v,
20
+ o,
21
+ z,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BTL: tl.constexpr,
29
+ BTS: tl.constexpr,
30
+ BK: tl.constexpr,
31
+ BV: tl.constexpr,
32
+ ):
33
+ # i_c: chunk index. used for sequence parallelism
34
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ NV = tl.cdiv(V, BV)
36
+ i_k = i_kv // (NV)
37
+ i_v = i_kv % (NV)
38
+
39
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
40
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, 0), (BK, BTS), (0, 1))
41
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v*BV), (BTS, BV), (1, 0))
42
+
43
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
44
+ b_q = tl.load(p_q, boundary_check=(0, 1))
45
+ b_q = (b_q * scale).to(b_q.dtype)
46
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
47
+ b_z = tl.zeros([BTL], dtype=tl.float32)
48
+
49
+ # Q block and K block have no overlap
50
+ # no need for mask, thereby saving flops
51
+ for _ in range(0, i_c*BTL, BTS):
52
+ # [BK, BTS]
53
+ b_k = tl.load(p_k, boundary_check=(0, 1))
54
+
55
+ # [BTS, BV]
56
+ b_v = tl.load(p_v, boundary_check=(0, 1))
57
+ # [BTL, BTS]
58
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
59
+ b_s = b_s * b_s
60
+ b_z += tl.sum(b_s, axis=1)
61
+
62
+ # [BQ, BD]
63
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
64
+ p_k = tl.advance(p_k, (0, BTS))
65
+ p_v = tl.advance(p_v, (BTS, 0))
66
+
67
+ # # rescale interchunk output
68
+ tl.debug_barrier()
69
+ o_q = tl.arange(0, BTL)
70
+ # # sync threads, easy for compiler to optimize
71
+ # tl.debug_barrier()
72
+
73
+ o_k = tl.arange(0, BTS)
74
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, i_c*BTL), (BK, BTS), (0, 1))
75
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTS, BV), (1, 0))
76
+ # Q block and K block have overlap. masks required
77
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
78
+ # [BK, BTS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BTS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BTL, BTS]
83
+ m_s = o_q[:, None] >= o_k[None, :]
84
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
85
+ b_s = b_s * b_s
86
+ b_s = tl.where(m_s, b_s, 0)
87
+ b_z += tl.sum(b_s, axis=1)
88
+ # [BTL, BV]
89
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
90
+ p_k = tl.advance(p_k, (0, BTS))
91
+ p_v = tl.advance(p_v, (BTS, 0))
92
+ o_k += BTS
93
+
94
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
95
+ p_z = z + (i_bh + B * H * i_k) * T + i_c*BTL + tl.arange(0, BTL)
96
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
97
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c*BTL + tl.arange(0, BTL)) < T))
98
+
99
+
100
+ @triton.jit(do_not_specialize=['T'])
101
+ def _parallel_rebased_bwd_dq(
102
+ i_bh,
103
+ i_c,
104
+ i_k,
105
+ i_v,
106
+ i_h,
107
+ q,
108
+ k,
109
+ v,
110
+ do,
111
+ dz,
112
+ dq,
113
+ scale,
114
+ T,
115
+ B: tl.constexpr,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BTL: tl.constexpr,
120
+ BTS: tl.constexpr,
121
+ BK: tl.constexpr,
122
+ BV: tl.constexpr
123
+ ):
124
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
125
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
126
+ b_q = tl.load(p_q, boundary_check=(0, 1))
127
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
130
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k*BK), (BTS, BK), (1, 0))
131
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, 0), (BV, BTS), (0, 1))
132
+ p_dz = dz + i_bh * T + i_c*BTL + tl.arange(0, BTL)
133
+ b_dz = tl.load(p_dz, mask=(i_c*BTL + tl.arange(0, BTL)) < T)
134
+
135
+ for _ in range(0, i_c*BTL, BTS):
136
+ # [BTS, BK]
137
+ b_k = tl.load(p_k, boundary_check=(0, 1))
138
+ # [BV, BTS]
139
+ b_v = tl.load(p_v, boundary_check=(0, 1))
140
+ # [BTL, BTS]
141
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
142
+ if i_v == 0:
143
+ b_ds += b_dz[:, None]
144
+ else:
145
+ b_ds = b_ds
146
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
147
+ # [BQ, BD]
148
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)
149
+ p_k = tl.advance(p_k, (BTS, 0))
150
+ p_v = tl.advance(p_v, (0, BTS))
151
+
152
+ b_dq *= scale
153
+ o_q = tl.arange(0, BTL)
154
+ o_k = tl.arange(0, BTS)
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTS, BK), (1, 0))
156
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, i_c*BTL), (BV, BTS), (0, 1))
157
+ # Q block and K block have overlap. masks required
158
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
159
+ # [BTS, BK]
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ # [BV, BTS]
162
+ b_v = tl.load(p_v, boundary_check=(0, 1))
163
+ # [BTL, BTS]
164
+ m_s = o_q[:, None] >= o_k[None, :]
165
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
166
+ if i_v == 0:
167
+ b_ds += b_dz[:, None]
168
+ else:
169
+ b_ds = b_ds
170
+ b_ds = tl.where(m_s, b_ds, 0) * scale
171
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
172
+ b_s = tl.where(m_s, b_s, 0)
173
+ # [BTL, BK]
174
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),
175
+ b_k, allow_tf32=False)
176
+ p_k = tl.advance(p_k, (BTS, 0))
177
+ p_v = tl.advance(p_v, (0, BTS))
178
+ o_k += BTS
179
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
180
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
181
+ return
182
+
183
+
184
+ @triton.jit(do_not_specialize=['T'])
185
+ def _parallel_rebased_bwd_dkv(
186
+ i_bh,
187
+ i_c,
188
+ i_k,
189
+ i_v,
190
+ i_h,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ K: tl.constexpr,
203
+ V: tl.constexpr,
204
+ BTL: tl.constexpr,
205
+ BTS: tl.constexpr,
206
+ BK: tl.constexpr,
207
+ BV: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
214
+ [BTL, BV], dtype=tl.float32)
215
+
216
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
217
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
218
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
219
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
220
+ # [BK, BTS]
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ # [BV, BTS]
223
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
224
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
225
+ # [BTL, BTS]
226
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale
227
+ b_s2 = b_s * b_s
228
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
229
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
230
+ if i_v == 0:
231
+ b_ds += b_dz[None, :] * scale
232
+ else:
233
+ b_ds = b_ds
234
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
235
+
236
+ tl.debug_barrier()
237
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
238
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
239
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
240
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
241
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
242
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
243
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
244
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
245
+ # [BK, BQ]
246
+ m_s = o_k[:, None] <= o_q[None, :]
247
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
248
+ b_s2 = b_s * b_s
249
+ b_s = tl.where(m_s, b_s, 0)
250
+ b_s2 = tl.where(m_s, b_s2, 0)
251
+
252
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
253
+ if i_v == 0:
254
+ b_ds += b_dz[None, :]
255
+ else:
256
+ b_ds = b_ds
257
+ b_ds = tl.where(m_s, b_ds, 0) * scale
258
+ # [BK, BD]
259
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
260
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
261
+ o_q += BTS
262
+
263
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
264
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
265
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
266
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
267
+ return
268
+
269
+
270
+ @triton.jit(do_not_specialize=['T'])
271
+ def parallel_rebased_bwd_kernel(
272
+ q,
273
+ k,
274
+ v,
275
+ do,
276
+ dz,
277
+ dq,
278
+ dk,
279
+ dv,
280
+ scale,
281
+ T,
282
+ B: tl.constexpr,
283
+ H: tl.constexpr,
284
+ K: tl.constexpr,
285
+ V: tl.constexpr,
286
+ BTL: tl.constexpr,
287
+ BTS: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr
290
+ ):
291
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
292
+ NV = tl.cdiv(V, BV)
293
+ i_k = i_kv // (NV)
294
+ i_v = i_kv % (NV)
295
+ i_h = i_bh % H
296
+ _parallel_rebased_bwd_dq(
297
+ i_bh,
298
+ i_c,
299
+ i_k,
300
+ i_v,
301
+ i_h,
302
+ q,
303
+ k,
304
+ v,
305
+ do,
306
+ dz,
307
+ dq,
308
+ scale,
309
+ B=B,
310
+ H=H,
311
+ T=T,
312
+ K=K,
313
+ V=V,
314
+ BTL=BTL,
315
+ BTS=BTS,
316
+ BK=BK,
317
+ BV=BV
318
+ )
319
+ tl.debug_barrier()
320
+ _parallel_rebased_bwd_dkv(
321
+ i_bh,
322
+ i_c,
323
+ i_k,
324
+ i_v,
325
+ i_h,
326
+ q,
327
+ k,
328
+ v,
329
+ do,
330
+ dz,
331
+ dk,
332
+ dv,
333
+ scale,
334
+ B=B,
335
+ H=H,
336
+ T=T,
337
+ K=K,
338
+ V=V,
339
+ BTL=BTL,
340
+ BTS=BTS,
341
+ BK=BK,
342
+ BV=BV
343
+ )
344
+
345
+
346
+ class ParallelBasedFunction(torch.autograd.Function):
347
+
348
+ @staticmethod
349
+ @input_guard
350
+ @autocast_custom_fwd
351
+ def forward(ctx, q, k, v, scale):
352
+ BTL, BTS = 128, 32
353
+ assert BTL % BTS == 0
354
+ # assert q.shape[-1] % 16 == 0
355
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
356
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
357
+ BK, BV = max(BK, 16), max(BV, 16)
358
+ B, H, T, K, V = *k.shape, v.shape[-1]
359
+ num_stages = 2
360
+ num_warps = 4
361
+ NK = triton.cdiv(K, BK)
362
+ NV = triton.cdiv(V, BV)
363
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
364
+
365
+ assert NK == 1, "will encounter some synchronization issue if not."
366
+
367
+ o = torch.empty(NK, B, H, T, V, device=q.device)
368
+ z = torch.empty(NK, B, H, T, device=q.device)
369
+ parallel_rebased_fwd_kernel[grid](
370
+ q,
371
+ k,
372
+ v,
373
+ o,
374
+ z,
375
+ scale,
376
+ T=T,
377
+ B=B,
378
+ H=H,
379
+ K=K,
380
+ V=V,
381
+ BTL=BTL,
382
+ BTS=BTS,
383
+ BK=BK,
384
+ BV=BV,
385
+ num_warps=num_warps,
386
+ num_stages=num_stages
387
+ )
388
+ ctx.save_for_backward(q, k, v)
389
+ ctx.scale = scale
390
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
391
+
392
+ @staticmethod
393
+ @input_guard
394
+ @autocast_custom_bwd
395
+ def backward(ctx, do, dz):
396
+ q, k, v = ctx.saved_tensors
397
+ scale = ctx.scale
398
+ BTL, BTS = 64, 32
399
+ assert BTL % BTS == 0
400
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
401
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
402
+ BK, BV = max(BK, 16), max(BV, 16)
403
+ B, H, T, K, V = *k.shape, v.shape[-1]
404
+ num_stages = 2
405
+ num_warps = 4
406
+ NK = triton.cdiv(K, BK)
407
+ NV = triton.cdiv(V, BV)
408
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
409
+
410
+ assert NK == 1, "will encounter some synchronization issue if not"
411
+
412
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
413
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
414
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
415
+
416
+ parallel_rebased_bwd_kernel[grid](
417
+ q,
418
+ k,
419
+ v,
420
+ do,
421
+ dz,
422
+ dq,
423
+ dk,
424
+ dv,
425
+ scale,
426
+ T=T,
427
+ B=B,
428
+ H=H,
429
+ K=K,
430
+ V=V,
431
+ BTL=BTL,
432
+ BTS=BTS,
433
+ BK=BK,
434
+ BV=BV,
435
+ num_warps=num_warps,
436
+ num_stages=num_stages
437
+ )
438
+
439
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
440
+
441
+
442
+ def parallel_rebased(
443
+ q: torch.Tensor,
444
+ k: torch.Tensor,
445
+ v: torch.Tensor,
446
+ eps: float = 1e-5,
447
+ use_scale: bool = True,
448
+ use_normalize: bool = True,
449
+ return_both: bool = False,
450
+ head_first: bool = True
451
+ ):
452
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
453
+ if use_scale:
454
+ scale = q.shape[-1] ** -0.5
455
+ else:
456
+ scale = 1
457
+ if not head_first:
458
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
459
+ o, z = ParallelBasedFunction.apply(q, k, v, scale)
460
+ if return_both:
461
+ return o, z
462
+ if use_normalize:
463
+ o = o / (z[..., None] + eps)
464
+ if not head_first:
465
+ o = o.transpose(1, 2)
466
+ return o.to(q.dtype)
fla/ops/retention/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_retention
4
+ from .fused_chunk import fused_chunk_retention
5
+ from .fused_recurrent import fused_recurrent_retention
6
+ from .parallel import parallel_retention
7
+
8
+ __all__ = [
9
+ 'chunk_retention',
10
+ 'fused_chunk_retention',
11
+ 'parallel_retention',
12
+ 'fused_recurrent_retention'
13
+ ]
fla/ops/retention/fused_chunk.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from packaging import version
10
+
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.jit(do_not_specialize=['T'])
15
+ def fused_chunk_retention_fwd_kernel(
16
+ q,
17
+ k,
18
+ v,
19
+ o,
20
+ h0,
21
+ ht,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BK: tl.constexpr,
30
+ BV: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr,
33
+ CHECK: tl.constexpr
34
+ ):
35
+ # indices
36
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
37
+ i_h = i_bh % H
38
+
39
+ o_i = tl.arange(0, BT)
40
+ # decay rate given the head index
41
+ b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))
42
+
43
+ # d_b: overall decay for the entire chunk
44
+ # d_o: cumulative decay from the start of the chunk
45
+ # d_h: cumulative decay from the end of the chunk
46
+ d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
47
+
48
+ # [BT, BT]
49
+ m_s = o_i[:, None] >= o_i[None, :]
50
+ d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
51
+ # [BK, BV]
52
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
53
+
54
+ # make block pointers
55
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
56
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
57
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
58
+ p_o = tl.make_block_ptr(o + (i_k*B*H+i_bh).to(tl.int64) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
59
+
60
+ if USE_INITIAL_STATE:
61
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
62
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
63
+
64
+ NT = tl.cdiv(T, BT)
65
+ for i in range(0, NT):
66
+ # [BT, BK]
67
+ b_q = tl.load(p_q, boundary_check=(0, 1))
68
+ b_q = (b_q * scale).to(b_q.dtype)
69
+ # [BK, BT]
70
+ b_k = tl.load(p_k, boundary_check=(0, 1))
71
+ # [BT, BV]
72
+ b_v = tl.load(p_v, boundary_check=(0, 1))
73
+
74
+ # [BT, BT]
75
+ b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
76
+ # [BT, BV]
77
+ b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
78
+ if CHECK and i == 0:
79
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
80
+ b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
81
+ else:
82
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
83
+ if i == NT - 1 and (T % BT) != 0:
84
+ d_b = tl.math.exp2((T % BT) * b_b)
85
+ d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
86
+ b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
87
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
88
+
89
+ p_q = tl.advance(p_q, (BT, 0))
90
+ p_k = tl.advance(p_k, (0, BT))
91
+ p_v = tl.advance(p_v, (BT, 0))
92
+ p_o = tl.advance(p_o, (BT, 0))
93
+
94
+ if STORE_FINAL_STATE:
95
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
96
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+
99
+ @triton.jit(do_not_specialize=['T'])
100
+ def fused_chunk_retention_bwd_kernel(
101
+ q,
102
+ k,
103
+ v,
104
+ do,
105
+ dq,
106
+ dk,
107
+ dv,
108
+ h0,
109
+ scale,
110
+ T,
111
+ B: tl.constexpr,
112
+ H: tl.constexpr,
113
+ K: tl.constexpr,
114
+ V: tl.constexpr,
115
+ BT: tl.constexpr,
116
+ BK: tl.constexpr,
117
+ BV: tl.constexpr,
118
+ USE_INITIAL_STATE: tl.constexpr,
119
+ CHECK: tl.constexpr
120
+ ):
121
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
122
+ i_h = i_bh % H
123
+
124
+ o_i = tl.arange(0, BT)
125
+ b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))
126
+ d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)
127
+ d_b = tl.math.exp2(BT * b_b)
128
+
129
+ m_s = o_i[:, None] >= o_i[None, :]
130
+ d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
131
+ # [BV, BK]
132
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
133
+ if USE_INITIAL_STATE:
134
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
135
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
136
+
137
+ for i in range(0, tl.cdiv(T, BT)):
138
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
139
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
140
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
141
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
142
+
143
+ # [BT, K]
144
+ b_k = tl.load(p_k, boundary_check=(0, 1))
145
+ # [V, BT]
146
+ b_v = tl.load(p_v, boundary_check=(0, 1))
147
+ # [BT, V]
148
+ b_do = tl.load(p_do, boundary_check=(0, 1))
149
+ b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
150
+
151
+ # [BT, BT]
152
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
153
+ b_ds = (b_ds * d_s).to(b_k.dtype)
154
+ # [BT, K]
155
+ b_dq = tl.dot(b_ds, b_k, allow_tf32=False)
156
+ # [V, K]
157
+ if CHECK and i == 0:
158
+ b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
159
+ b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
160
+ else:
161
+ b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
162
+ b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
163
+
164
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ # sync threads
167
+ b_h = None
168
+ tl.debug_barrier()
169
+ d_s = tl.trans(d_s)
170
+ # [BK, BV]
171
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
172
+ for i in range(1, tl.cdiv(T, BT) + 1):
173
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
174
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
175
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
176
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
177
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
178
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H).to(tl.int64) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
179
+ # [K, BT]
180
+ b_q = tl.load(p_q, boundary_check=(0, 1))
181
+ # [BT, BK]
182
+ b_k = tl.load(p_k, boundary_check=(0, 1))
183
+ # [BT, BV]
184
+ b_v = tl.load(p_v, boundary_check=(0, 1))
185
+ b_do = tl.load(p_do, boundary_check=(0, 1))
186
+ b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
187
+
188
+ # [BT, BT]
189
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
190
+ b_ds = (b_ds * d_s).to(b_k.dtype)
191
+
192
+ # [BT, BT]
193
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
194
+ # [BT, BK]
195
+ b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
196
+ # [BT, BV]
197
+ b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
198
+ if CHECK and i == 1:
199
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
200
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
201
+ b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
202
+ else:
203
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
204
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
205
+ b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
206
+
207
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
208
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
209
+
210
+
211
+ class FusedChunkRetentionFunction(torch.autograd.Function):
212
+
213
+ @staticmethod
214
+ @input_guard
215
+ @autocast_custom_fwd
216
+ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
217
+ B, H, T, K, V = *k.shape, v.shape[-1]
218
+
219
+ BT = 64
220
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
221
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
222
+ num_stages = 1
223
+ num_warps = 4
224
+
225
+ o = q.new_empty(NK, B, H, T, V)
226
+
227
+ if output_final_state:
228
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
229
+ else:
230
+ final_state = None
231
+ # the bug still exists even for Triton 2.2 on H100 GPUs
232
+ # so we always enable initial checks
233
+ CHECK = True
234
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
235
+ import warnings
236
+ warnings.warn(
237
+ "Triton<2.2.0 detected for running this kernel, "
238
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
239
+ "that lead to significant precision loss. "
240
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
241
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
242
+ )
243
+ CHECK = True
244
+
245
+ grid = (NV, NK, B * H)
246
+ fused_chunk_retention_fwd_kernel[grid](
247
+ q,
248
+ k,
249
+ v,
250
+ o,
251
+ initial_state,
252
+ final_state,
253
+ scale,
254
+ T=T,
255
+ B=B,
256
+ H=H,
257
+ K=K,
258
+ V=V,
259
+ BT=BT,
260
+ BK=BK,
261
+ BV=BV,
262
+ USE_INITIAL_STATE=initial_state is not None,
263
+ STORE_FINAL_STATE=output_final_state,
264
+ CHECK=CHECK,
265
+ num_warps=num_warps,
266
+ num_stages=num_stages
267
+ )
268
+
269
+ o = o.sum(0)
270
+ ctx.save_for_backward(q, k, v, initial_state)
271
+ ctx.CHECK = CHECK
272
+ return o.to(q.dtype), final_state
273
+
274
+ @staticmethod
275
+ @input_guard
276
+ @autocast_custom_bwd
277
+ def backward(ctx, do, dht=None):
278
+ q, k, v, initial_state = ctx.saved_tensors
279
+ B, H, T, K, V = *k.shape, v.shape[-1]
280
+ scale = K ** -0.5
281
+
282
+ BT = 64
283
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
284
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
285
+ num_stages = 1
286
+ num_warps = 4
287
+
288
+ dq = q.new_empty(NV, B, H, T, K)
289
+ dk = q.new_empty(NV, B, H, T, K)
290
+ dv = q.new_empty(NK, B, H, T, V)
291
+ grid = (NV, NK, B * H)
292
+
293
+ fused_chunk_retention_bwd_kernel[grid](
294
+ q,
295
+ k,
296
+ v,
297
+ do,
298
+ dq,
299
+ dk,
300
+ dv,
301
+ initial_state,
302
+ scale,
303
+ T=T,
304
+ B=B,
305
+ H=H,
306
+ K=K,
307
+ V=V,
308
+ BT=BT,
309
+ BK=BK,
310
+ BV=BV,
311
+ USE_INITIAL_STATE=initial_state is not None,
312
+ CHECK=ctx.CHECK,
313
+ num_warps=num_warps,
314
+ num_stages=num_stages
315
+ )
316
+ dq = dq.sum(0)
317
+ dk = dk.sum(0)
318
+ dv = dv.sum(0)
319
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
320
+
321
+
322
+ def fused_chunk_retention(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ v: torch.Tensor,
326
+ scale: Optional[float] = None,
327
+ initial_state: Optional[torch.Tensor] = None,
328
+ output_final_state: bool = False,
329
+ head_first: bool = True
330
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
331
+ r"""
332
+ Args:
333
+ q (torch.Tensor):
334
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
335
+ k (torch.Tensor):
336
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
337
+ v (torch.Tensor):
338
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
339
+ scale (Optional[int]):
340
+ Scale factor for the attention scores.
341
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
342
+ initial_state (Optional[torch.Tensor]):
343
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
344
+ output_final_state (Optional[bool]):
345
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
346
+ head_first (Optional[bool]):
347
+ Whether the inputs are in the head-first format.
348
+ Default: `True`.
349
+
350
+ Returns:
351
+ o (torch.Tensor):
352
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
353
+ final_state (torch.Tensor):
354
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
355
+ """
356
+ if scale is None:
357
+ scale = k.shape[-1] ** -0.5
358
+ if not head_first:
359
+ q = q.transpose(1, 2)
360
+ k = k.transpose(1, 2)
361
+ v = v.transpose(1, 2)
362
+ o, final_state = FusedChunkRetentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
363
+ if not head_first:
364
+ o = o.transpose(1, 2)
365
+ return o, final_state
fla/ops/retention/fused_recurrent.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla
9
+
10
+
11
+ def fused_recurrent_retention(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ scale: Optional[float] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: bool = False,
18
+ reverse: bool = False,
19
+ cu_seqlens: Optional[torch.LongTensor] = None,
20
+ head_first: bool = True
21
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ if head_first:
23
+ n_heads = q.shape[1]
24
+ else:
25
+ n_heads = q.shape[2]
26
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log()
27
+ if head_first:
28
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
29
+ else:
30
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
31
+ return fused_recurrent_simple_gla(
32
+ q=q,
33
+ k=k,
34
+ v=v,
35
+ g=g,
36
+ scale=scale,
37
+ initial_state=initial_state,
38
+ output_final_state=output_final_state,
39
+ reverse=reverse,
40
+ cu_seqlens=cu_seqlens,
41
+ head_first=head_first
42
+ )
fla/ops/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
fla/ops/rwkv6/fused_recurrent.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps)
22
+ for num_warps in [1, 2, 4, 8, 16]
23
+ ],
24
+ key=['BK', 'BV']
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def fused_recurrent_rwkv6_fwd_kernel(
28
+ q, # query [B, H, T, K]/[B, T, H, K]
29
+ k, # key [B, H, T, K]/[B, T, H, K]
30
+ v, # value [B, H, T, V]/[B, T, H, V]
31
+ w, # log gate [B, H, T]/[B, T, H] or None
32
+ u, # bonus [B, H, K]
33
+ o, # output [NK, B, H, T, V]/[NK, B, T, H, V]
34
+ h0, # initial hidden state [B, H, K, V]
35
+ ht, # final hidden state [B, H, K, V]
36
+ offsets,
37
+ scale,
38
+ T,
39
+ B: tl.constexpr,
40
+ H: tl.constexpr,
41
+ K: tl.constexpr,
42
+ V: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ REVERSE: tl.constexpr, # whether to reverse the recurrence
46
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
47
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
48
+ USE_OFFSETS: tl.constexpr,
49
+ HEAD_FIRST: tl.constexpr
50
+ ):
51
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
52
+ i_n, i_h = i_nh // H, i_nh % H
53
+ if USE_OFFSETS:
54
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
55
+ all = T
56
+ T = eos - bos
57
+ else:
58
+ bos, eos = i_n * T, i_n * T + T
59
+ all = B * T
60
+
61
+ o_k = i_k * BK + tl.arange(0, BK)
62
+ o_v = i_v * BV + tl.arange(0, BV)
63
+ if HEAD_FIRST:
64
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
65
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
66
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
67
+ p_w = w + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
68
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + o_v
69
+ else:
70
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
71
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
72
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
73
+ p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
74
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
75
+ p_u = u + i_h * K + o_k
76
+
77
+ mask_k = o_k < K
78
+ mask_v = o_v < V
79
+ mask_h = mask_k[:, None] & mask_v[None, :]
80
+
81
+ b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32)
82
+
83
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_INITIAL_STATE:
85
+ p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
86
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
87
+
88
+ for _ in range(0, T):
89
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
90
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
91
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
92
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
93
+ b_kv = b_k[:, None] * b_v[None, :]
94
+ b_o = tl.sum((b_h + b_kv * b_u[:, None]) * b_q[:, None], 0)
95
+ b_h = b_h * exp(b_w)[:, None] + b_kv
96
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
97
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
98
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
99
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
100
+ p_w += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
101
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
102
+
103
+ if STORE_FINAL_STATE:
104
+ p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
105
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
106
+
107
+
108
+ @triton.heuristics({
109
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
110
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
111
+ })
112
+ @triton.autotune(
113
+ configs=[
114
+ triton.Config({}, num_warps=1),
115
+ triton.Config({}, num_warps=2),
116
+ triton.Config({}, num_warps=4),
117
+ ],
118
+ key=['BK', 'BV']
119
+ )
120
+ @triton.jit(do_not_specialize=['T'])
121
+ def fused_recurrent_rwkv6_bwd_kernel_dq(
122
+ k, # key [B, H, T, V]/[B, T, H, V]
123
+ v, # value [B, H, T, V]/[B, T, H, V]
124
+ w, # log gate [B, H, T]/[B, T, H]
125
+ u, # bonus [B, H, K]
126
+ do, # gradient of output [B, H, T, V]/[B, T, H, V]
127
+ dq, # gradient of query [NV, B, H, T, K]/[NV, B, T, H, K]
128
+ dq1, # gradient of query_aux [NV, B, H, T, K]/[NV, B, T, H, K]
129
+ h0,
130
+ offsets,
131
+ scale,
132
+ T,
133
+ B: tl.constexpr,
134
+ H: tl.constexpr,
135
+ K: tl.constexpr,
136
+ V: tl.constexpr,
137
+ BK: tl.constexpr,
138
+ BV: tl.constexpr,
139
+ REVERSE: tl.constexpr,
140
+ USE_INITIAL_STATE: tl.constexpr,
141
+ USE_OFFSETS: tl.constexpr,
142
+ HEAD_FIRST: tl.constexpr
143
+ ):
144
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
145
+ i_n, i_h = i_nh // H, i_nh % H
146
+ if USE_OFFSETS:
147
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
148
+ all = T
149
+ T = eos - bos
150
+ else:
151
+ bos, eos = i_n * T, i_n * T + T
152
+ all = B * T
153
+
154
+ o_k = i_k * BK + tl.arange(0, BK)
155
+ o_v = i_v * BV + tl.arange(0, BV)
156
+ if HEAD_FIRST:
157
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
158
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
159
+ p_w = w + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
160
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
161
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + o_k
162
+ p_dq1 = dq1 + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + o_k
163
+ else:
164
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
165
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
166
+ p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
167
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
168
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
169
+ p_dq1 = dq1 + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
170
+ p_u = u + i_h * K + o_k
171
+
172
+ mask_k = o_k < K
173
+ mask_v = o_v < V
174
+ mask_h = mask_k[:, None] & mask_v[None, :]
175
+
176
+ b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32)
177
+
178
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
179
+ if USE_INITIAL_STATE:
180
+ p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
181
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
182
+
183
+ for _ in range(0, T):
184
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
185
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
186
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
187
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
188
+ b_kv = b_k[:, None] * b_v[None, :]
189
+
190
+ b_hq = b_h * b_do[None, :]
191
+ b_dq = tl.sum(b_hq + b_kv * b_u[:, None] * b_do[None, :], 1) * scale
192
+ b_dq1 = tl.sum(b_hq, 1)
193
+ b_h = b_h * exp(b_w)[:, None]
194
+ b_h += b_kv
195
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
196
+ tl.store(p_dq1, b_dq1.to(p_dq1.dtype.element_ty), mask=mask_k)
197
+
198
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
199
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
200
+ p_w += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
201
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
202
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
203
+ p_dq1 += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
204
+
205
+
206
+ @triton.heuristics({
207
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
208
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
209
+ })
210
+ @triton.autotune(
211
+ configs=[
212
+ triton.Config({}, num_warps=1),
213
+ triton.Config({}, num_warps=2),
214
+ triton.Config({}, num_warps=4),
215
+ ],
216
+ key=['BK', 'BV']
217
+ )
218
+ @triton.jit(do_not_specialize=['T'])
219
+ def fused_recurrent_rwkv6_bwd_kernel_dkv(
220
+ q, # query [B, H, T, K]/[B, T, H, K]
221
+ k, # key [B, H, T, V]/[B, T, H, V]
222
+ v, # value [B, H, T, V]/[B, T, H, V]
223
+ w, # log gate [B, H, T]/[B, T, H]
224
+ u, # bonus [B, H, K]
225
+ do, # gradient of output [B, H, T, V]/[B, T, H, V]
226
+ dk, # gradient of key [NV, B, H, T, K]/[NK, B, T, H, K]
227
+ dk1, # gradient of key_aux [NV, B, H, T, K]/[NK, B, T, H, K]
228
+ dv, # gradient of value [NK, B, H, T, V]/[NV, B, T, H, V]
229
+ dh0, # gradient of initial hidden state [N, H, K, V]
230
+ offsets,
231
+ scale,
232
+ T,
233
+ B: tl.constexpr,
234
+ H: tl.constexpr,
235
+ K: tl.constexpr,
236
+ V: tl.constexpr,
237
+ BK: tl.constexpr,
238
+ BV: tl.constexpr,
239
+ REVERSE: tl.constexpr,
240
+ USE_INITIAL_STATE: tl.constexpr,
241
+ USE_OFFSETS: tl.constexpr,
242
+ HEAD_FIRST: tl.constexpr,
243
+ ):
244
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
245
+ i_n, i_h = i_nh // H, i_nh % H
246
+ if USE_OFFSETS:
247
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
248
+ all = T
249
+ T = eos - bos
250
+ else:
251
+ bos, eos = i_n * T, i_n * T + T
252
+ all = B * T
253
+
254
+ o_k = i_k * BK + tl.arange(0, BK)
255
+ o_v = i_v * BV + tl.arange(0, BV)
256
+ if HEAD_FIRST:
257
+ p_q = q + i_nh * T*K + ((T-1) * K if not REVERSE else 0) + o_k
258
+ p_k = k + i_nh * T*K + ((T-1) * K if not REVERSE else 0) + o_k
259
+ p_v = v + i_nh * T*V + ((T-1) * V if not REVERSE else 0) + o_v
260
+ p_w = w + i_nh * T*K + ((T-1) * K if not REVERSE else 0) + o_k
261
+ p_do = do + i_nh * T*V + ((T-1) * V if not REVERSE else 0) + o_v
262
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T-1) * K if not REVERSE else 0) + o_k
263
+ p_dk1 = dk1 + (i_v * B*H + i_nh) * T*K + ((T-1) * K if not REVERSE else 0) + o_k
264
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T-1) * V if not REVERSE else 0) + o_v
265
+ else:
266
+ p_q = q + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k
267
+ p_k = k + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k
268
+ p_v = v + (bos + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v
269
+ p_w = w + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k
270
+ p_do = do + (bos + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v
271
+ p_dk = dk + ((i_v * all + bos) + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k
272
+ p_dk1 = dk1 + ((i_v * all + bos) + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k
273
+ p_dv = dv + ((i_k * all + bos) + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v
274
+ p_u = u + i_h * K + o_k
275
+
276
+ mask_k = o_k < K
277
+ mask_v = o_v < V
278
+ mask_h = mask_k[:, None] & mask_v[None, :]
279
+
280
+ b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32)
281
+
282
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
283
+ for _ in range(T - 1, -1, -1):
284
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
285
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
286
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
287
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
288
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
289
+ b_dkv = b_q[:, None] * b_do[None, :]
290
+ b_dk = tl.sum(b_dh * b_v[None, :], 1)
291
+ tl.store(p_dk1, b_dk.to(p_dk1.dtype.element_ty), mask=mask_k)
292
+ b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], 1)
293
+ b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], 0)
294
+
295
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
296
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
297
+ b_dh *= exp(b_w)[:, None]
298
+ b_dh += b_dkv
299
+
300
+ p_q += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K
301
+ p_k += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K
302
+ p_v += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * V
303
+ p_w += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K
304
+ p_do += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * V
305
+ p_dk += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K
306
+ p_dk1 += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K
307
+ p_dv += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * V
308
+
309
+ if USE_INITIAL_STATE:
310
+ p_dh0 = dh0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
311
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
312
+
313
+
314
+ @triton.heuristics({
315
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
316
+ })
317
+ @triton.autotune(
318
+ configs=[
319
+ triton.Config({'BT': BT, 'BK': BK}, num_warps=num_warps)
320
+ for BT in [16, 32, 64]
321
+ for BK in [32, 64]
322
+ for num_warps in [1, 2, 4, 8]
323
+ ],
324
+ key=['K']
325
+ )
326
+ @triton.jit(do_not_specialize=['T'])
327
+ def fused_recurrent_rwkv6_bwd_kernel_dw(
328
+ q,
329
+ k,
330
+ dq,
331
+ dk,
332
+ dw,
333
+ offsets,
334
+ scale,
335
+ T,
336
+ H: tl.constexpr,
337
+ K: tl.constexpr,
338
+ BT: tl.constexpr,
339
+ BK: tl.constexpr,
340
+ REVERSE: tl.constexpr,
341
+ HEAD_FIRST: tl.constexpr,
342
+ USE_OFFSETS: tl.constexpr
343
+ ):
344
+ i_k, i_nh = tl.program_id(0), tl.program_id(1)
345
+ i_n, i_h = i_nh // H, i_nh % H
346
+ if USE_OFFSETS:
347
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
348
+ else:
349
+ bos, eos = i_n * T, i_n * T + T
350
+ T = eos - bos
351
+ NT = tl.cdiv(T, BT)
352
+
353
+ o_i = tl.arange(0, BT)
354
+ m_i = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) if not REVERSE else tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
355
+
356
+ b_z = tl.zeros([BK], dtype=tl.float32)
357
+
358
+ i_t = 0 if not REVERSE else NT - 1
359
+ for _ in range(NT):
360
+ if HEAD_FIRST:
361
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (T, K), (K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0))
362
+ p_dq = tl.make_block_ptr(dq + i_nh * T*K, (T, K), (K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0))
363
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T-1, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
364
+ p_dk = tl.make_block_ptr(dk + i_nh * T*K, (T-1, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
365
+ p_dw = tl.make_block_ptr(dw + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
366
+ else:
367
+ p_q = tl.make_block_ptr(q + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0))
368
+ p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0))
369
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T-1, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
370
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T-1, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
371
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
372
+ # [BT, BK]
373
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
374
+ b_dq = tl.load(p_dq, boundary_check=(0, 1)).to(tl.float32)
375
+ b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32)
376
+ b_dk = tl.load(p_dk, boundary_check=(0, 1)).to(tl.float32)
377
+ b_dw = (b_q * b_dq * scale) - b_k * b_dk
378
+ b_c = b_z[None, :] + tl.dot(m_i, b_dw, allow_tf32=False)
379
+ tl.store(p_dw, b_c.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
380
+ if i_t >= 0:
381
+ b_z += tl.sum(b_dw, 0)
382
+
383
+ i_t += (1 if not REVERSE else -1)
384
+
385
+
386
+ def fused_recurrent_rwkv6_fwd(
387
+ q: torch.Tensor,
388
+ k: torch.Tensor,
389
+ v: torch.Tensor,
390
+ w: torch.Tensor,
391
+ u: torch.Tensor,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ output_final_state: bool = False,
395
+ reverse: bool = False,
396
+ offsets: Optional[torch.LongTensor] = None,
397
+ head_first: bool = True
398
+ ):
399
+ if head_first:
400
+ B, H, T, K, V = *k.shape, v.shape[-1]
401
+ else:
402
+ B, T, H, K, V = *k.shape, v.shape[-1]
403
+ N = B if offsets is None else len(offsets) - 1
404
+ BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ h0 = initial_state
408
+ ht = q.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
409
+ o = q.new_empty(NK, *v.shape, dtype=torch.float)
410
+
411
+ grid = (NV, NK, N * H)
412
+ fused_recurrent_rwkv6_fwd_kernel[grid](
413
+ q,
414
+ k,
415
+ v,
416
+ w,
417
+ u,
418
+ o,
419
+ h0,
420
+ ht,
421
+ offsets,
422
+ scale,
423
+ T=T,
424
+ B=B,
425
+ H=H,
426
+ K=K,
427
+ V=V,
428
+ BK=BK,
429
+ BV=BV,
430
+ REVERSE=reverse,
431
+ HEAD_FIRST=head_first
432
+ )
433
+ o = o.sum(0)
434
+ return o, ht
435
+
436
+
437
+ def fused_recurrent_rwkv6_bwd(
438
+ q: torch.Tensor,
439
+ k: torch.Tensor,
440
+ v: torch.Tensor,
441
+ w: torch.Tensor,
442
+ u: torch.Tensor,
443
+ do: torch.Tensor,
444
+ scale: Optional[float] = None,
445
+ initial_state: Optional[torch.Tensor] = None,
446
+ reverse: bool = False,
447
+ offsets: Optional[torch.LongTensor] = None,
448
+ head_first: bool = True
449
+ ):
450
+ if head_first:
451
+ B, H, T, K, V = *k.shape, v.shape[-1]
452
+ else:
453
+ B, T, H, K, V = *k.shape, v.shape[-1]
454
+ N = B if offsets is None else len(offsets) - 1
455
+
456
+ BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)
457
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
458
+
459
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float)
460
+ dq1 = torch.empty_like(dq)
461
+
462
+ grid = (NV, NK, N * H)
463
+ fused_recurrent_rwkv6_bwd_kernel_dq[grid](
464
+ k,
465
+ v,
466
+ w,
467
+ u,
468
+ do,
469
+ dq,
470
+ dq1,
471
+ initial_state,
472
+ offsets,
473
+ scale,
474
+ T=T,
475
+ B=B,
476
+ H=H,
477
+ K=K,
478
+ V=V,
479
+ BK=BK,
480
+ BV=BV,
481
+ REVERSE=reverse,
482
+ HEAD_FIRST=head_first
483
+ )
484
+ dq = dq.sum(0)
485
+ dq1 = dq1.sum(0)
486
+
487
+ BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
488
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
489
+
490
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float)
491
+ dk1 = q.new_empty(NV, *k.shape, dtype=torch.float)
492
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float)
493
+
494
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
495
+ grid = (NV, NK, N * H)
496
+ fused_recurrent_rwkv6_bwd_kernel_dkv[grid](
497
+ q,
498
+ k,
499
+ v,
500
+ w,
501
+ u,
502
+ do,
503
+ dk,
504
+ dk1,
505
+ dv,
506
+ dh0,
507
+ offsets,
508
+ scale,
509
+ T=T,
510
+ B=B,
511
+ H=H,
512
+ K=K,
513
+ V=V,
514
+ BK=BK,
515
+ BV=BV,
516
+ REVERSE=reverse,
517
+ HEAD_FIRST=head_first
518
+ )
519
+ dk = dk.sum(0)
520
+ dk1 = dk1.sum(0)
521
+ dv = dv.sum(0)
522
+
523
+ dw = torch.empty_like(w)
524
+ def grid(meta): return (triton.cdiv(meta['K'], meta['BK']), N * H)
525
+ fused_recurrent_rwkv6_bwd_kernel_dw[grid](
526
+ q,
527
+ k,
528
+ dq1,
529
+ dk1,
530
+ dw,
531
+ offsets,
532
+ scale,
533
+ T=T,
534
+ H=H,
535
+ K=K,
536
+ REVERSE=not reverse,
537
+ HEAD_FIRST=head_first
538
+ )
539
+ du = (do.float() * v).sum(-1, True, dtype=torch.float) * q * k * scale
540
+ du = du.sum((0, 2)) if head_first else du.sum((0, 1))
541
+ return dq, dk, dv, dw, du, dh0
542
+
543
+
544
+ class FusedRecurrentRWKV6Function(torch.autograd.Function):
545
+
546
+ @staticmethod
547
+ @input_guard
548
+ @autocast_custom_fwd
549
+ def forward(
550
+ ctx,
551
+ q: torch.Tensor,
552
+ k: torch.Tensor,
553
+ v: torch.Tensor,
554
+ w: torch.Tensor,
555
+ u: torch.Tensor,
556
+ scale: Optional[float] = None,
557
+ initial_state: Optional[torch.Tensor] = None,
558
+ output_final_state: bool = False,
559
+ reverse: bool = False,
560
+ offsets: Optional[torch.LongTensor] = None,
561
+ head_first: bool = True
562
+ ):
563
+ o, ht = fused_recurrent_rwkv6_fwd(
564
+ q=q,
565
+ k=k,
566
+ v=v,
567
+ w=w,
568
+ u=u,
569
+ scale=scale,
570
+ initial_state=initial_state,
571
+ output_final_state=output_final_state,
572
+ reverse=reverse,
573
+ offsets=offsets,
574
+ head_first=head_first
575
+ )
576
+ ctx.save_for_backward(q, k, v, w, u, initial_state)
577
+ ctx.scale = scale
578
+ ctx.reverse = reverse
579
+ ctx.offsets = offsets
580
+ ctx.head_first = head_first
581
+ return o.to(v), ht
582
+
583
+ @staticmethod
584
+ @input_guard
585
+ @autocast_custom_bwd
586
+ def backward(ctx, do, dht):
587
+ q, k, v, w, u, initial_state = ctx.saved_tensors
588
+
589
+ dq, dk, dv, dw, du, dh0 = fused_recurrent_rwkv6_bwd(
590
+ q=q,
591
+ k=k,
592
+ v=v,
593
+ w=w,
594
+ u=u,
595
+ do=do,
596
+ scale=ctx.scale,
597
+ initial_state=initial_state,
598
+ reverse=ctx.reverse,
599
+ offsets=ctx.offsets,
600
+ head_first=ctx.head_first
601
+ )
602
+ dh0 = dh0.to(initial_state) if dh0 is not None else dh0
603
+ return dq.to(q), dk.to(k), dv.to(v), dw.to(w), du.to(u), None, dh0, None, None, None, None
604
+
605
+
606
+ def fused_recurrent_rwkv6(
607
+ r: torch.Tensor,
608
+ k: torch.Tensor,
609
+ v: torch.Tensor,
610
+ w: torch.Tensor,
611
+ u: torch.Tensor,
612
+ scale: Optional[int] = None,
613
+ initial_state: Optional[torch.Tensor] = None,
614
+ output_final_state: bool = False,
615
+ reverse: bool = False,
616
+ cu_seqlens: Optional[torch.LongTensor] = None,
617
+ head_first: bool = True
618
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
619
+ r"""
620
+ Args:
621
+ r (torch.Tensor):
622
+ reception of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
623
+ Alias: q, query in linear attention.
624
+ k (torch.Tensor):
625
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
626
+ v (torch.Tensor):
627
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
628
+ w (torch.Tensor):
629
+ data-dependent decays of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` in log space! Alias: g.
630
+ u (torch.Tensor):
631
+ bonus of shape `[H, K]`
632
+ scale (Optional[int]):
633
+ Scale factor for the attention scores.
634
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
635
+ initial_state (Optional[torch.Tensor]):
636
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
637
+ For equal-length input sequences, `N` equals the batch size `B`.
638
+ Default: `None`.
639
+ output_final_state (Optional[bool]):
640
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
641
+ reverse (Optional[bool]):
642
+ If `True`, process the state passing in reverse order. Default: `False`.
643
+ cu_seqlens (torch.LongTensor):
644
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
645
+ consistent with the FlashAttention API.
646
+ head_first (Optional[bool]):
647
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
648
+ Default: `True`.
649
+
650
+ Returns:
651
+ o (torch.Tensor):
652
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
653
+ final_state (Optional[torch.Tensor]):
654
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
655
+
656
+ Examples::
657
+ >>> import torch
658
+ >>> import torch.nn.functional as F
659
+ >>> from einops import rearrange
660
+ >>> from fla.ops.rwkv6 import fused_recurrent_rwkv6
661
+ # inputs with equal lengths
662
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
663
+ >>> q = torch.randn(B, T, H, K, device='cuda')
664
+ >>> k = torch.randn(B, T, H, K, device='cuda')
665
+ >>> v = torch.randn(B, T, H, V, device='cuda')
666
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
667
+ >>> u = torch.randn(H, K, device='cuda')
668
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
669
+ >>> o, ht = fused_recurrent_rwkv6(q, k, v, g, u,
670
+ initial_state=h0,
671
+ output_final_state=True,
672
+ head_first=False)
673
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
674
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
675
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
676
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
677
+ >>> o_var, ht_var = fused_recurrent_rwkv6(q, k, v, g, u,
678
+ initial_state=h0,
679
+ output_final_state=True,
680
+ cu_seqlens=cu_seqlens,
681
+ head_first=False)
682
+ >>> assert o.allclose(o_var.view(o.shape))
683
+ >>> assert ht.allclose(ht_var)
684
+ """
685
+ if cu_seqlens is not None:
686
+ if r.shape[0] != 1:
687
+ raise ValueError(f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`."
688
+ f"Please flatten variable-length inputs before processing.")
689
+ if head_first:
690
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
691
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
692
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
693
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
694
+ if scale is None:
695
+ scale = k.shape[-1] ** -0.5
696
+ o, final_state = FusedRecurrentRWKV6Function.apply(
697
+ r,
698
+ k,
699
+ v,
700
+ w,
701
+ u,
702
+ scale,
703
+ initial_state,
704
+ output_final_state,
705
+ reverse,
706
+ cu_seqlens,
707
+ head_first
708
+ )
709
+ return o, final_state
fla/ops/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
fla/ops/rwkv7/__pycache__/channel_mixing.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
fla/ops/simple_gla/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simple GLA
2
+
3
+ Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).
4
+
5
+ Compared to GLA, the gating is head-wise instead of elementwise.
6
+ As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
7
+ It is faster than GLA but has less expressive power.
8
+ I will use it as a baseline for the GLA.
9
+
10
+ $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
fla/ops/simple_gla/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_simple_gla
4
+ from .fused_recurrent import fused_recurrent_simple_gla
5
+ from .parallel import parallel_simple_gla
6
+
7
+ __all__ = [
8
+ 'chunk_simple_gla',
9
+ 'fused_recurrent_simple_gla',
10
+ 'parallel_simple_gla'
11
+ ]
fla/ops/simple_gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
fla/ops/ttt/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_ttt_linear
4
+ from .fused_chunk import fused_chunk_ttt_linear
5
+
6
+ __all__ = [
7
+ 'fused_chunk_ttt_linear',
8
+ 'chunk_ttt_linear'
9
+ ]
fla/ops/ttt/chunk.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.modules.layernorm import group_norm
12
+ from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
18
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [1, 2, 4, 8]
26
+ ],
27
+ key=['BT', 'BK', 'BV']
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_ttt_linear_fwd_kernel_h(
31
+ k,
32
+ v,
33
+ v_new,
34
+ eta,
35
+ w,
36
+ b,
37
+ eps,
38
+ h,
39
+ hb,
40
+ h0,
41
+ hb0,
42
+ ht,
43
+ hbt,
44
+ offsets,
45
+ chunk_offsets,
46
+ T,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ V: tl.constexpr,
50
+ BT: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ BV: tl.constexpr,
53
+ NT: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ USE_INITIAL_STATE_B: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr,
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
67
+ else:
68
+ bos, eos = i_n * T, i_n * T + T
69
+ NT = tl.cdiv(T, BT)
70
+ boh = i_n * NT
71
+
72
+ # [BK, BV]
73
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
74
+ # [BV]
75
+ b_hb = tl.zeros([BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
79
+ if USE_INITIAL_STATE_B:
80
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
81
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
82
+
83
+ offs = tl.arange(0, BV)
84
+ b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.)
85
+ b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.)
86
+
87
+ for i_t in range(NT):
88
+ if HEAD_FIRST:
89
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
90
+ p_hb = tl.make_block_ptr(hb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
91
+ else:
92
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ p_hb = tl.make_block_ptr(hb + ((boh + i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
94
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
95
+ tl.store(p_hb, b_hb.to(p_hb.dtype.element_ty), boundary_check=(0,))
96
+ if HEAD_FIRST:
97
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
98
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
100
+ p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
103
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
104
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
105
+ p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
106
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
107
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
108
+
109
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
110
+ b_kh = tl.where((offs < V)[None, :], b_kh, 0.)
111
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
112
+ xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.)
113
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
114
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
115
+ b_kh_hat = (b_kh - mean) * rstd
116
+
117
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
118
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
119
+ b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
120
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
121
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
122
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
123
+ b_eta_last = tl.load(p_eta_last)
124
+ b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
125
+ b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
129
+ p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
130
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
131
+ tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,))
132
+
133
+
134
+ @triton.heuristics({
135
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
136
+ })
137
+ @triton.autotune(
138
+ configs=[
139
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
140
+ for num_warps in [2, 4, 8]
141
+ for num_stages in [2, 3]
142
+ ],
143
+ key=['BT'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def chunk_ttt_linear_fwd_kernel_o(
147
+ q,
148
+ k,
149
+ v,
150
+ eta,
151
+ h,
152
+ hb,
153
+ o,
154
+ offsets,
155
+ indices,
156
+ scale,
157
+ T,
158
+ H: tl.constexpr,
159
+ K: tl.constexpr,
160
+ V: tl.constexpr,
161
+ BT: tl.constexpr,
162
+ BK: tl.constexpr,
163
+ BV: tl.constexpr,
164
+ USE_OFFSETS: tl.constexpr,
165
+ HEAD_FIRST: tl.constexpr,
166
+ ):
167
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
168
+ i_b, i_h = i_bh // H, i_bh % H
169
+
170
+ if USE_OFFSETS:
171
+ i_tg = i_t
172
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
173
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
174
+ T = eos - bos
175
+ NT = tl.cdiv(T, BT)
176
+ else:
177
+ NT = tl.cdiv(T, BT)
178
+ i_tg = i_b * NT + i_t
179
+ bos, eos = i_b * T, i_b * T + T
180
+
181
+ # offset calculation
182
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
183
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
184
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
185
+ eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
186
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
187
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
188
+ hb += ((i_bh * NT + i_t) * V) if HEAD_FIRST else ((i_tg * H + i_h) * V)
189
+ stride_qk = K if HEAD_FIRST else H*K
190
+ stride_vo = V if HEAD_FIRST else H*V
191
+ stride_eta = 1 if HEAD_FIRST else H
192
+
193
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (0, i_t * BT), (BK, BT), (0, 1))
195
+ p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,))
196
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0))
197
+ p_hb = tl.make_block_ptr(hb, (V,), (1,), (i_v * BV,), (BV,), (0,))
198
+ # [BT, BK]
199
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
200
+ # [BK, BT]
201
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
202
+ # [BT, 1]
203
+ b_eta = tl.load(p_eta, boundary_check=(0,), padding_option="zero")
204
+ # [BK, BV]
205
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
206
+ # [BV]
207
+ b_hb = tl.load(p_hb, boundary_check=(0,), padding_option="zero")
208
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
209
+ b_o = tl.dot(b_q, b_h, allow_tf32=False)
210
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
211
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
212
+
213
+ o_i = tl.arange(0, BT)
214
+ m_A = o_i[:, None] >= o_i[None, :]
215
+ b_A = tl.where(m_A, b_A, 0)
216
+ b_Ae = tl.where(m_A, b_eta[:, None], 0.0)
217
+
218
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
219
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
220
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
221
+ b_o = (b_o - tl.dot(b_eta[:, None] * b_A.to(b_v.dtype), b_v, allow_tf32=False)) * scale
222
+ b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v.dtype), b_v, allow_tf32=False)
223
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
224
+
225
+
226
+ @triton.heuristics({
227
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
228
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
229
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
230
+ })
231
+ @triton.autotune(
232
+ configs=[
233
+ triton.Config({}, num_warps=num_warps)
234
+ for num_warps in [1, 2, 4, 8]
235
+ ],
236
+ key=['BT', 'BK', 'BV'],
237
+ )
238
+ @triton.jit(do_not_specialize=['T'])
239
+ def chunk_ttt_linear_bwd_kernel_h(
240
+ k,
241
+ v,
242
+ v_new,
243
+ eta,
244
+ w,
245
+ b,
246
+ eps,
247
+ h,
248
+ h0,
249
+ hb0,
250
+ x,
251
+ y,
252
+ r,
253
+ offsets,
254
+ chunk_offsets,
255
+ T,
256
+ H: tl.constexpr,
257
+ K: tl.constexpr,
258
+ V: tl.constexpr,
259
+ BT: tl.constexpr,
260
+ BK: tl.constexpr,
261
+ BV: tl.constexpr,
262
+ NT: tl.constexpr,
263
+ USE_INITIAL_STATE: tl.constexpr,
264
+ USE_INITIAL_STATE_B: tl.constexpr,
265
+ USE_OFFSETS: tl.constexpr,
266
+ HEAD_FIRST: tl.constexpr,
267
+ ):
268
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
269
+ i_n, i_h = i_nh // H, i_nh % H
270
+ if USE_OFFSETS:
271
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
272
+ T = eos - bos
273
+ NT = tl.cdiv(T, BT)
274
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
275
+ else:
276
+ bos, eos = i_n * T, i_n * T + T
277
+ NT = tl.cdiv(T, BT)
278
+ boh = i_n * NT
279
+
280
+ # [BK, BV]
281
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
282
+ # [BV]
283
+ b_hb = tl.zeros([BV], dtype=tl.float32)
284
+ if USE_INITIAL_STATE:
285
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
286
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
287
+ if USE_INITIAL_STATE_B:
288
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
289
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
290
+
291
+ offs = tl.arange(0, BV)
292
+ b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.)
293
+ b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.)
294
+
295
+ for i_t in range(NT):
296
+ if HEAD_FIRST:
297
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
298
+ else:
299
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
300
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
301
+ if HEAD_FIRST:
302
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
303
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
304
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
305
+ p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
306
+ p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
307
+ p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0))
308
+ p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
309
+ else:
310
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
311
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
312
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
313
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
314
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
315
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
316
+ p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
317
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
318
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
319
+
320
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
321
+ b_kh = tl.where((offs < V)[None, :], b_kh, 0.)
322
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
323
+ xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.)
324
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
325
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
326
+ b_kh_hat = (b_kh - mean) * rstd
327
+
328
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
329
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
330
+ b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
331
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
332
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
333
+ tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1))
334
+ tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1))
335
+ tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1))
336
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
337
+ b_eta_last = tl.load(p_eta_last)
338
+ b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
339
+ b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0)
340
+
341
+
342
+ @triton.heuristics({
343
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
344
+ })
345
+ @triton.autotune(
346
+ configs=[
347
+ triton.Config({}, num_warps=num_warps)
348
+ for num_warps in [4]
349
+ ],
350
+ key=['BT', 'BK', 'BV'],
351
+ )
352
+ @triton.jit(do_not_specialize=['T'])
353
+ def chunk_ttt_linear_bwd_kernel_dv_local(
354
+ q,
355
+ k,
356
+ eta,
357
+ do,
358
+ dv,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ H: tl.constexpr,
364
+ K: tl.constexpr,
365
+ V: tl.constexpr,
366
+ BT: tl.constexpr,
367
+ BK: tl.constexpr,
368
+ BV: tl.constexpr,
369
+ USE_OFFSETS: tl.constexpr,
370
+ HEAD_FIRST: tl.constexpr,
371
+ ):
372
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
373
+ i_b, i_h = i_bh // H, i_bh % H
374
+ if USE_OFFSETS:
375
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
376
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
377
+ T = eos - bos
378
+ else:
379
+ bos, eos = i_b * T, i_b * T + T
380
+
381
+ # offset calculation
382
+ q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
383
+ k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
384
+ eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
385
+ do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
386
+ dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
387
+ stride_qk = K if HEAD_FIRST else H*K
388
+ stride_vo = V if HEAD_FIRST else H*V
389
+ stride_eta = 1 if HEAD_FIRST else H
390
+
391
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
392
+ for i_k in range(tl.cdiv(K, BK)):
393
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
394
+ p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
395
+ b_q = tl.load(p_q, boundary_check=(0, 1))
396
+ b_k = tl.load(p_k, boundary_check=(0, 1))
397
+ b_A += tl.dot(b_k, b_q)
398
+
399
+ p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,))
400
+ b_eta = tl.load(p_eta, boundary_check=(0,))
401
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
402
+ b_A = - tl.where(mask, b_A * scale * b_eta[None, :], 0).to(do.dtype.element_ty)
403
+ b_Ae = - tl.where(mask, b_eta[None, :], 0).to(do.dtype.element_ty)
404
+
405
+ for i_v in range(tl.cdiv(V, BV)):
406
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
407
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
408
+ b_do = tl.load(p_do, boundary_check=(0, 1))
409
+ b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do)
410
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
411
+
412
+
413
+ @triton.heuristics({
414
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
415
+ 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None,
416
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
417
+ 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None,
418
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
419
+ })
420
+ @triton.autotune(
421
+ configs=[
422
+ triton.Config({}, num_warps=num_warps)
423
+ for num_warps in [2, 4, 8, 16]
424
+ ],
425
+ key=['BT', 'BK', 'BV'],
426
+ )
427
+ @triton.jit(do_not_specialize=['T'])
428
+ def chunk_ttt_linear_bwd_kernel_norm(
429
+ q,
430
+ k,
431
+ v,
432
+ v_new,
433
+ x,
434
+ y,
435
+ r,
436
+ w,
437
+ b,
438
+ eta,
439
+ h,
440
+ dht,
441
+ dhbt,
442
+ dh0,
443
+ dhb0,
444
+ do,
445
+ dh,
446
+ dhb,
447
+ dv,
448
+ dv_new,
449
+ dk,
450
+ dw,
451
+ db,
452
+ offsets,
453
+ chunk_offsets,
454
+ scale,
455
+ T,
456
+ H: tl.constexpr,
457
+ K: tl.constexpr,
458
+ V: tl.constexpr,
459
+ BT: tl.constexpr,
460
+ BK: tl.constexpr,
461
+ BV: tl.constexpr,
462
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
463
+ USE_FINAL_STATE_GRADIENT_B: tl.constexpr,
464
+ USE_INITIAL_STATE: tl.constexpr,
465
+ USE_INITIAL_STATE_B: tl.constexpr,
466
+ USE_OFFSETS: tl.constexpr,
467
+ HEAD_FIRST: tl.constexpr
468
+ ):
469
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
470
+ i_n, i_h = i_nh // H, i_nh % H
471
+ if USE_OFFSETS:
472
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
473
+ T = eos - bos
474
+ NT = tl.cdiv(T, BT)
475
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
476
+ else:
477
+ bos, eos = i_n * T, i_n * T + T
478
+ NT = tl.cdiv(T, BT)
479
+ boh = i_n * NT
480
+
481
+ # [BK, BV]
482
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
483
+ # [BV]
484
+ b_dhb = tl.zeros([BV], dtype=tl.float32)
485
+ if USE_FINAL_STATE_GRADIENT:
486
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
487
+ b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero")
488
+ if USE_FINAL_STATE_GRADIENT_B:
489
+ p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
490
+ b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero")
491
+
492
+ # [BV]
493
+ offs_v = tl.arange(0, BV)
494
+ offs_t = tl.arange(0, BT)
495
+ b_w = tl.load(w + i_h * V + offs_v, mask=offs_v < V, other=0.)
496
+ b_b = tl.load(b + i_h * V + offs_v, mask=offs_v < V, other=0.)
497
+ b_dw = tl.zeros([BV,], dtype=b_w.dtype)
498
+ b_db = tl.zeros([BV,], dtype=b_b.dtype)
499
+ p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
500
+ p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
501
+
502
+ for i_t in range(NT - 1, -1, -1):
503
+ if HEAD_FIRST:
504
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
505
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
506
+ p_dhb = tl.make_block_ptr(dhb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
507
+ else:
508
+ p_h = tl.make_block_ptr(h + ((boh+i_t) * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
509
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
510
+ p_dhb = tl.make_block_ptr(dhb + ((boh+i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
511
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
512
+ tl.store(p_dhb, b_dhb.to(p_dhb.dtype.element_ty), boundary_check=(0,))
513
+ if HEAD_FIRST:
514
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
515
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
516
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
517
+ p_v_new = tl.make_block_ptr(v_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
518
+ p_x = tl.make_block_ptr(x + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
519
+ p_y = tl.make_block_ptr(y + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
520
+ p_dv_new = tl.make_block_ptr(dv_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
521
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
522
+ p_dk = tl.make_block_ptr(dk + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
523
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
524
+ p_r = tl.make_block_ptr(r + i_nh * T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0))
525
+ p_eta_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1
526
+ else:
527
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
528
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
529
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
530
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
531
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
532
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
533
+ p_dv_new = tl.make_block_ptr(dv_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
534
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
535
+ p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, i_k * BK), (BT, BK), (1, 0))
536
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
537
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
538
+ p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
539
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
540
+ b_dv_new = tl.load(p_dv_new, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
541
+ b_eta_last = tl.load(p_eta_last)
542
+ b_dv_new -= tl.dot(b_eta_last * b_k, b_dh.to(b_k.dtype))
543
+ b_dv_new -= b_eta_last * b_dhb.to(b_k.dtype)[None, :]
544
+
545
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1), padding_option="zero")
546
+ b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
547
+ b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
548
+ b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
549
+ b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) -
550
+ b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
551
+ b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) +
552
+ b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
553
+ b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v_new.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True)
554
+
555
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
556
+ b_w = b_w.to(b_k.dtype)
557
+ b_b = b_b.to(b_k.dtype)
558
+ b_dv = -b_w * b_dy.to(b_k.dtype)
559
+ b_dk = b_w * b_dy.to(b_k.dtype)
560
+ b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) +
561
+ (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype)
562
+ b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype)
563
+ b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype)
564
+
565
+ # d_rstd, dx --> dkh --> dk, dh
566
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
567
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
568
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
569
+ b_q = (b_q * scale).to(b_q.dtype)
570
+ b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) -
571
+ b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V
572
+ b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V
573
+ b_dkh = tl.where((offs_v < V)[None, :] * (offs_t < T-i_t*BT)[:, None], b_dkh, 0.)
574
+ b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype)
575
+ b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh)
576
+ b_dhb += tl.sum(b_do + b_dkh, axis=0)
577
+ b_dh = tl.where((offs_v < V)[None, :], b_dh, 0.)
578
+ b_dhb = tl.where((offs_v < V), b_dhb, 0.)
579
+
580
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
581
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
582
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,))
583
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
584
+
585
+ if USE_INITIAL_STATE:
586
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
587
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
588
+ if USE_INITIAL_STATE_B:
589
+ p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (i_v * BV,), (BV,), (0,))
590
+ tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,))
591
+
592
+
593
+ @triton.heuristics({
594
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
595
+ })
596
+ @triton.autotune(
597
+ configs=[
598
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
599
+ for num_warps in [2, 4, 8]
600
+ for num_stages in [2, 3]
601
+ ],
602
+ key=['BT', 'BK', 'BV'],
603
+ )
604
+ @triton.jit(do_not_specialize=['T'])
605
+ def chunk_bwd_kernel_dqke(
606
+ q,
607
+ k,
608
+ v,
609
+ e,
610
+ h,
611
+ do,
612
+ dh,
613
+ dhb,
614
+ dq,
615
+ dk,
616
+ de,
617
+ offsets,
618
+ indices,
619
+ scale,
620
+ T,
621
+ B: tl.constexpr,
622
+ H: tl.constexpr,
623
+ K: tl.constexpr,
624
+ V: tl.constexpr,
625
+ BT: tl.constexpr,
626
+ BK: tl.constexpr,
627
+ BV: tl.constexpr,
628
+ USE_OFFSETS: tl.constexpr,
629
+ HEAD_FIRST: tl.constexpr,
630
+ ):
631
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
632
+ i_b, i_h = i_bh // H, i_bh % H
633
+ if USE_OFFSETS:
634
+ i_tg = i_t
635
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
636
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
637
+ T = eos - bos
638
+ NT = tl.cdiv(T, BT)
639
+ else:
640
+ NT = tl.cdiv(T, BT)
641
+ i_tg = i_b * NT + i_t
642
+ bos, eos = i_b * T, i_b * T + T
643
+
644
+ # offset calculation
645
+ v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
646
+ do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
647
+ h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V
648
+ dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V
649
+ dhb += (i_bh * NT + i_t) * V if HEAD_FIRST else (i_tg * H + i_h) * V
650
+ q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
651
+ k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
652
+ dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
653
+ dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
654
+ e += i_bh * T if HEAD_FIRST else (bos * H + i_h)
655
+ de += i_bh * T if HEAD_FIRST else (bos * H + i_h)
656
+ stride_qk = K if HEAD_FIRST else H*K
657
+ stride_vo = V if HEAD_FIRST else H*V
658
+ stride_e = 1 if HEAD_FIRST else H
659
+
660
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
661
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
662
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
663
+ b_de = tl.zeros([BT,], dtype=tl.float32)
664
+
665
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
666
+ b_k = tl.load(p_k, boundary_check=(0, 1))
667
+ p_e_last = (e + (i_t*BT+BT-1)*stride_e) if (i_t*BT+BT) <= T else (e + (T-1)*stride_e)
668
+ i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1)
669
+ mask = (tl.arange(0, BT) == i_last)
670
+ b_e_last = tl.load(p_e_last)
671
+
672
+ for i_v in range(tl.cdiv(V, BV)):
673
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
674
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
675
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
676
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
677
+ p_dhb = tl.make_block_ptr(dhb, (V,), (1,), (i_v * BV,), (BV,), (0,))
678
+ # [BT, BV]
679
+ b_v = tl.load(p_v, boundary_check=(0, 1))
680
+ b_do = tl.load(p_do, boundary_check=(0, 1))
681
+ # [BV, BK]
682
+ b_h = tl.load(p_h, boundary_check=(0, 1))
683
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
684
+ # [BV]
685
+ b_dhb = tl.load(p_dhb, boundary_check=(0,))
686
+ # [BT, BV] @ [BV, BT] -> [BT, BT]
687
+ b_ds += tl.dot(b_do, tl.trans(b_v))
688
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
689
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
690
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
691
+ b_dk -= b_e_last * tl.dot(b_v, b_dh.to(b_v.dtype))
692
+ b_de -= mask * tl.sum(tl.trans(b_dh) * tl.dot(tl.trans(b_k), b_v.to(b_k.dtype)))
693
+ b_de -= mask * tl.sum(b_dhb * tl.sum(b_v, axis=0).to(b_k.dtype))
694
+
695
+ o_i = tl.arange(0, BT)
696
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
697
+ p_e = tl.make_block_ptr(e, (T,), (stride_e,), (i_t * BT,), (BT,), (0,))
698
+ b_q = tl.load(p_q, boundary_check=(0, 1))
699
+ b_e = tl.load(p_e, boundary_check=(0,))
700
+
701
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
702
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
703
+ p_de = tl.make_block_ptr(de, (T,), (stride_e,), (i_t * BT,), (BT,), (0,))
704
+
705
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0)
706
+ b_ds = b_ds.to(b_k.dtype)
707
+ b_dq -= tl.dot(b_ds, b_k) * b_e[:, None]
708
+ b_dk -= tl.dot(tl.trans(b_ds), b_q * b_e[:, None]) * scale
709
+ b_de -= tl.sum(scale * tl.dot(b_ds, b_k) * b_q, axis=1)
710
+ b_de -= tl.sum(b_ds, axis=1)
711
+ b_dq *= scale
712
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
713
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
714
+ tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,))
715
+
716
+
717
+ def chunk_ttt_linear_fwd_h(
718
+ k: torch.Tensor,
719
+ v: torch.Tensor,
720
+ w: torch.Tensor,
721
+ b: torch.Tensor,
722
+ eta: torch.Tensor,
723
+ eps: float,
724
+ initial_state: Optional[torch.Tensor] = None,
725
+ initial_state_bias: Optional[torch.Tensor] = None,
726
+ output_final_state: bool = False,
727
+ offsets: Optional[torch.LongTensor] = None,
728
+ indices: Optional[torch.LongTensor] = None,
729
+ head_first: bool = True,
730
+ chunk_size: int = 16,
731
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
732
+ if head_first:
733
+ B, H, T, K, V = *k.shape, v.shape[-1]
734
+ else:
735
+ B, T, H, K, V = *k.shape, v.shape[-1]
736
+ BT = chunk_size
737
+ # N: the actual number of sequences in the batch with either equal or variable lengths
738
+ if offsets is None:
739
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
740
+ else:
741
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
742
+ BK = triton.next_power_of_2(K)
743
+ BV = triton.next_power_of_2(V)
744
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
745
+ NK = triton.cdiv(K, BK)
746
+ NV = triton.cdiv(V, BV)
747
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
748
+ assert NV == 1, 'NV > 1 is not supported by TTT update rule.'
749
+
750
+ if head_first:
751
+ h = k.new_empty(B, H, NT, K, V)
752
+ hb = k.new_empty(B, H, NT, 1, V)
753
+ else:
754
+ h = k.new_empty(B, NT, H, K, V)
755
+ hb = k.new_empty(B, NT, H, 1, V)
756
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
757
+ final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None
758
+
759
+ v_new = torch.empty_like(v)
760
+ grid = (NK, NV, N * H)
761
+
762
+ chunk_ttt_linear_fwd_kernel_h[grid](
763
+ k=k,
764
+ v=v,
765
+ v_new=v_new,
766
+ eta=eta,
767
+ w=w,
768
+ b=b,
769
+ eps=eps,
770
+ h=h,
771
+ hb=hb,
772
+ h0=initial_state,
773
+ hb0=initial_state_bias,
774
+ ht=final_state,
775
+ hbt=final_state_bias,
776
+ offsets=offsets,
777
+ chunk_offsets=chunk_offsets,
778
+ T=T,
779
+ H=H,
780
+ K=K,
781
+ V=V,
782
+ BT=BT,
783
+ BK=BK,
784
+ BV=BV,
785
+ NT=NT,
786
+ HEAD_FIRST=head_first
787
+ )
788
+ return h, hb, v_new, final_state, final_state_bias
789
+
790
+
791
+ def chunk_ttt_linear_fwd_o(
792
+ q: torch.Tensor,
793
+ k: torch.Tensor,
794
+ v: torch.Tensor,
795
+ eta: torch.Tensor,
796
+ h: torch.Tensor,
797
+ hb: torch.Tensor,
798
+ scale: Optional[float] = None,
799
+ offsets: Optional[torch.LongTensor] = None,
800
+ indices: Optional[torch.LongTensor] = None,
801
+ head_first: bool = True,
802
+ chunk_size: int = 64
803
+ ) -> torch.Tensor:
804
+ if head_first:
805
+ B, H, T, K, V = *q.shape, v.shape[-1]
806
+ else:
807
+ B, T, H, K, V = *q.shape, v.shape[-1]
808
+ if scale is None:
809
+ scale = k.shape[-1] ** -0.5
810
+ BT = chunk_size
811
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
812
+ BK = triton.next_power_of_2(K)
813
+ BV = triton.next_power_of_2(V)
814
+ NK = triton.cdiv(K, BK)
815
+ NV = triton.cdiv(V, BV)
816
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
817
+ assert NV == 1, 'NV > 1 is not supported by TTT update rule.'
818
+
819
+ o = torch.empty_like(v)
820
+
821
+ grid = (NV, NT, B * H)
822
+ chunk_ttt_linear_fwd_kernel_o[grid](
823
+ q,
824
+ k,
825
+ v,
826
+ eta,
827
+ h,
828
+ hb,
829
+ o,
830
+ offsets,
831
+ indices,
832
+ scale,
833
+ T=T,
834
+ H=H,
835
+ K=K,
836
+ V=V,
837
+ BT=BT,
838
+ BK=BK,
839
+ BV=BV,
840
+ HEAD_FIRST=head_first
841
+ )
842
+ return o
843
+
844
+
845
+ def chunk_ttt_linear_bwd_h(
846
+ k: torch.Tensor,
847
+ v: torch.Tensor,
848
+ w: torch.Tensor,
849
+ b: torch.Tensor,
850
+ eta: torch.Tensor,
851
+ eps: float,
852
+ initial_state: Optional[torch.Tensor] = None,
853
+ initial_state_bias: Optional[torch.Tensor] = None,
854
+ offsets: Optional[torch.LongTensor] = None,
855
+ indices: Optional[torch.LongTensor] = None,
856
+ head_first: bool = True,
857
+ chunk_size: int = 16,
858
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
859
+ if head_first:
860
+ B, H, T, K, V = *k.shape, v.shape[-1]
861
+ else:
862
+ B, T, H, K, V = *k.shape, v.shape[-1]
863
+ BT = chunk_size
864
+ # N: the actual number of sequences in the batch with either equal or variable lengths
865
+ if offsets is None:
866
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
867
+ else:
868
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
869
+ BK = triton.next_power_of_2(K)
870
+ BV = triton.next_power_of_2(V)
871
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
872
+ NK = triton.cdiv(K, BK)
873
+ NV = triton.cdiv(V, BV)
874
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
875
+ assert NV == 1, 'NV > 1 is not supported by TTT update rule.'
876
+
877
+ if head_first:
878
+ h = k.new_empty(B, H, NT, K, V)
879
+ rstd = v.new_empty(B, H, T, 1, dtype=torch.float32)
880
+ else:
881
+ h = k.new_empty(B, NT, H, K, V)
882
+ rstd = v.new_empty(B, T, H, 1, dtype=torch.float32)
883
+ x = torch.empty_like(v)
884
+ y = torch.empty_like(v)
885
+
886
+ v_new = torch.empty_like(v)
887
+ grid = (NK, NV, N * H)
888
+
889
+ chunk_ttt_linear_bwd_kernel_h[grid](
890
+ k=k,
891
+ v=v,
892
+ v_new=v_new,
893
+ eta=eta,
894
+ w=w,
895
+ b=b,
896
+ eps=eps,
897
+ h=h,
898
+ h0=initial_state,
899
+ hb0=initial_state_bias,
900
+ x=x,
901
+ y=y,
902
+ r=rstd,
903
+ offsets=offsets,
904
+ chunk_offsets=chunk_offsets,
905
+ T=T,
906
+ H=H,
907
+ K=K,
908
+ V=V,
909
+ BT=BT,
910
+ BK=BK,
911
+ BV=BV,
912
+ NT=NT,
913
+ HEAD_FIRST=head_first
914
+ )
915
+ return h, v_new, x, y, rstd
916
+
917
+
918
+ def chunk_ttt_linear_bwd_dv_local(
919
+ q: torch.Tensor,
920
+ k: torch.Tensor,
921
+ eta: torch.Tensor,
922
+ do: torch.Tensor,
923
+ scale: float,
924
+ offsets: Optional[torch.LongTensor] = None,
925
+ indices: Optional[torch.LongTensor] = None,
926
+ head_first: bool = True,
927
+ chunk_size: int = 16
928
+ ) -> torch.Tensor:
929
+ if head_first:
930
+ B, H, T, K, V = *k.shape, do.shape[-1]
931
+ else:
932
+ B, T, H, K, V = *k.shape, do.shape[-1]
933
+ BT = chunk_size
934
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
935
+ BK = min(triton.next_power_of_2(K), 128)
936
+ BV = min(triton.next_power_of_2(V), 128)
937
+
938
+ dv = torch.empty_like(do)
939
+ grid = (NT, B * H)
940
+ chunk_ttt_linear_bwd_kernel_dv_local[grid](
941
+ q,
942
+ k,
943
+ eta,
944
+ do,
945
+ dv,
946
+ offsets,
947
+ indices,
948
+ scale,
949
+ T=T,
950
+ H=H,
951
+ K=K,
952
+ V=V,
953
+ BT=BT,
954
+ BK=BK,
955
+ BV=BV,
956
+ HEAD_FIRST=head_first
957
+ )
958
+ return dv
959
+
960
+
961
+ def chunk_ttt_linear_bwd_norm(
962
+ q: torch.Tensor, # [B, H, L, D]
963
+ k: torch.Tensor, # [B, H, L, D]
964
+ v: torch.Tensor, # [B, H, L, D]
965
+ v_new: torch.Tensor, # [B, H, L, D]
966
+ x: torch.Tensor, # [B, H, L, D]
967
+ y: torch.Tensor, # [B, H, L, D]
968
+ rstd: torch.Tensor, # [B, H, L, 1]
969
+ w: torch.Tensor, # [H, D]
970
+ b: torch.Tensor, # [H, D]
971
+ eta: torch.Tensor, # [B, H, L, 1]
972
+ h0: torch.Tensor, # [B, H, D, D]
973
+ hb0: torch.Tensor, # [B, H, 1, D]
974
+ h: torch.Tensor, # [B, H, NT, D, D]
975
+ dht: Optional[torch.Tensor], # [B, H, D, D]
976
+ dhbt: Optional[torch.Tensor], # [B, H, 1, D]
977
+ dv_new: Optional[torch.Tensor], # [B, H, L, D]
978
+ do: torch.Tensor, # [B, H, L, D]
979
+ scale: float,
980
+ offsets: Optional[torch.LongTensor] = None,
981
+ indices: Optional[torch.LongTensor] = None,
982
+ head_first: bool = True,
983
+ chunk_size: int = 16
984
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
985
+ # torch implementation of `dkh, dw, db, dk, dv` for LN^2
986
+ assert offsets is None, "bwd of varlen is not implemented yet."
987
+ if head_first:
988
+ B, H, T, K, V = *q.shape, do.shape[-1]
989
+ else:
990
+ B, T, H, K, V = *q.shape, do.shape[-1]
991
+ BT = chunk_size
992
+ if offsets is None:
993
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
994
+ else:
995
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
996
+
997
+ BK = triton.next_power_of_2(K)
998
+ BV = triton.next_power_of_2(V)
999
+ NK = triton.cdiv(K, BK)
1000
+ NV = triton.cdiv(V, BV)
1001
+ assert NK == 1, 'NK > 1 is not supported by TTT.'
1002
+ assert NV == 1, 'NV > 1 is not supported by TTT.'
1003
+
1004
+ if head_first:
1005
+ dh = q.new_empty(B, H, NT, K, V)
1006
+ dhb = q.new_empty(B, H, NT, 1, V)
1007
+ else:
1008
+ dh = q.new_empty(B, NT, H, K, V)
1009
+ dhb = q.new_empty(B, NT, H, 1, V)
1010
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
1011
+ dhb0 = torch.empty_like(hb0, dtype=torch.float32) if hb0 is not None else None
1012
+ dv = torch.empty_like(v)
1013
+ dk = torch.empty_like(k)
1014
+ dw = w.new_empty(B, H, V)
1015
+ db = b.new_empty(B, H, V)
1016
+
1017
+ grid = (NK, NV, N * H)
1018
+ chunk_ttt_linear_bwd_kernel_norm[grid](
1019
+ q=q,
1020
+ k=k,
1021
+ v=v,
1022
+ v_new=v_new,
1023
+ x=x,
1024
+ y=y,
1025
+ r=rstd,
1026
+ w=w,
1027
+ b=b,
1028
+ eta=eta,
1029
+ h=h,
1030
+ dht=dht,
1031
+ dhbt=dhbt,
1032
+ dh0=dh0,
1033
+ dhb0=dhb0,
1034
+ do=do,
1035
+ dh=dh,
1036
+ dhb=dhb,
1037
+ dv=dv,
1038
+ dv_new=dv_new,
1039
+ dk=dk,
1040
+ dw=dw,
1041
+ db=db,
1042
+ offsets=offsets,
1043
+ chunk_offsets=chunk_offsets,
1044
+ scale=scale,
1045
+ T=T,
1046
+ H=H,
1047
+ K=K,
1048
+ V=V,
1049
+ BT=BT,
1050
+ BK=BK,
1051
+ BV=BV,
1052
+ HEAD_FIRST=head_first
1053
+ )
1054
+ dw = dw.sum(dim=0)
1055
+ db = db.sum(dim=0)
1056
+ return dh, dhb, dh0, dhb0, dv, dk, dw, db
1057
+
1058
+
1059
+ def chunk_ttt_linear_bwd_norm_ref(
1060
+ q: torch.Tensor, # [B, H, L, D]
1061
+ k: torch.Tensor, # [B, H, L, D]
1062
+ v: torch.Tensor, # [B, H, L, D]
1063
+ v_new: torch.Tensor, # [B, H, L, D]
1064
+ kh: torch.Tensor, # [B, H, L, D]
1065
+ y: torch.Tensor, # [B, H, L, D]
1066
+ w: torch.Tensor, # [H, D]
1067
+ b: torch.Tensor, # [H, D]
1068
+ eta: torch.Tensor, # [B, H, L, 1]
1069
+ h0: torch.Tensor, # [B, H, D, D]
1070
+ h: torch.Tensor, # [B, H, NT, D, D]
1071
+ dht: Optional[torch.Tensor], # [B, H, D, D]
1072
+ dv_new: Optional[torch.Tensor], # [B, H, L, D]
1073
+ do: torch.Tensor, # [B, H, L, D]
1074
+ scale: float,
1075
+ eps: float,
1076
+ offsets: Optional[torch.LongTensor] = None,
1077
+ indices: Optional[torch.LongTensor] = None,
1078
+ head_first: bool = True,
1079
+ chunk_size: int = 16
1080
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1081
+ # torch implementation of `dkh, dw, db, dk, dv` for LN^2
1082
+ assert offsets is None, "bwd of varlen is not implemented yet."
1083
+ if head_first:
1084
+ B, H, T, K, V = *q.shape, do.shape[-1]
1085
+ else:
1086
+ B, T, H, K, V = *q.shape, do.shape[-1]
1087
+ # [B, L, H, D] -> [B, H, L, D]
1088
+ q, k, v, v_new, kh, y, h, eta, dv_new, do = [
1089
+ x.transpose(1, 2) for x in
1090
+ [q, k, v, v_new, kh, y, h, eta, dv_new, do]
1091
+ ]
1092
+ BT = chunk_size
1093
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1094
+ pad_len = (BT - (T % BT)) % BT
1095
+ if pad_len > 0:
1096
+ q, k, v, v_new, kh, y, eta, dv_new, do = [
1097
+ F.pad(x, (0, 0, 0, pad_len)) for x in
1098
+ [q, k, v, v_new, kh, y, eta, dv_new, do]
1099
+ ]
1100
+ eta[:, :, -1, :] = eta[:, :, -(pad_len+1), :]
1101
+ # [NT, B, H, BT, D]
1102
+ q, k, v, v_new, kh, y, eta, dv_new, do = [
1103
+ x.reshape(B, H, NT, BT, -1).permute(2, 0, 1, 3, 4) for x in
1104
+ [q, k, v, v_new, kh, y, eta, dv_new, do]
1105
+ ]
1106
+ h = h.permute(2, 0, 1, 3, 4)
1107
+
1108
+ # allocate
1109
+ dh = q.new_zeros(NT, B, H, K, V)
1110
+ dv = torch.zeros_like(v)
1111
+ dk = torch.zeros_like(k)
1112
+ dw = torch.zeros_like(w)
1113
+ db = torch.zeros_like(b)
1114
+ # recurrent state
1115
+ b_dh = dht if dht is not None else torch.zeros_like(dh[0])
1116
+ b_dh = b_dh.to(torch.float32)
1117
+
1118
+ # [H, 1, D]
1119
+ _w = w.reshape(H, 1, V).to(torch.float32)
1120
+ _b = b.reshape(H, 1, V).to(torch.float32)
1121
+
1122
+ # d_state passing
1123
+ for i_t in range(NT - 1, -1, -1):
1124
+ dh[i_t] = b_dh.to(dh.dtype)
1125
+ # [B, H, BT, D]
1126
+ _q, _k, _v, _v_new, _kh, _y, _h, _eta, _dv_new, _do = [
1127
+ x[i_t].to(torch.float32) for x in
1128
+ (q, k, v, v_new, kh, y, h, eta, dv_new, do)
1129
+ ]
1130
+ _dv_new -= (_eta[:, :, -1, :, None] * _k) @ b_dh
1131
+
1132
+ mean = _kh.mean(dim=-1, keepdim=True)
1133
+ var = _kh.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
1134
+ rstd = 1 / torch.sqrt(var + eps).to(torch.float32)
1135
+ x = (_kh - mean) * rstd
1136
+ # [B, H, BT, D]
1137
+ dy = rstd * (_dv_new*V - _dv_new.sum(dim=-1, keepdim=True) - x*(x*_dv_new).sum(dim=-1, keepdim=True)) / V
1138
+ dx = -rstd * (_dv_new*(x*_y).sum(dim=-1, keepdim=True) + _y*(x*_dv_new).sum(dim=-1, keepdim=True)) / V
1139
+ d_rstd = (_dv_new * _v_new / rstd).sum(dim=-1, keepdim=True)
1140
+
1141
+ dv[i_t] = (-_w*dy).to(dv.dtype)
1142
+ dk[i_t] += (_w*dy).to(dk.dtype)
1143
+ dw += (2*_w*x*dy+(_b-_v+_k)*dy).sum(dim=(0, 2)).to(dw.dtype)
1144
+ db += (_w*dy).sum(dim=(0, 2)).to(db.dtype)
1145
+ dx += _w*_w*dy
1146
+
1147
+ # d_rstd, dx --> dkh --> dk, dh
1148
+ dkh = rstd * (V * dx - dx.sum(dim=-1, keepdim=True) - x * (x * dx).sum(dim=-1, keepdim=True)) / V
1149
+ dkh -= rstd**2 * d_rstd * x / V
1150
+ dk[i_t] += (dkh @ _h.transpose(-2, -1)).to(dk.dtype)
1151
+ b_dh += (_q.transpose(-2, -1) * scale) @ _do + _k.transpose(-2, -1) @ dkh
1152
+ dh0 = b_dh.to(torch.float32) if h0 is not None else None
1153
+
1154
+ # [NT, B, H, BT, D] -> [B, H, T, D]
1155
+ dv = dv.permute(1, 2, 0, 3, 4).reshape(B, H, -1, V)[:, :, :T, :]
1156
+ dk = dk.permute(1, 2, 0, 3, 4).reshape(B, H, -1, K)[:, :, :T, :]
1157
+ # [B, H, NT, D, D]
1158
+ dh = dh.permute(1, 2, 0, 3, 4)
1159
+ if not head_first:
1160
+ dv, dk, dh = [x.transpose(1, 2) for x in (dv, dk, dh)]
1161
+ dh, dv, dk, dw, db = [x.contiguous() for x in (dh, dv, dk, dw, db)]
1162
+ dh0 = dh0.contiguous() if h0 is not None else None
1163
+ return dh, dh0, dv, dk, dw, db
1164
+
1165
+
1166
+ def chunk_ttt_linear_bwd_dqke(
1167
+ q: torch.Tensor,
1168
+ k: torch.Tensor,
1169
+ v: torch.Tensor,
1170
+ eta: torch.Tensor,
1171
+ h: torch.Tensor,
1172
+ do: torch.Tensor,
1173
+ dh: torch.Tensor,
1174
+ dhb: torch.Tensor,
1175
+ scale: float,
1176
+ offsets: Optional[torch.LongTensor] = None,
1177
+ indices: Optional[torch.LongTensor] = None,
1178
+ head_first: bool = True,
1179
+ chunk_size: int = 16,
1180
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1181
+
1182
+ if head_first:
1183
+ B, H, T, K, V = *k.shape, v.shape[-1]
1184
+ else:
1185
+ B, T, H, K, V = *k.shape, v.shape[-1]
1186
+ BT = chunk_size
1187
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1188
+
1189
+ BK = triton.next_power_of_2(K)
1190
+ BV = min(triton.next_power_of_2(V), 64)
1191
+ NK = triton.cdiv(K, BK)
1192
+ assert NK == 1, "NK > 1 is not supported."
1193
+
1194
+ dq = torch.empty_like(q)
1195
+ dk = torch.empty_like(k)
1196
+ de = torch.empty_like(eta)
1197
+ grid = (NK, NT, B * H)
1198
+
1199
+ chunk_bwd_kernel_dqke[grid](
1200
+ q=q,
1201
+ k=k,
1202
+ v=v,
1203
+ e=eta,
1204
+ h=h,
1205
+ do=do,
1206
+ dh=dh,
1207
+ dhb=dhb,
1208
+ dq=dq,
1209
+ dk=dk,
1210
+ de=de,
1211
+ offsets=offsets,
1212
+ indices=indices,
1213
+ scale=scale,
1214
+ B=B,
1215
+ T=T,
1216
+ H=H,
1217
+ K=K,
1218
+ V=V,
1219
+ BT=BT,
1220
+ BK=BK,
1221
+ BV=BV,
1222
+ HEAD_FIRST=head_first
1223
+ )
1224
+ return dq, dk, de
1225
+
1226
+
1227
+ def chunk_ttt_linear_fwd(
1228
+ q: torch.Tensor,
1229
+ k: torch.Tensor,
1230
+ v: torch.Tensor,
1231
+ w: torch.Tensor,
1232
+ b: torch.Tensor,
1233
+ eta: torch.Tensor,
1234
+ scale: float,
1235
+ eps: float,
1236
+ initial_state: torch.Tensor,
1237
+ initial_state_bias: torch.Tensor,
1238
+ output_final_state: bool,
1239
+ offsets: Optional[torch.LongTensor] = None,
1240
+ indices: Optional[torch.LongTensor] = None,
1241
+ head_first: bool = True,
1242
+ BT: int = 16
1243
+ ):
1244
+ h, hb, v_new, final_state, final_state_bias = chunk_ttt_linear_fwd_h(
1245
+ k=k,
1246
+ v=v,
1247
+ w=w,
1248
+ b=b,
1249
+ eta=eta,
1250
+ eps=eps,
1251
+ initial_state=initial_state,
1252
+ initial_state_bias=initial_state_bias,
1253
+ output_final_state=output_final_state,
1254
+ offsets=offsets,
1255
+ indices=indices,
1256
+ head_first=head_first,
1257
+ chunk_size=BT
1258
+ )
1259
+ o = chunk_ttt_linear_fwd_o(
1260
+ q=q,
1261
+ k=k,
1262
+ v=v_new,
1263
+ eta=eta,
1264
+ h=h,
1265
+ hb=hb,
1266
+ scale=scale,
1267
+ offsets=offsets,
1268
+ indices=indices,
1269
+ head_first=head_first,
1270
+ chunk_size=BT
1271
+ )
1272
+ return o, final_state, final_state_bias
1273
+
1274
+
1275
+ def chunk_ttt_linear_bwd(
1276
+ q: torch.Tensor,
1277
+ k: torch.Tensor,
1278
+ v: torch.Tensor,
1279
+ w: torch.Tensor,
1280
+ b: torch.Tensor,
1281
+ eta: torch.Tensor,
1282
+ scale: float,
1283
+ eps: float,
1284
+ do: torch.Tensor,
1285
+ dht: torch.Tensor,
1286
+ dhbt: torch.Tensor,
1287
+ BT: int = 16,
1288
+ initial_state: torch.Tensor = None,
1289
+ initial_state_bias: torch.Tensor = None,
1290
+ offsets: Optional[torch.LongTensor] = None,
1291
+ indices: Optional[torch.LongTensor] = None,
1292
+ head_first: bool = True
1293
+ ):
1294
+ h, v_new, x, y, rstd = chunk_ttt_linear_bwd_h(
1295
+ k=k,
1296
+ v=v,
1297
+ w=w,
1298
+ b=b,
1299
+ eta=eta,
1300
+ eps=eps,
1301
+ initial_state=initial_state,
1302
+ initial_state_bias=initial_state_bias,
1303
+ offsets=offsets,
1304
+ indices=indices,
1305
+ head_first=head_first,
1306
+ chunk_size=BT
1307
+ )
1308
+ dv_new = chunk_ttt_linear_bwd_dv_local(
1309
+ q=q,
1310
+ k=k,
1311
+ eta=eta,
1312
+ do=do,
1313
+ scale=scale,
1314
+ offsets=offsets,
1315
+ indices=indices,
1316
+ head_first=head_first,
1317
+ chunk_size=BT
1318
+ )
1319
+ dh, dhb, dh0, dhb0, dv, dk, dw, db = chunk_ttt_linear_bwd_norm(
1320
+ q=q,
1321
+ k=k,
1322
+ v=v,
1323
+ v_new=v_new,
1324
+ x=x,
1325
+ y=y,
1326
+ rstd=rstd,
1327
+ w=w,
1328
+ b=b,
1329
+ eta=eta,
1330
+ h0=initial_state,
1331
+ hb0=initial_state_bias,
1332
+ h=h,
1333
+ dht=dht,
1334
+ dhbt=dhbt,
1335
+ dv_new=dv_new,
1336
+ do=do,
1337
+ scale=scale,
1338
+ offsets=offsets,
1339
+ indices=indices,
1340
+ head_first=head_first,
1341
+ chunk_size=BT
1342
+ )
1343
+ dq, dk2, de = chunk_ttt_linear_bwd_dqke(
1344
+ q=q,
1345
+ k=k,
1346
+ v=v_new,
1347
+ eta=eta,
1348
+ h=h,
1349
+ do=do,
1350
+ dh=dh,
1351
+ dhb=dhb,
1352
+ scale=scale,
1353
+ offsets=offsets,
1354
+ indices=indices,
1355
+ head_first=head_first,
1356
+ chunk_size=BT
1357
+ )
1358
+ dk.add_(dk2)
1359
+ return dq, dk, dv, de, dw, db, dh0, dhb0
1360
+
1361
+
1362
+ class ChunkTTTLinearFunction(torch.autograd.Function):
1363
+
1364
+ @staticmethod
1365
+ @input_guard
1366
+ @autocast_custom_fwd
1367
+ def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state,
1368
+ initial_state_bias, output_final_state, offsets, head_first):
1369
+ # 2-d indices denoting the offsets of chunks in each sequence
1370
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1371
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1372
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1373
+ indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None
1374
+ o, final_state, final_state_bias = chunk_ttt_linear_fwd(
1375
+ q=q,
1376
+ k=k,
1377
+ v=v,
1378
+ w=w,
1379
+ b=b,
1380
+ eta=eta,
1381
+ scale=scale,
1382
+ eps=eps,
1383
+ BT=BT,
1384
+ initial_state=initial_state,
1385
+ initial_state_bias=initial_state_bias,
1386
+ output_final_state=output_final_state,
1387
+ offsets=offsets,
1388
+ indices=indices,
1389
+ head_first=head_first,
1390
+ )
1391
+ ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias)
1392
+ ctx.BT = BT
1393
+ ctx.scale = scale
1394
+ ctx.eps = eps
1395
+ ctx.offsets = offsets
1396
+ ctx.indices = indices
1397
+ ctx.head_first = head_first
1398
+ return o.to(q.dtype), final_state, final_state_bias
1399
+
1400
+ @staticmethod
1401
+ @input_guard
1402
+ @autocast_custom_bwd
1403
+ def backward(ctx, do, dht, dhbt):
1404
+ q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors
1405
+ dq, dk, dv, de, dw, db, dh0, dhb0 = chunk_ttt_linear_bwd(
1406
+ q=q,
1407
+ k=k,
1408
+ v=v,
1409
+ w=w,
1410
+ b=b,
1411
+ eta=eta,
1412
+ scale=ctx.scale,
1413
+ eps=ctx.eps,
1414
+ do=do,
1415
+ dht=dht,
1416
+ dhbt=dhbt,
1417
+ BT=ctx.BT,
1418
+ initial_state=initial_state,
1419
+ initial_state_bias=initial_state_bias,
1420
+ offsets=ctx.offsets,
1421
+ indices=ctx.indices,
1422
+ head_first=ctx.head_first
1423
+ )
1424
+ return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None
1425
+
1426
+
1427
+ def norm_residual(x, weight, bias, eps, head_first):
1428
+ # GroupNorm and Residual
1429
+ if head_first:
1430
+ B, H, T, D = x.shape
1431
+ x = x.transpose(1, 2)
1432
+ x += group_norm(
1433
+ x.reshape(B, T, -1).clone(),
1434
+ weight=weight.reshape(-1).clone(),
1435
+ bias=bias.reshape(-1).clone(),
1436
+ eps=eps,
1437
+ num_groups=H,
1438
+ ).reshape(x.shape)
1439
+ x = x.transpose(1, 2)
1440
+ else:
1441
+ B, T, H, D = x.shape
1442
+ x += group_norm(
1443
+ x.reshape(B, T, -1).clone(),
1444
+ weight=weight.reshape(-1).clone(),
1445
+ bias=bias.reshape(-1).clone(),
1446
+ eps=eps,
1447
+ num_groups=H,
1448
+ ).reshape(x.shape)
1449
+ return x
1450
+
1451
+
1452
+ def chunk_ttt_linear(
1453
+ q: torch.Tensor,
1454
+ k: torch.Tensor,
1455
+ v: torch.Tensor,
1456
+ w: torch.Tensor,
1457
+ b: torch.Tensor,
1458
+ eta: torch.Tensor,
1459
+ scale: float = None,
1460
+ eps: float = 1e-6,
1461
+ chunk_size: int = 16,
1462
+ initial_state: torch.Tensor = None,
1463
+ initial_state_bias: torch.Tensor = None,
1464
+ output_final_state: bool = False,
1465
+ cu_seqlens: Optional[torch.LongTensor] = None,
1466
+ head_first: bool = True,
1467
+ ):
1468
+ r"""
1469
+ Args:
1470
+ q (torch.Tensor):
1471
+ queries of shape `(B, H, T, K)`
1472
+ k (torch.Tensor):
1473
+ keys of shape `(B, H, T, K)`
1474
+ v (torch.Tensor):
1475
+ values of shape `(B, H, T, V)`
1476
+ w (torch.Tensor):
1477
+ layer norm weight of shape `(H, V)`
1478
+ b (torch.Tensor):
1479
+ layer norm bias of shape `(H, V)`
1480
+ eta (torch.Tensor):
1481
+ Learning rate for hidden state, of shape `(B, H, T, 1)`.
1482
+ scale (Optional[int]):
1483
+ Scale factor for the RetNet attention scores.
1484
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1485
+ chunk_size (int):
1486
+ chunk size. Default: `16`.
1487
+ initial_state (Optional[torch.Tensor]):
1488
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
1489
+ initial_state_bias (Optional[torch.Tensor]):
1490
+ Initial state bias of shape `(B, H, 1, V)`. Default: `None`.
1491
+ output_final_state (Optional[bool]):
1492
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
1493
+ cu_seqlens (torch.LongTensor):
1494
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1495
+ consistent with the FlashAttention API.
1496
+ head_first (Optional[bool]):
1497
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1498
+ Default: `True`.
1499
+ Returns:
1500
+ o (torch.Tensor):
1501
+ Outputs of shape `[B, H, T, V]`
1502
+ final_state (torch.Tensor):
1503
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
1504
+ """
1505
+ assert q.dtype == k.dtype == v.dtype
1506
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
1507
+ if isinstance(eta, float):
1508
+ eta = torch.full_like(q[:, :, :, :1], eta)
1509
+ if cu_seqlens is not None:
1510
+ if q.shape[0] != 1:
1511
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1512
+ f"Please flatten variable-length inputs before processing.")
1513
+ if head_first:
1514
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1515
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
1516
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1517
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
1518
+ if scale is None:
1519
+ scale = k.shape[-1] ** -0.5
1520
+ else:
1521
+ assert scale > 0, "Scale must be positive."
1522
+ o, final_state, final_state_bias = ChunkTTTLinearFunction.apply(
1523
+ q,
1524
+ k,
1525
+ v,
1526
+ w,
1527
+ b,
1528
+ chunk_size,
1529
+ eta,
1530
+ scale,
1531
+ eps,
1532
+ initial_state,
1533
+ initial_state_bias,
1534
+ output_final_state,
1535
+ cu_seqlens,
1536
+ head_first,
1537
+ )
1538
+ o = norm_residual(o, w, b, eps, head_first)
1539
+ return o, final_state, final_state_bias
fla/ops/ttt/fused_chunk.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.modules.layernorm import group_norm
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=1),
23
+ triton.Config({}, num_warps=2),
24
+ triton.Config({}, num_warps=4)
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fused_chunk_ttt_linear_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ eta,
34
+ w,
35
+ b,
36
+ o,
37
+ scale,
38
+ eps,
39
+ h0,
40
+ hb0,
41
+ ht,
42
+ hbt,
43
+ offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ USE_INITIAL_STATE_B: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # indices
58
+ i_nh = tl.program_id(0)
59
+ i_n, i_h = i_nh // H, i_nh % H
60
+ if USE_OFFSETS:
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+
68
+ o_i = tl.arange(0, BT)
69
+ v_i = tl.arange(0, BV)
70
+ m_A = o_i[:, None] >= o_i[None, :]
71
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
72
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # [BV]
77
+ b_hb = tl.zeros([BV], dtype=tl.float32)
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
80
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
81
+ if USE_INITIAL_STATE_B:
82
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
83
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
84
+
85
+ for i_t in range(NT):
86
+ if HEAD_FIRST:
87
+ p_q = tl.make_block_ptr(q+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
88
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
90
+ p_o = tl.make_block_ptr(o+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
91
+ p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
92
+ p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
93
+ else:
94
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
95
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
98
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
99
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
100
+ # [BK, BT]
101
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
102
+ # [BT, BV]
103
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
104
+
105
+ # [BT, BV]
106
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
107
+ b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
108
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
109
+ xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
110
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
111
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
112
+ b_kh_hat = (b_kh - mean) * rstd
113
+
114
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
115
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
116
+ b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
117
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
118
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
119
+
120
+ # [BT, BK]
121
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
122
+ # [BT]
123
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
124
+ b_q = (b_q * scale).to(b_k.dtype)
125
+
126
+ # [BT, BT]
127
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
128
+ b_A = tl.where(m_A, b_A, 0)
129
+ b_Ae = tl.where(m_A, b_e[:, None], 0.0)
130
+
131
+ b_o = - tl.dot(b_e[:, None] * b_A.to(b_v2.dtype), b_v2, allow_tf32=False)
132
+ b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v2.dtype), b_v2, allow_tf32=False)
133
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
134
+ b_e_last = tl.load(p_e_last)
135
+ b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
136
+ b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
137
+ b_h = tl.where((v_i < V)[None, :], b_h, 0.)
138
+ b_hb = tl.where((v_i < V), b_hb, 0.)
139
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
140
+
141
+ if STORE_FINAL_STATE:
142
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
143
+ p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
144
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
145
+ tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,))
146
+
147
+
148
+ @triton.heuristics({
149
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
150
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=1),
155
+ triton.Config({}, num_warps=2),
156
+ triton.Config({}, num_warps=4)
157
+ ],
158
+ key=['BT', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def fused_chunk_ttt_linear_bwd_kernel_h(
162
+ k,
163
+ v,
164
+ v2,
165
+ x,
166
+ y,
167
+ r,
168
+ w,
169
+ b,
170
+ eta,
171
+ h0,
172
+ hb0,
173
+ h,
174
+ do,
175
+ dq,
176
+ scale,
177
+ eps,
178
+ T,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_INITIAL_STATE: tl.constexpr,
186
+ USE_INITIAL_STATE_B: tl.constexpr,
187
+ HEAD_FIRST: tl.constexpr
188
+ ):
189
+ # indices
190
+ i_nh = tl.program_id(0)
191
+ i_n, i_h = i_nh // H, i_nh % H
192
+ bos, _ = i_n * T, i_n * T + T
193
+ NT = tl.cdiv(T, BT)
194
+ boh = i_n * NT
195
+
196
+ o_i = tl.arange(0, BT)
197
+ v_i = tl.arange(0, BV)
198
+ m_A = o_i[:, None] >= o_i[None, :]
199
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
200
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
201
+
202
+ # [BK, BV]
203
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
204
+ # [BV]
205
+ b_hb = tl.zeros([BV], dtype=tl.float32)
206
+ if USE_INITIAL_STATE:
207
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
208
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
209
+ if USE_INITIAL_STATE_B:
210
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
211
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
212
+
213
+ for i_t in range(NT):
214
+ if HEAD_FIRST:
215
+ p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
216
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1))
217
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
218
+ p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
219
+ p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
220
+ p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
221
+ p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0))
222
+ p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
223
+ p_dq = tl.make_block_ptr(dq+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
224
+ p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
225
+ p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
226
+ else:
227
+ p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
228
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
229
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
230
+ p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
231
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
232
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
233
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
234
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
235
+ p_dq = tl.make_block_ptr(dq+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
236
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
237
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
238
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
239
+ # [BK, BT]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
241
+ # [BT, BV]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
243
+
244
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
245
+ b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
246
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
247
+ xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
248
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
249
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
250
+ b_kh_hat = (b_kh - mean) * rstd
251
+
252
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
253
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
254
+ b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
255
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
256
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
257
+ tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1))
258
+ tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1))
259
+ tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1))
260
+ tl.store(p_v2, b_v2.to(p_v2.dtype.element_ty), boundary_check=(0, 1))
261
+
262
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
263
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
264
+
265
+ b_v2 = tl.where((v_i < V)[None, :], b_v2, 0.)
266
+ b_ds = tl.dot(b_do, tl.trans(b_v2).to(b_do.dtype))
267
+ b_ds = tl.where(m_A, b_ds, 0)
268
+ b_ds = b_ds.to(b_k.dtype)
269
+ b_dq = tl.dot(b_do, tl.trans(b_h).to(b_do.dtype))
270
+ b_dq -= tl.dot(b_ds, tl.trans(b_k)) * b_e[:, None]
271
+ b_dq *= scale
272
+
273
+ b_e_last = tl.load(p_e_last)
274
+ b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
275
+ b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
276
+ b_h = tl.where((v_i < V)[None, :], b_h, 0.)
277
+ b_hb = tl.where((v_i < V), b_hb, 0.)
278
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
279
+
280
+
281
+ @triton.heuristics({
282
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
283
+ 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None,
284
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
285
+ 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None,
286
+ })
287
+ @triton.autotune(
288
+ configs=[
289
+ triton.Config({}, num_warps=1),
290
+ triton.Config({}, num_warps=2),
291
+ triton.Config({}, num_warps=4)
292
+ ],
293
+ key=['BT', 'BK', 'BV'],
294
+ )
295
+ @triton.jit(do_not_specialize=['T'])
296
+ def fused_chunk_ttt_linear_bwd_kernel_dh(
297
+ q,
298
+ k,
299
+ v,
300
+ v2,
301
+ x,
302
+ y,
303
+ r,
304
+ w,
305
+ b,
306
+ eta,
307
+ h,
308
+ dht,
309
+ dhbt,
310
+ dh0,
311
+ dhb0,
312
+ do,
313
+ dk,
314
+ dv,
315
+ de,
316
+ dw,
317
+ db,
318
+ scale,
319
+ T,
320
+ H: tl.constexpr,
321
+ K: tl.constexpr,
322
+ V: tl.constexpr,
323
+ BT: tl.constexpr,
324
+ BK: tl.constexpr,
325
+ BV: tl.constexpr,
326
+ USE_INITIAL_STATE: tl.constexpr,
327
+ USE_INITIAL_STATE_B: tl.constexpr,
328
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
329
+ USE_FINAL_STATE_GRADIENT_B: tl.constexpr,
330
+ HEAD_FIRST: tl.constexpr
331
+ ):
332
+ # indices
333
+ i_nh = tl.program_id(0)
334
+ i_n, i_h = i_nh // H, i_nh % H
335
+ bos, _ = i_n * T, i_n * T + T
336
+ NT = tl.cdiv(T, BT)
337
+ boh = i_n * NT
338
+
339
+ # [BK, BV]
340
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
341
+ # [BV]
342
+ b_dhb = tl.zeros([BV], dtype=tl.float32)
343
+ if USE_FINAL_STATE_GRADIENT:
344
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
345
+ b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero")
346
+ if USE_FINAL_STATE_GRADIENT_B:
347
+ p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
348
+ b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero")
349
+
350
+ # [BV]
351
+ o_i = tl.arange(0, BT)
352
+ v_i = tl.arange(0, BV)
353
+ m_A = o_i[:, None] >= o_i[None, :]
354
+ m_A_t = o_i[:, None] <= o_i[None, :]
355
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
356
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
357
+ b_dw = tl.zeros([BV,], dtype=b_w.dtype)
358
+ b_db = tl.zeros([BV,], dtype=b_b.dtype)
359
+ p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
360
+ p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
361
+
362
+ for i_t in range(NT - 1, -1, -1):
363
+ if HEAD_FIRST:
364
+ p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1))
365
+ p_q = tl.make_block_ptr(q+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1))
366
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
367
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
368
+ p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
369
+ p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
370
+ p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
371
+ p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0))
372
+ p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
373
+ p_dv = tl.make_block_ptr(dv+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
374
+ p_dk = tl.make_block_ptr(dk+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
375
+ p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
376
+ p_de = tl.make_block_ptr(de+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
377
+ p_e_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1
378
+ else:
379
+ p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1))
380
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
381
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
382
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
383
+ p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
384
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
385
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
386
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
387
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
388
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
389
+ p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
390
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
391
+ p_de = tl.make_block_ptr(de+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
392
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
393
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
394
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
395
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
396
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
397
+ b_e_last = tl.load(p_e_last)
398
+ b_A = tl.dot(b_k, b_q)
399
+ b_A = - tl.where(m_A_t, b_A * scale * b_e[None, :], 0).to(do.dtype.element_ty)
400
+ b_Ae = - tl.where(m_A_t, b_e[None, :], 0).to(do.dtype.element_ty)
401
+ b_dv_new = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do)
402
+ b_dv_new -= tl.dot(b_e_last * b_k, b_dh.to(b_k.dtype))
403
+ b_dv_new -= b_e_last * b_dhb.to(b_k.dtype)[None, :]
404
+
405
+ b_v2 = tl.load(p_v2, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
406
+ b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
407
+ b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
408
+ b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
409
+ b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) -
410
+ b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
411
+ b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) +
412
+ b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
413
+ b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v2.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True)
414
+
415
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
416
+ b_w = b_w.to(b_k.dtype)
417
+ b_b = b_b.to(b_k.dtype)
418
+ b_dv = -b_w * b_dy.to(b_k.dtype)
419
+ b_dk = b_w * b_dy.to(b_k.dtype)
420
+ b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) +
421
+ (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype)
422
+ b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype)
423
+ b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype)
424
+
425
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
426
+ b_q = (b_q * scale).to(b_q.dtype)
427
+ b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) -
428
+ b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V
429
+ b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V
430
+ b_dkh = tl.where((v_i < V)[None, :] * (o_i < T-i_t*BT)[:, None], b_dkh, 0.)
431
+ b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype)
432
+
433
+ b_ds = tl.dot(b_do, tl.trans(b_v2))
434
+ b_ds = tl.where(m_A, b_ds, 0)
435
+ b_ds = b_ds.to(b_k.dtype)
436
+ i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1)
437
+ mask = (o_i == i_last)
438
+ b_dk -= b_e_last * tl.dot(b_v2, tl.trans(b_dh).to(b_v2.dtype))
439
+ b_dk -= tl.dot(tl.trans(b_ds), tl.trans(b_q) * b_e[:, None])
440
+ b_de = mask * tl.sum(- b_dh * tl.trans(tl.dot(tl.trans(b_v2), b_k))).to(b_k.dtype)
441
+ b_de -= mask * tl.sum(b_dhb * tl.sum(b_v2, axis=0)).to(b_k.dtype)
442
+ b_de -= tl.sum(tl.dot(b_ds, b_k) * tl.trans(b_q).to(b_k.dtype), axis=1)
443
+ b_de -= tl.sum(b_ds, axis=1)
444
+ b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh)
445
+ b_dhb += tl.sum(b_do + b_dkh, axis=0)
446
+ b_dh = tl.where((v_i < V)[None, :], b_dh, 0.)
447
+ b_dhb = tl.where((v_i < V), b_dhb, 0.)
448
+
449
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
450
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
451
+ tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,))
452
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,))
453
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
454
+
455
+ if USE_INITIAL_STATE:
456
+ p_dh0 = tl.make_block_ptr(dh0+i_nh*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
457
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
458
+ if USE_INITIAL_STATE_B:
459
+ p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (0,), (BV,), (0,))
460
+ tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,))
461
+
462
+
463
+ def fused_chunk_ttt_linear_bwd_h(
464
+ q: torch.Tensor,
465
+ k: torch.Tensor,
466
+ v: torch.Tensor,
467
+ w: torch.Tensor,
468
+ b: torch.Tensor,
469
+ eta: torch.Tensor,
470
+ scale: float,
471
+ eps: float,
472
+ do: torch.Tensor,
473
+ BT: int = 16,
474
+ initial_state: torch.Tensor = None,
475
+ initial_state_bias: torch.Tensor = None,
476
+ offsets: Optional[torch.LongTensor] = None,
477
+ head_first: bool = True
478
+ ):
479
+ assert offsets is None, "bwd of varlen is not implemented yet."
480
+ if head_first:
481
+ B, H, T, K, V = *k.shape, v.shape[-1]
482
+ else:
483
+ B, T, H, K, V = *k.shape, v.shape[-1]
484
+ # N: the actual number of sequences in the batch with either equal or variable lengths
485
+ N, NT = B, triton.cdiv(T, BT)
486
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
487
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
488
+
489
+ if head_first:
490
+ h = k.new_empty(B, H, NT, K, V)
491
+ r = v.new_empty(B, H, T, 1, dtype=torch.float32)
492
+ else:
493
+ h = k.new_empty(B, NT, H, K, V)
494
+ r = v.new_empty(B, T, H, 1, dtype=torch.float32)
495
+ v2 = torch.empty_like(v)
496
+ x = torch.empty_like(v)
497
+ y = torch.empty_like(v)
498
+ dq = torch.empty_like(q)
499
+
500
+ grid = (N * H,)
501
+ fused_chunk_ttt_linear_bwd_kernel_h[grid](
502
+ k=k,
503
+ v=v,
504
+ v2=v2,
505
+ x=x,
506
+ y=y,
507
+ r=r,
508
+ w=w,
509
+ b=b,
510
+ eta=eta,
511
+ h0=initial_state,
512
+ hb0=initial_state_bias,
513
+ h=h,
514
+ do=do,
515
+ dq=dq,
516
+ scale=scale,
517
+ eps=eps,
518
+ T=T,
519
+ H=H,
520
+ K=K,
521
+ V=V,
522
+ BT=BT,
523
+ BK=BK,
524
+ BV=BV,
525
+ HEAD_FIRST=head_first
526
+ )
527
+ return dq, h, v2, x, y, r
528
+
529
+
530
+ def fused_chunk_ttt_linear_bwd_dh(
531
+ q: torch.Tensor,
532
+ k: torch.Tensor,
533
+ v: torch.Tensor,
534
+ v2: torch.Tensor,
535
+ x: torch.Tensor,
536
+ y: torch.Tensor,
537
+ r: torch.Tensor,
538
+ w: torch.Tensor,
539
+ b: torch.Tensor,
540
+ eta: torch.Tensor,
541
+ scale: float,
542
+ h: torch.Tensor,
543
+ do: torch.Tensor,
544
+ dht: torch.Tensor,
545
+ dhbt: torch.Tensor,
546
+ BT: int = 16,
547
+ initial_state: torch.Tensor = None,
548
+ initial_state_bias: torch.Tensor = None,
549
+ offsets: Optional[torch.LongTensor] = None,
550
+ head_first: bool = True
551
+ ):
552
+ assert offsets is None, "bwd of varlen is not implemented yet."
553
+ if head_first:
554
+ B, H, T, K, V = *k.shape, v.shape[-1]
555
+ else:
556
+ B, T, H, K, V = *k.shape, v.shape[-1]
557
+ # N: the actual number of sequences in the batch with either equal or variable lengths
558
+ N = B
559
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
560
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
561
+
562
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32) if initial_state is not None else None
563
+ dhb0 = torch.empty_like(initial_state_bias, dtype=torch.float32) if initial_state_bias is not None else None
564
+ dk = torch.empty_like(k)
565
+ dv = torch.empty_like(v)
566
+ de = torch.empty_like(eta)
567
+ dw = w.new_empty(B, H, V)
568
+ db = b.new_empty(B, H, V)
569
+
570
+ grid = (N * H,)
571
+ fused_chunk_ttt_linear_bwd_kernel_dh[grid](
572
+ q=q,
573
+ k=k,
574
+ v=v,
575
+ v2=v2,
576
+ x=x,
577
+ y=y,
578
+ r=r,
579
+ w=w,
580
+ b=b,
581
+ eta=eta,
582
+ h=h,
583
+ dht=dht,
584
+ dhbt=dhbt,
585
+ dh0=dh0,
586
+ dhb0=dhb0,
587
+ do=do,
588
+ dk=dk,
589
+ dv=dv,
590
+ de=de,
591
+ dw=dw,
592
+ db=db,
593
+ scale=scale,
594
+ T=T,
595
+ H=H,
596
+ K=K,
597
+ V=V,
598
+ BT=BT,
599
+ BK=BK,
600
+ BV=BV,
601
+ HEAD_FIRST=head_first
602
+ )
603
+ dw = dw.sum(dim=0)
604
+ db = db.sum(dim=0)
605
+ return dk, dv, de, dw, db, dh0, dhb0
606
+
607
+
608
+ def fused_chunk_ttt_linear_fwd(
609
+ q: torch.Tensor,
610
+ k: torch.Tensor,
611
+ v: torch.Tensor,
612
+ w: torch.Tensor,
613
+ b: torch.Tensor,
614
+ eta: torch.Tensor,
615
+ scale: float,
616
+ eps: float,
617
+ initial_state: torch.Tensor,
618
+ initial_state_bias: torch.Tensor,
619
+ output_final_state: bool,
620
+ offsets: Optional[torch.LongTensor] = None,
621
+ head_first: bool = True,
622
+ BT: int = 16
623
+ ):
624
+ if head_first:
625
+ B, H, T, K, V = *k.shape, v.shape[-1]
626
+ else:
627
+ B, T, H, K, V = *k.shape, v.shape[-1]
628
+ # N: the actual number of sequences in the batch with either equal or variable lengths
629
+ N = B if offsets is None else len(offsets) - 1
630
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
631
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
632
+ o = torch.empty_like(v)
633
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
634
+ final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None
635
+
636
+ grid = (N * H,)
637
+ fused_chunk_ttt_linear_fwd_kernel[grid](
638
+ q=q,
639
+ k=k,
640
+ v=v,
641
+ eta=eta,
642
+ w=w,
643
+ b=b,
644
+ o=o,
645
+ scale=scale,
646
+ eps=eps,
647
+ h0=initial_state,
648
+ hb0=initial_state_bias,
649
+ ht=final_state,
650
+ hbt=final_state_bias,
651
+ offsets=offsets,
652
+ T=T,
653
+ H=H,
654
+ K=K,
655
+ V=V,
656
+ BT=BT,
657
+ BK=BK,
658
+ BV=BV,
659
+ HEAD_FIRST=head_first
660
+ )
661
+ return o, final_state, final_state_bias
662
+
663
+
664
+ def fused_chunk_ttt_linear_bwd(
665
+ q: torch.Tensor,
666
+ k: torch.Tensor,
667
+ v: torch.Tensor,
668
+ w: torch.Tensor,
669
+ b: torch.Tensor,
670
+ eta: torch.Tensor,
671
+ scale: float,
672
+ eps: float,
673
+ do: torch.Tensor,
674
+ dht: torch.Tensor,
675
+ dhbt: torch.Tensor,
676
+ BT: int = 16,
677
+ initial_state: torch.Tensor = None,
678
+ initial_state_bias: torch.Tensor = None,
679
+ offsets: Optional[torch.LongTensor] = None,
680
+ head_first: bool = True
681
+ ):
682
+ assert offsets is None, "bwd of varlen is not implemented yet."
683
+ dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h(
684
+ q=q,
685
+ k=k,
686
+ v=v,
687
+ w=w,
688
+ b=b,
689
+ eta=eta,
690
+ scale=scale,
691
+ eps=eps,
692
+ do=do,
693
+ BT=BT,
694
+ initial_state=initial_state,
695
+ initial_state_bias=initial_state_bias,
696
+ offsets=offsets,
697
+ head_first=head_first
698
+ )
699
+ dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd_dh(
700
+ q=q,
701
+ k=k,
702
+ v=v,
703
+ v2=v2,
704
+ x=x,
705
+ y=y,
706
+ r=rstd,
707
+ w=w,
708
+ b=b,
709
+ eta=eta,
710
+ scale=scale,
711
+ h=h,
712
+ do=do,
713
+ dht=dht,
714
+ dhbt=dhbt,
715
+ BT=BT,
716
+ initial_state=initial_state,
717
+ initial_state_bias=initial_state_bias,
718
+ offsets=offsets,
719
+ head_first=head_first
720
+ )
721
+ return dq, dk, dv, de, dw, db, dh0, dhb0
722
+
723
+
724
+ class FusedChunkTTTLinearFunction(torch.autograd.Function):
725
+
726
+ @staticmethod
727
+ @input_guard
728
+ @autocast_custom_fwd
729
+ def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state,
730
+ initial_state_bias, output_final_state, offsets, head_first):
731
+ o, final_state, final_state_bias = fused_chunk_ttt_linear_fwd(
732
+ q=q,
733
+ k=k,
734
+ v=v,
735
+ w=w,
736
+ b=b,
737
+ eta=eta,
738
+ scale=scale,
739
+ eps=eps,
740
+ BT=BT,
741
+ initial_state=initial_state,
742
+ initial_state_bias=initial_state_bias,
743
+ output_final_state=output_final_state,
744
+ offsets=offsets,
745
+ head_first=head_first
746
+ )
747
+ ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias)
748
+ ctx.BT = BT
749
+ ctx.scale = scale
750
+ ctx.eps = eps
751
+ ctx.offsets = offsets
752
+ ctx.head_first = head_first
753
+ return o.to(q.dtype), final_state, final_state_bias
754
+
755
+ @staticmethod
756
+ @input_guard
757
+ @autocast_custom_bwd
758
+ def backward(ctx, do, dht, dhbt):
759
+ q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors
760
+ dq, dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd(
761
+ q=q,
762
+ k=k,
763
+ v=v,
764
+ w=w,
765
+ b=b,
766
+ eta=eta,
767
+ scale=ctx.scale,
768
+ eps=ctx.eps,
769
+ do=do,
770
+ dht=dht,
771
+ dhbt=dhbt,
772
+ BT=ctx.BT,
773
+ initial_state=initial_state,
774
+ initial_state_bias=initial_state_bias,
775
+ offsets=ctx.offsets,
776
+ head_first=ctx.head_first
777
+ )
778
+ return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None
779
+
780
+
781
+ def norm_residual(x, weight, bias, eps, head_first):
782
+ # GroupNorm and Residual
783
+ if head_first:
784
+ B, H, T, D = x.shape
785
+ x = x.transpose(1, 2)
786
+ x += group_norm(
787
+ x.reshape(B, T, -1).clone(),
788
+ weight=weight.reshape(-1).clone(),
789
+ bias=bias.reshape(-1).clone(),
790
+ eps=eps,
791
+ num_groups=H,
792
+ ).reshape(x.shape)
793
+ x = x.transpose(1, 2)
794
+ else:
795
+ B, T, H, D = x.shape
796
+ x += group_norm(
797
+ x.reshape(B, T, -1).clone(),
798
+ weight=weight.reshape(-1).clone(),
799
+ bias=bias.reshape(-1).clone(),
800
+ eps=eps,
801
+ num_groups=H,
802
+ ).reshape(x.shape)
803
+ return x
804
+
805
+
806
+ def fused_chunk_ttt_linear(
807
+ q: torch.Tensor,
808
+ k: torch.Tensor,
809
+ v: torch.Tensor,
810
+ w: torch.Tensor,
811
+ b: torch.Tensor,
812
+ eta: torch.Tensor,
813
+ scale: float = None,
814
+ eps: float = 1e-6,
815
+ chunk_size: int = 16,
816
+ initial_state: torch.Tensor = None,
817
+ initial_state_bias: torch.Tensor = None,
818
+ output_final_state: bool = False,
819
+ cu_seqlens: Optional[torch.LongTensor] = None,
820
+ head_first: bool = True,
821
+ ):
822
+ r"""
823
+ Args:
824
+ q (torch.Tensor):
825
+ queries of shape `(B, H, T, K)`
826
+ k (torch.Tensor):
827
+ keys of shape `(B, H, T, K)`
828
+ v (torch.Tensor):
829
+ values of shape `(B, H, T, V)`
830
+ w (torch.Tensor):
831
+ layer norm weight of shape `(H, V)`
832
+ b (torch.Tensor):
833
+ layer norm bias of shape `(H, V)`
834
+ eta (torch.Tensor):
835
+ Learning rate for hidden state, of shape `(B, H, T, 1)`.
836
+ scale (Optional[int]):
837
+ Scale factor for the RetNet attention scores.
838
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
839
+ chunk_size (int):
840
+ chunk size. Default: `16`.
841
+ initial_state (Optional[torch.Tensor]):
842
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
843
+ initial_state_bias (Optional[torch.Tensor]):
844
+ Initial state bias of shape `(B, H, 1, V)`. Default: `None`.
845
+ output_final_state (Optional[bool]):
846
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
847
+ cu_seqlens (torch.LongTensor):
848
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
849
+ consistent with the FlashAttention API.
850
+ head_first (Optional[bool]):
851
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
852
+ Default: `True`.
853
+
854
+ Returns:
855
+ o (torch.Tensor):
856
+ Outputs of shape `[B, H, T, V]`
857
+ final_state (torch.Tensor):
858
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
859
+ final_state_bias (torch.Tensor):
860
+ Final state bias of shape `[B, H, 1, V]` if `output_final_state=True` else `None`.
861
+ """
862
+ assert q.dtype == k.dtype == v.dtype
863
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
864
+ if isinstance(eta, float):
865
+ eta = torch.full_like(q[:, :, :, :1], eta)
866
+ if cu_seqlens is not None:
867
+ if q.shape[0] != 1:
868
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
869
+ f"Please flatten variable-length inputs before processing.")
870
+ if head_first:
871
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
872
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
873
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
874
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
875
+ if scale is None:
876
+ scale = k.shape[-1] ** -0.5
877
+ else:
878
+ assert scale > 0, "Scale must be positive."
879
+ o, final_state, final_state_bias = FusedChunkTTTLinearFunction.apply(
880
+ q,
881
+ k,
882
+ v,
883
+ w,
884
+ b,
885
+ chunk_size,
886
+ eta,
887
+ scale,
888
+ eps,
889
+ initial_state,
890
+ initial_state_bias,
891
+ output_final_state,
892
+ cu_seqlens,
893
+ head_first
894
+ )
895
+ o = norm_residual(o, w, b, eps, head_first)
896
+ return o, final_state, final_state_bias
fla/ops/ttt/naive.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def ttt_linear(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ w: torch.Tensor,
13
+ b: torch.Tensor,
14
+ eta: torch.Tensor,
15
+ scale: float,
16
+ eps: float,
17
+ mini_batch_size: int,
18
+ initial_state: torch.Tensor,
19
+ initial_state_bias: torch.Tensor,
20
+ output_final_state: bool
21
+ ):
22
+ B, H, T, D = q.shape
23
+ BT = mini_batch_size
24
+ NT = T // BT
25
+ # [NT, B, H, mini_batch_size, D]
26
+ _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
27
+ _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
28
+ _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
29
+ # [NT, B, H, BT, 1]
30
+ _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4)
31
+ # [H, 1, D]
32
+ w = w.reshape(H, 1, D).to(torch.float32)
33
+ b = b.reshape(H, 1, D).to(torch.float32)
34
+
35
+ h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state
36
+ hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias
37
+ q *= scale
38
+ # [NT, B, H, BT, D]
39
+ o = torch.empty_like(_v)
40
+
41
+ for i in range(NT):
42
+ q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]]
43
+ kh = k_i @ h + hb
44
+ reconstruction_target = v_i - k_i
45
+
46
+ mean = kh.mean(-1, True)
47
+ var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
48
+ rstd = torch.sqrt(var + eps).to(torch.float32)
49
+ kh_hat = (kh - mean) / rstd
50
+
51
+ g = w * kh_hat + b - reconstruction_target
52
+ g *= w
53
+ v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D)
54
+
55
+ Attn = torch.tril(q_i @ k_i.transpose(-2, -1))
56
+ o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new
57
+ h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new
58
+ hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True)
59
+ # layer norm with residuals
60
+
61
+ mean = o_i.mean(dim=-1, keepdim=True)
62
+ var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
63
+ rstd = torch.sqrt(var + eps).to(torch.float32)
64
+ o[i] = o_i + (o_i - mean) / rstd * w + b
65
+
66
+ # [B, H, T, D]
67
+ o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
68
+ h = h if output_final_state else None
69
+ hb = hb if output_final_state else None
70
+ return o, h, hb
71
+
72
+
73
+ def chunk_ttt_linear_ref(
74
+ q: torch.Tensor,
75
+ k: torch.Tensor,
76
+ v: torch.Tensor,
77
+ w: torch.Tensor,
78
+ b: torch.Tensor,
79
+ eta: torch.Tensor,
80
+ scale: float = None,
81
+ eps: float = 1e-6,
82
+ mini_batch_size: int = 16,
83
+ initial_state: torch.Tensor = None,
84
+ initial_state_bias: torch.Tensor = None,
85
+ output_final_state: bool = False,
86
+ head_first: bool = True,
87
+ ):
88
+ assert q.dtype == k.dtype == v.dtype
89
+ assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same."
90
+ if isinstance(eta, float):
91
+ eta = torch.full_like(q[:, :, :, :1], eta)
92
+ if scale is None:
93
+ scale = k.shape[-1] ** -0.5
94
+ if not head_first:
95
+ q = q.transpose(1, 2)
96
+ k = k.transpose(1, 2)
97
+ v = v.transpose(1, 2)
98
+ eta = eta.transpose(1, 2)
99
+ T = q.shape[-2]
100
+ padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size
101
+ if padded > 0:
102
+ q = F.pad(q, (0, 0, 0, padded))
103
+ k = F.pad(k, (0, 0, 0, padded))
104
+ v = F.pad(v, (0, 0, 0, padded))
105
+ eta = F.pad(eta, (0, 0, 0, padded))
106
+ eta[:, :, -1, :] = eta[:, :, -(padded+1), :]
107
+ assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size."
108
+ q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b])
109
+ o, final_state, final_state_bias = ttt_linear(
110
+ q,
111
+ k,
112
+ v,
113
+ w,
114
+ b,
115
+ eta,
116
+ scale,
117
+ eps,
118
+ mini_batch_size,
119
+ initial_state,
120
+ initial_state_bias,
121
+ output_final_state,
122
+ )
123
+ o = o[:, :, :T, :].contiguous()
124
+ if not head_first:
125
+ o = o.transpose(1, 2)
126
+ return o, final_state, final_state_bias