zaydzuhri commited on
Commit
237dac0
·
verified ·
1 Parent(s): fc57272

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__pycache__/abc.cpython-312.pyc +0 -0
  2. fla/layers/__pycache__/forgetting_attn.cpython-312.pyc +0 -0
  3. fla/ops/__pycache__/__init__.cpython-312.pyc +0 -0
  4. fla/ops/abc/__init__.py +7 -0
  5. fla/ops/attn/parallel.py +629 -0
  6. fla/ops/based/naive.py +72 -0
  7. fla/ops/based/parallel.py +410 -0
  8. fla/ops/common/chunk_delta_h.py +399 -0
  9. fla/ops/common/chunk_h_parallel.py +650 -0
  10. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  11. fla/ops/delta_rule/__init__.py +11 -0
  12. fla/ops/delta_rule/chunk.py +373 -0
  13. fla/ops/delta_rule/fused_recurrent.py +607 -0
  14. fla/ops/delta_rule/parallel.py +394 -0
  15. fla/ops/gated_delta_rule/__init__.py +7 -0
  16. fla/ops/gated_delta_rule/chunk.py +392 -0
  17. fla/ops/gated_delta_rule/fused_recurrent.py +321 -0
  18. fla/ops/generalized_delta_rule/README.md +37 -0
  19. fla/ops/generalized_delta_rule/__init__.py +9 -0
  20. fla/ops/gla/fused_chunk.py +631 -0
  21. fla/ops/gsa/__init__.py +9 -0
  22. fla/ops/hgrn/__init__.py +9 -0
  23. fla/ops/hgrn/naive.py +63 -0
  24. fla/ops/lightning_attn/__pycache__/chunk.cpython-312.pyc +0 -0
  25. fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  26. fla/ops/linear_attn/fused_chunk.py +318 -0
  27. fla/ops/linear_attn/fused_recurrent.py +251 -0
  28. fla/ops/linear_attn/utils.py +10 -0
  29. fla/ops/nsa/utils.py +92 -0
  30. fla/ops/rebased/naive.py +27 -0
  31. fla/ops/rebased/parallel.py +466 -0
  32. fla/ops/retention/__init__.py +13 -0
  33. fla/ops/retention/fused_chunk.py +365 -0
  34. fla/ops/retention/fused_recurrent.py +42 -0
  35. fla/ops/retention/naive.py +15 -0
  36. fla/ops/rwkv6/recurrent_naive.py +103 -0
  37. fla/ops/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  38. fla/ops/rwkv7/__pycache__/chunk.cpython-312.pyc +0 -0
  39. fla/ops/simple_gla/chunk.py +302 -0
  40. fla/ops/simple_gla/parallel.py +722 -0
  41. fla/ops/titans/naive.py +375 -0
  42. fla/ops/ttt/chunk.py +1539 -0
  43. fla/ops/utils/asm.py +17 -0
  44. fla/ops/utils/logcumsumexp.py +52 -0
  45. fla/ops/utils/testing.py +26 -0
  46. profile_trace/iteration_10240/rank2_trace.json +0 -0
  47. profile_trace/iteration_10240/rank3_trace.json +0 -0
  48. profile_trace/iteration_10240/rank4_trace.json +0 -0
  49. profile_trace/iteration_10240/rank5_trace.json +0 -0
  50. profile_trace/iteration_10240/rank6_trace.json +0 -0
fla/layers/__pycache__/abc.cpython-312.pyc ADDED
Binary file (9.59 kB). View file
 
fla/layers/__pycache__/forgetting_attn.cpython-312.pyc ADDED
Binary file (5.33 kB). View file
 
fla/ops/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.93 kB). View file
 
fla/ops/abc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_abc
4
+
5
+ __all__ = [
6
+ 'chunk_abc'
7
+ ]
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
fla/ops/based/parallel.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+ # Based: An Educational and Effective Sequence Mixer
13
+ # https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
14
+
15
+
16
+ @triton.jit(do_not_specialize=['T'])
17
+ def parallel_based_fwd_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ o,
22
+ z,
23
+ scale,
24
+ T,
25
+ B: tl.constexpr,
26
+ H: tl.constexpr,
27
+ K: tl.constexpr,
28
+ V: tl.constexpr,
29
+ BTL: tl.constexpr,
30
+ BTS: tl.constexpr,
31
+ BK: tl.constexpr,
32
+ BV: tl.constexpr,
33
+ ):
34
+ # i_c: chunk index. used for sequence parallelism
35
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
36
+ NV = tl.cdiv(V, BV)
37
+ i_k = i_kv // (NV)
38
+ i_v = i_kv % (NV)
39
+
40
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
41
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1))
42
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0))
43
+
44
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
45
+ b_q = tl.load(p_q, boundary_check=(0, 1))
46
+ b_q = (b_q * scale).to(b_q.dtype)
47
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
48
+ b_z = tl.zeros([BTL], dtype=tl.float32)
49
+
50
+ # Q block and K block have no overlap
51
+ # no need for mask, thereby saving flops
52
+ for _ in range(0, i_c * BTL, BTS):
53
+ # [BK, BTS]
54
+ b_k = tl.load(p_k, boundary_check=(0, 1))
55
+
56
+ # [BTS, BV]
57
+ b_v = tl.load(p_v, boundary_check=(0, 1))
58
+ # [BTL, BTS]
59
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
60
+ b_s = 1 + b_s + 0.5 * b_s * b_s
61
+ b_z += tl.sum(b_s, axis=1)
62
+
63
+ # [BQ, BD]
64
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
65
+ p_k = tl.advance(p_k, (0, BTS))
66
+ p_v = tl.advance(p_v, (BTS, 0))
67
+
68
+ # # rescale interchunk output
69
+ tl.debug_barrier()
70
+ o_q = tl.arange(0, BTL)
71
+ # # sync threads, easy for compiler to optimize
72
+ # tl.debug_barrier()
73
+
74
+ o_k = tl.arange(0, BTS)
75
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
76
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
77
+ # Q block and K block have overlap. masks required
78
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
79
+ # [BK, BTS]
80
+ b_k = tl.load(p_k, boundary_check=(0, 1))
81
+ # [BTS, BV]
82
+ b_v = tl.load(p_v, boundary_check=(0, 1))
83
+ # [BTL, BTS]
84
+ m_s = o_q[:, None] >= o_k[None, :]
85
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
86
+ b_s = 1 + b_s + 0.5 * b_s * b_s
87
+ b_s = tl.where(m_s, b_s, 0)
88
+ b_z += tl.sum(b_s, axis=1)
89
+ # [BTL, BV]
90
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
91
+
92
+ p_k = tl.advance(p_k, (0, BTS))
93
+ p_v = tl.advance(p_v, (BTS, 0))
94
+ o_k += BTS
95
+
96
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
97
+ p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
98
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))
100
+
101
+
102
+ @triton.jit
103
+ def _parallel_based_bwd_dq(
104
+ i_bh,
105
+ i_c,
106
+ i_k,
107
+ i_v,
108
+ q,
109
+ k,
110
+ v,
111
+ do,
112
+ dz,
113
+ dq,
114
+ scale,
115
+ T,
116
+ B: tl.constexpr,
117
+ H: tl.constexpr,
118
+ BTL: tl.constexpr,
119
+ BTS: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ K: tl.constexpr,
123
+ V: tl.constexpr,
124
+ ):
125
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
126
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
127
+ b_q = tl.load(p_q, boundary_check=(0, 1))
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+
130
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
131
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
132
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0))
133
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1))
134
+ p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
135
+ b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
136
+
137
+ for _ in range(0, i_c * BTL, BTS):
138
+ # [BTS, BK]
139
+ b_k = tl.load(p_k, boundary_check=(0, 1))
140
+ # [BV, BTS]
141
+ b_v = tl.load(p_v, boundary_check=(0, 1))
142
+ # [BTL, BTS]
143
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
144
+ if i_v == 0:
145
+ b_ds += b_dz[:, None]
146
+ else:
147
+ b_ds = b_ds
148
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
149
+ # [BQ, BD]
150
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
151
+ p_k = tl.advance(p_k, (BTS, 0))
152
+ p_v = tl.advance(p_v, (0, BTS))
153
+
154
+ b_dq *= scale
155
+ o_q = tl.arange(0, BTL)
156
+ o_k = tl.arange(0, BTS)
157
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
158
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
159
+ # Q block and K block have overlap. masks required
160
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
161
+ # [BTS, BK]
162
+ b_k = tl.load(p_k, boundary_check=(0, 1))
163
+ # [BV, BTS]
164
+ b_v = tl.load(p_v, boundary_check=(0, 1))
165
+ # [BTL, BTS]
166
+ m_s = o_q[:, None] >= o_k[None, :]
167
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
168
+ if i_v == 0:
169
+ b_ds += b_dz[:, None]
170
+ else:
171
+ b_ds = b_ds
172
+ b_ds = tl.where(m_s, b_ds, 0) * scale
173
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
174
+ b_s = tl.where(m_s, b_s, 0)
175
+ # [BTL, BK]
176
+ b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False)
177
+ p_k = tl.advance(p_k, (BTS, 0))
178
+ p_v = tl.advance(p_v, (0, BTS))
179
+ o_k += BTS
180
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
181
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
182
+ return
183
+
184
+
185
+ @triton.jit
186
+ def _parallel_based_bwd_dkv(
187
+ i_bh,
188
+ i_c,
189
+ i_k,
190
+ i_v,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ BTL: tl.constexpr,
203
+ BTS: tl.constexpr,
204
+ BK: tl.constexpr,
205
+ BV: tl.constexpr,
206
+ K: tl.constexpr,
207
+ V: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32)
214
+
215
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
216
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
217
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
218
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
219
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
220
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
221
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
222
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS]
223
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
224
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
225
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
226
+ if i_v == 0:
227
+ b_ds += b_dz[None, :] * scale
228
+ else:
229
+ b_ds = b_ds
230
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
231
+
232
+ tl.debug_barrier()
233
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
234
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
235
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
236
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
237
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
238
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
239
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
240
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
241
+ # [BK, BQ]
242
+ m_s = o_k[:, None] <= o_q[None, :]
243
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
244
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
245
+ b_s = tl.where(m_s, b_s, 0)
246
+ b_s2 = tl.where(m_s, b_s2, 0)
247
+
248
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
249
+ if i_v == 0:
250
+ b_ds += b_dz[None, :]
251
+ else:
252
+ b_ds = b_ds
253
+ b_ds = tl.where(m_s, b_ds, 0) * scale
254
+ # [BK, BD]
255
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
256
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
257
+ o_q += BTS
258
+
259
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
260
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
261
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
262
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
263
+ return
264
+
265
+
266
+ @triton.jit(do_not_specialize=['T'])
267
+ def parallel_based_bwd_kernel(
268
+ q,
269
+ k,
270
+ v,
271
+ do,
272
+ dz,
273
+ dq,
274
+ dk,
275
+ dv,
276
+ scale,
277
+ T,
278
+ B: tl.constexpr,
279
+ H: tl.constexpr,
280
+ K: tl.constexpr,
281
+ V: tl.constexpr,
282
+ BTL: tl.constexpr,
283
+ BTS: tl.constexpr,
284
+ BK: tl.constexpr,
285
+ BV: tl.constexpr,
286
+ ):
287
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
288
+ NV = tl.cdiv(V, BV)
289
+ i_k = i_kv // (NV)
290
+ i_v = i_kv % NV
291
+ _parallel_based_bwd_dq(
292
+ i_bh, i_c, i_k, i_v,
293
+ q, k, v, do, dz, dq,
294
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
295
+ )
296
+ tl.debug_barrier()
297
+ _parallel_based_bwd_dkv(
298
+ i_bh, i_c, i_k, i_v,
299
+ q, k, v, do, dz, dk, dv,
300
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
301
+ )
302
+
303
+
304
+ class ParallelBasedFunction(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ @input_guard
308
+ @autocast_custom_fwd
309
+ def forward(ctx, q, k, v, scale):
310
+ BTL, BTS = 128, 32
311
+ assert BTL % BTS == 0
312
+ # assert q.shape[-1] % 16 == 0
313
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
314
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
315
+ BK, BV = max(BK, 16), max(BV, 16)
316
+ B, H, T, K, V = *k.shape, v.shape[-1]
317
+ num_stages = 2
318
+ num_warps = 4
319
+ NK = triton.cdiv(K, BK)
320
+ NV = triton.cdiv(V, BV)
321
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
322
+
323
+ assert NK == 1, "will encounter some synchronization issue if not."
324
+
325
+ o = torch.empty(NK, B, H, T, V, device=q.device)
326
+ z = torch.empty(NK, B, H, T, device=q.device)
327
+ parallel_based_fwd_kernel[grid](
328
+ q, k, v, o, z,
329
+ scale,
330
+ B=B,
331
+ H=H,
332
+ T=T,
333
+ K=K,
334
+ V=V,
335
+ BTL=BTL,
336
+ BTS=BTS,
337
+ BK=BK,
338
+ BV=BV,
339
+ num_warps=num_warps,
340
+ num_stages=num_stages
341
+ )
342
+ ctx.save_for_backward(q, k, v)
343
+ ctx.scale = scale
344
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
345
+
346
+ @staticmethod
347
+ @input_guard
348
+ @autocast_custom_bwd
349
+ def backward(ctx, do, dz):
350
+ q, k, v = ctx.saved_tensors
351
+ scale = ctx.scale
352
+ BTL, BTS = 64, 32
353
+ assert BTL % BTS == 0
354
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
355
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
356
+ BK, BV = max(BK, 16), max(BV, 16)
357
+ B, H, T, K, V = *k.shape, v.shape[-1]
358
+ num_stages = 2
359
+ num_warps = 4
360
+ NK = triton.cdiv(K, BK)
361
+ NV = triton.cdiv(V, BV)
362
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
363
+
364
+ assert NK == 1, "will encounter some synchronization issue if not"
365
+
366
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
367
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
368
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
369
+
370
+ parallel_based_bwd_kernel[grid](
371
+ q, k, v, do, dz, dq, dk, dv,
372
+ scale,
373
+ B=B,
374
+ H=H,
375
+ T=T,
376
+ K=K,
377
+ V=V,
378
+ BTL=BTL,
379
+ BTS=BTS,
380
+ BK=BK,
381
+ BV=BV,
382
+ num_warps=num_warps,
383
+ num_stages=num_stages
384
+ )
385
+
386
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
387
+
388
+
389
+ triton_parallel_based = ParallelBasedFunction.apply
390
+
391
+
392
+ def parallel_based(
393
+ q: torch.Tensor,
394
+ k: torch.Tensor,
395
+ v: torch.Tensor,
396
+ scale: Optional[float] = None,
397
+ use_norm: bool = True,
398
+ head_first: bool = True
399
+ ):
400
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
401
+ if scale is None:
402
+ scale = q.shape[-1] ** -0.5
403
+ if not head_first:
404
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
405
+ o, z = triton_parallel_based(q, k, v, scale)
406
+ if use_norm:
407
+ o = o / (z[..., None] + 1e-6)
408
+ if not head_first:
409
+ o = o.transpose(1, 2)
410
+ return o.to(q.dtype)
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/delta_rule/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_delta_rule
4
+ from .fused_chunk import fused_chunk_delta_rule
5
+ from .fused_recurrent import fused_recurrent_delta_rule
6
+
7
+ __all__ = [
8
+ 'fused_chunk_delta_rule',
9
+ 'fused_recurrent_delta_rule',
10
+ 'chunk_delta_rule'
11
+ ]
fla/ops/delta_rule/chunk.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.common.utils import prepare_chunk_indices
14
+ from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ beta: torch.Tensor,
23
+ scale: float,
24
+ initial_state: torch.Tensor,
25
+ output_final_state: bool,
26
+ offsets: Optional[torch.LongTensor] = None,
27
+ indices: Optional[torch.LongTensor] = None,
28
+ head_first: bool = True,
29
+ chunk_size: int = 64
30
+ ):
31
+ T = q.shape[2] if head_first else q.shape[1]
32
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, A = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ offsets=offsets,
39
+ indices=indices,
40
+ head_first=head_first,
41
+ chunk_size=BT
42
+ )
43
+
44
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
45
+ k=k,
46
+ w=w,
47
+ u=u,
48
+ g=None,
49
+ initial_state=initial_state,
50
+ output_final_state=output_final_state,
51
+ offsets=offsets,
52
+ indices=indices,
53
+ head_first=head_first,
54
+ chunk_size=BT
55
+ )
56
+ o = chunk_fwd_o(
57
+ q=q,
58
+ k=k,
59
+ v=v_new,
60
+ h=h,
61
+ g=None,
62
+ scale=scale,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ return o, A, final_state
69
+
70
+
71
+ def chunk_delta_rule_bwd(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ beta: torch.Tensor,
76
+ A: torch.Tensor,
77
+ scale: float,
78
+ initial_state: torch.Tensor,
79
+ do: torch.Tensor,
80
+ dht: torch.Tensor,
81
+ offsets: Optional[torch.LongTensor] = None,
82
+ indices: Optional[torch.LongTensor] = None,
83
+ head_first: bool = True,
84
+ chunk_size: int = 64
85
+ ):
86
+ T = q.shape[2] if head_first else q.shape[1]
87
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
88
+ w, u = fwd_recompute_w_u(
89
+ k=k,
90
+ v=v,
91
+ beta=beta,
92
+ A=A,
93
+ offsets=offsets,
94
+ indices=indices,
95
+ head_first=head_first,
96
+ chunk_size=BT
97
+ )
98
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
99
+ k=k,
100
+ w=w,
101
+ u=u,
102
+ g=None,
103
+ initial_state=initial_state,
104
+ output_final_state=False,
105
+ offsets=offsets,
106
+ indices=indices,
107
+ head_first=head_first,
108
+ chunk_size=BT
109
+ )
110
+ dv = chunk_bwd_dv_local(
111
+ q=q,
112
+ k=k,
113
+ do=do,
114
+ g=None,
115
+ dh=None,
116
+ scale=scale,
117
+ offsets=offsets,
118
+ indices=indices,
119
+ head_first=head_first,
120
+ chunk_size=BT
121
+ )
122
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
123
+ q=q,
124
+ k=k,
125
+ w=w,
126
+ g=None,
127
+ h0=initial_state,
128
+ dht=dht,
129
+ do=do,
130
+ dv=dv,
131
+ scale=scale,
132
+ offsets=offsets,
133
+ indices=indices,
134
+ head_first=head_first,
135
+ chunk_size=BT
136
+ )
137
+ dq, dk, dw, _ = chunk_bwd_dqkwg(
138
+ q=q,
139
+ k=k,
140
+ v=v_new,
141
+ h=h,
142
+ w=w,
143
+ dv=dv,
144
+ do=do,
145
+ dh=dh,
146
+ g=None,
147
+ scale=scale,
148
+ offsets=offsets,
149
+ indices=indices,
150
+ head_first=head_first,
151
+ chunk_size=BT
152
+ )
153
+ dk2, dv, db = bwd_prepare_wy_repr(
154
+ k=k,
155
+ v=v,
156
+ beta=beta,
157
+ A=A,
158
+ dw=dw,
159
+ du=dv,
160
+ offsets=offsets,
161
+ indices=indices,
162
+ head_first=head_first,
163
+ chunk_size=BT
164
+ )
165
+ dk.add_(dk2)
166
+ return dq, dk, dv, db, dh0
167
+
168
+
169
+ class ChunkDeltaRuleFunction(torch.autograd.Function):
170
+
171
+ @staticmethod
172
+ @input_guard
173
+ @autocast_custom_fwd
174
+ def forward(
175
+ ctx,
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ v: torch.Tensor,
179
+ beta: torch.Tensor,
180
+ scale: float,
181
+ initial_state: torch.Tensor,
182
+ output_final_state: bool,
183
+ offsets: Optional[torch.LongTensor] = None,
184
+ head_first: bool = True,
185
+ use_qk_l2norm_in_kernel: bool = True
186
+ ):
187
+ T = q.shape[2] if head_first else q.shape[1]
188
+ chunk_size = min(64, max(triton.next_power_of_2(T), 16))
189
+
190
+ q_orig = q
191
+ k_orig = k
192
+
193
+ if use_qk_l2norm_in_kernel:
194
+ q = l2norm_fwd(q)
195
+ k = l2norm_fwd(k)
196
+
197
+ # 2-d indices denoting the offsets of chunks in each sequence
198
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
199
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
200
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
201
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
202
+
203
+ o, A, final_state = chunk_delta_rule_fwd(
204
+ q=q,
205
+ k=k,
206
+ v=v,
207
+ beta=beta,
208
+ scale=scale,
209
+ initial_state=initial_state,
210
+ output_final_state=output_final_state,
211
+ offsets=offsets,
212
+ indices=indices,
213
+ head_first=head_first,
214
+ chunk_size=chunk_size
215
+ )
216
+ ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state)
217
+ ctx.chunk_size = chunk_size
218
+ ctx.scale = scale
219
+ ctx.offsets = offsets
220
+ ctx.indices = indices
221
+ ctx.head_first = head_first
222
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
223
+ return o.to(q.dtype), final_state
224
+
225
+ @staticmethod
226
+ @input_guard
227
+ @autocast_custom_bwd
228
+ def backward(
229
+ ctx,
230
+ do: torch.Tensor,
231
+ dht: torch.Tensor
232
+ ):
233
+ q, k, v, beta, A, initial_state = ctx.saved_tensors
234
+ use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel
235
+ if use_qk_l2norm_in_kernel:
236
+ q, q_orig = l2norm_fwd(q), q
237
+ k, k_orig = l2norm_fwd(k), k
238
+
239
+ dq, dk, dv, db, dh0 = chunk_delta_rule_bwd(
240
+ q=q,
241
+ k=k,
242
+ v=v,
243
+ beta=beta,
244
+ A=A,
245
+ scale=ctx.scale,
246
+ initial_state=initial_state,
247
+ do=do,
248
+ dht=dht,
249
+ offsets=ctx.offsets,
250
+ indices=ctx.indices,
251
+ head_first=ctx.head_first,
252
+ chunk_size=ctx.chunk_size
253
+ )
254
+ if use_qk_l2norm_in_kernel:
255
+ dq = l2norm_bwd(q_orig, dq)
256
+ dk = l2norm_bwd(k_orig, dk)
257
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None
258
+
259
+
260
+ @torch.compiler.disable
261
+ def chunk_delta_rule(
262
+ q: torch.Tensor,
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ beta: torch.Tensor,
266
+ scale: float = None,
267
+ initial_state: torch.Tensor = None,
268
+ output_final_state: bool = False,
269
+ cu_seqlens: Optional[torch.LongTensor] = None,
270
+ head_first: bool = False,
271
+ use_qk_l2norm_in_kernel: bool = False
272
+ ):
273
+ r"""
274
+ Args:
275
+ q (torch.Tensor):
276
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
277
+ k (torch.Tensor):
278
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
279
+ v (torch.Tensor):
280
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
281
+ beta (torch.Tensor):
282
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
283
+ scale (Optional[int]):
284
+ Scale factor for the RetNet attention scores.
285
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
286
+ initial_state (Optional[torch.Tensor]):
287
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
288
+ For equal-length input sequences, `N` equals the batch size `B`.
289
+ Default: `None`.
290
+ output_final_state (Optional[bool]):
291
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
292
+ cu_seqlens (torch.LongTensor):
293
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
294
+ consistent with the FlashAttention API.
295
+ head_first (Optional[bool]):
296
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
297
+ Default: `False`.
298
+ use_qk_l2norm_in_kernel (Optional[bool]):
299
+ Whether to use qk l2norm within the kernel for saving GPU memory.
300
+ Default: `False`.
301
+
302
+ Returns:
303
+ o (torch.Tensor):
304
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
305
+ final_state (torch.Tensor):
306
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
307
+
308
+ Examples::
309
+ >>> import torch
310
+ >>> import torch.nn.functional as F
311
+ >>> from einops import rearrange
312
+ >>> from fla.ops.delta_rule import chunk_delta_rule
313
+ # inputs with equal lengths
314
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
315
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
316
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
317
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
318
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
319
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
320
+ >>> o, ht = chunk_delta_rule(
321
+ q, k, v, beta,
322
+ initial_state=h0,
323
+ output_final_state=True
324
+ )
325
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
326
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
327
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
328
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
329
+ >>> o_var, ht_var = chunk_delta_rule(
330
+ q, k, v, beta,
331
+ initial_state=h0,
332
+ output_final_state=True,
333
+ cu_seqlens=cu_seqlens
334
+ )
335
+ """
336
+ assert q.dtype == k.dtype == v.dtype
337
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
338
+ assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
339
+
340
+ if cu_seqlens is not None:
341
+ if q.shape[0] != 1:
342
+ raise ValueError(
343
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
344
+ f"Please flatten variable-length inputs before processing."
345
+ )
346
+ if head_first:
347
+ raise RuntimeError(
348
+ "Sequences with variable lengths are not supported for head-first mode"
349
+ )
350
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
351
+ raise ValueError(
352
+ f"The number of initial states is expected to be equal to the number of input sequences, "
353
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
354
+ )
355
+ if head_first:
356
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
357
+ beta = rearrange(beta, 'b h t -> b t h')
358
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
359
+ o, final_state = ChunkDeltaRuleFunction.apply(
360
+ q,
361
+ k,
362
+ v,
363
+ beta,
364
+ scale,
365
+ initial_state,
366
+ output_final_state,
367
+ cu_seqlens,
368
+ False,
369
+ use_qk_l2norm_in_kernel
370
+ )
371
+ if head_first:
372
+ o = rearrange(o, 'b t h v -> b h t v')
373
+ return o, final_state
fla/ops/delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ u,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr,
40
+ STORE_FINAL_STATE: tl.constexpr,
41
+ IS_BETA_HEADWISE: tl.constexpr,
42
+ USE_OFFSETS: tl.constexpr,
43
+ HEAD_FIRST: tl.constexpr
44
+ ):
45
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+
55
+ if HEAD_FIRST:
56
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
57
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
58
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
59
+ p_u = u + i_nh * T*V + i_v * BV + tl.arange(0, BV)
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
62
+ else:
63
+ p_beta = beta + i_nh * T
64
+ p_o = o + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
65
+ else:
66
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
67
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
68
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
69
+ p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
70
+ if IS_BETA_HEADWISE:
71
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
72
+ else:
73
+ p_beta = beta + bos * H + i_h
74
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+
76
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
77
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
78
+ mask_h = mask_k[None, :] & mask_v[:, None]
79
+
80
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
81
+ if USE_INITIAL_STATE:
82
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
83
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
84
+
85
+ for _ in range(0, T):
86
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
87
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
88
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
90
+ b_v -= b_v_minus
91
+ if IS_BETA_HEADWISE:
92
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
93
+ else:
94
+ b_beta = tl.load(p_beta).to(tl.float32)
95
+ tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
96
+ b_v *= b_beta
97
+ b_h += b_k[None, :] * b_v[:, None]
98
+ b_o = b_h * b_q[None, :]
99
+ b_o = tl.sum(b_o, axis=1)
100
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
101
+
102
+ p_q += K if HEAD_FIRST else H*K
103
+ p_k += K if HEAD_FIRST else H*K
104
+ p_o += V if HEAD_FIRST else H*V
105
+ p_v += V if HEAD_FIRST else H*V
106
+ p_u += V if HEAD_FIRST else H*V
107
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
117
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
118
+ })
119
+ @triton.jit(do_not_specialize=['T'])
120
+ def fused_recurrent_delta_rule_bwd_kernel(
121
+ q,
122
+ k,
123
+ v,
124
+ beta,
125
+ h0,
126
+ dh0,
127
+ dht,
128
+ do,
129
+ dq,
130
+ dk,
131
+ dv,
132
+ db,
133
+ offsets,
134
+ scale,
135
+ B: tl.constexpr,
136
+ T,
137
+ H: tl.constexpr,
138
+ K: tl.constexpr,
139
+ V: tl.constexpr,
140
+ BK: tl.constexpr,
141
+ BV: tl.constexpr,
142
+ NK: tl.constexpr,
143
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar
144
+ USE_INITIAL_STATE: tl.constexpr, # whether to use dh0
145
+ USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht
146
+ USE_OFFSETS: tl.constexpr,
147
+ HEAD_FIRST: tl.constexpr
148
+ ):
149
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
150
+ i_n, i_h = i_nh // H, i_nh % H
151
+ if USE_OFFSETS:
152
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
153
+ all = T
154
+ T = eos - bos
155
+ else:
156
+ bos, eos = i_n * T, i_n * T + T
157
+ all = B * T
158
+
159
+ mask_k = i_k * BK + tl.arange(0, BK) < K
160
+ mask_v = i_v * BV + tl.arange(0, BV) < V
161
+
162
+ if HEAD_FIRST:
163
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
164
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
165
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
166
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
167
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
168
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
169
+ if IS_BETA_HEADWISE:
170
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
171
+ p_dbeta = db + (i_v * NK*B*H + i_k * B*H + i_nh) * T*V + tl.arange(0, BV) + (T - 1) * V
172
+ else:
173
+ p_beta = beta + i_nh * T + T - 1
174
+ p_dbeta = db + (i_v * B*H + i_nh) * T + T - 1
175
+ else:
176
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
177
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
178
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
179
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
180
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
182
+ if IS_BETA_HEADWISE:
183
+ p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
184
+ p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV)
185
+ else:
186
+ p_beta = beta + (bos + T - 1) * H + i_h
187
+ p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h
188
+
189
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
190
+ if USE_FINAL_STATE_GRADIENT:
191
+ p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
192
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
193
+
194
+ for _ in range(T):
195
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
196
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
197
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
198
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
199
+ if IS_BETA_HEADWISE:
200
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
201
+ else:
202
+ b_beta = tl.load(p_beta).to(tl.float32)
203
+ b_dh += b_q[:, None] * b_do[None, :]
204
+ b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1)
205
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
206
+
207
+ b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v)
208
+ b_dv = b_dv * b_beta
209
+
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
212
+ if IS_BETA_HEADWISE:
213
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v)
214
+ else:
215
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty))
216
+
217
+ b_dh -= b_k[:, None] * b_dv[None, :]
218
+
219
+ p_q -= K if HEAD_FIRST else H*K
220
+ p_k -= K if HEAD_FIRST else H*K
221
+ p_v -= V if HEAD_FIRST else H*V
222
+ p_do -= V if HEAD_FIRST else H*V
223
+ p_dk -= K if HEAD_FIRST else H*K
224
+ p_dv -= V if HEAD_FIRST else H*V
225
+ p_dbeta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
226
+ p_beta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
227
+
228
+ if USE_INITIAL_STATE:
229
+ p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
230
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
231
+
232
+ tl.debug_barrier()
233
+
234
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
235
+
236
+ if HEAD_FIRST:
237
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
238
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
239
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
240
+ if IS_BETA_HEADWISE:
241
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
242
+ else:
243
+ p_beta = beta + i_nh * T
244
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV)
245
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
246
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
247
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
248
+ else:
249
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
250
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
251
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
252
+ if IS_BETA_HEADWISE:
253
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
254
+ else:
255
+ p_beta = beta + bos * H + i_h
256
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
257
+ p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
258
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
260
+
261
+ if USE_INITIAL_STATE:
262
+ mask_h = mask_k[:, None] & mask_v[None, :]
263
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
264
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
265
+
266
+ for _ in range(0, T):
267
+ b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32)
268
+ b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32)
269
+ b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1)
270
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
271
+
272
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
273
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
274
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
275
+ if IS_BETA_HEADWISE:
276
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
277
+ else:
278
+ b_beta = tl.load(p_beta).to(tl.float32)
279
+ b_v *= b_beta
280
+
281
+ b_h += b_k[:, None] * b_v[None, :]
282
+ b_dq = b_h * b_do[None, :]
283
+ d_q = tl.sum(b_dq, axis=1) * scale
284
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
285
+
286
+ p_k += K if HEAD_FIRST else H*K
287
+ p_v += V if HEAD_FIRST else H*V
288
+ p_do += V if HEAD_FIRST else H*V
289
+ p_dq += K if HEAD_FIRST else H*K
290
+ p_dk += K if HEAD_FIRST else H*K
291
+ p_dv += V if HEAD_FIRST else H*V
292
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
293
+
294
+
295
+ def fused_recurrent_delta_rule_fwd(
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ beta: torch.Tensor,
300
+ scale: float,
301
+ initial_state: torch.Tensor,
302
+ output_final_state: bool,
303
+ offsets: Optional[torch.LongTensor] = None,
304
+ head_first: bool = True
305
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ N = B if offsets is None else len(offsets) - 1
311
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
312
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
313
+ assert NK == 1, "NK > 1 is not supported yet"
314
+ num_stages = 1
315
+ num_warps = 1
316
+
317
+ o = q.new_empty(NK, *v.shape)
318
+ if output_final_state:
319
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
320
+ else:
321
+ final_state = None
322
+
323
+ grid = (NV, NK, N * H)
324
+ u = torch.empty_like(v)
325
+ fused_recurrent_delta_rule_fwd_kernel[grid](
326
+ q,
327
+ k,
328
+ v,
329
+ u,
330
+ beta,
331
+ o,
332
+ initial_state,
333
+ final_state,
334
+ offsets,
335
+ scale,
336
+ T=T,
337
+ B=B,
338
+ H=H,
339
+ K=K,
340
+ V=V,
341
+ BK=BK,
342
+ BV=BV,
343
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
344
+ HEAD_FIRST=head_first,
345
+ num_warps=num_warps,
346
+ num_stages=num_stages,
347
+ )
348
+ o = o.squeeze(0)
349
+ return o, u, final_state
350
+
351
+
352
+ def fused_recurrent_delta_rule_bwd(
353
+ q: torch.Tensor,
354
+ k: torch.Tensor,
355
+ v: torch.Tensor,
356
+ beta: torch.Tensor,
357
+ dht: torch.Tensor,
358
+ do: torch.Tensor,
359
+ scale: float,
360
+ initial_state: torch.Tensor,
361
+ offsets: Optional[torch.LongTensor] = None,
362
+ head_first: bool = True
363
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
364
+ if head_first:
365
+ B, H, T, K, V = *k.shape, v.shape[-1]
366
+ else:
367
+ B, T, H, K, V = *k.shape, v.shape[-1]
368
+ N = B if offsets is None else len(offsets) - 1
369
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
370
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
371
+ assert NK == 1, "NK > 1 is not supported yet"
372
+ num_stages = 1
373
+ num_warps = 2
374
+
375
+ beta_vector = beta.ndim == v.ndim
376
+
377
+ dq = q.new_empty(NV, *q.shape)
378
+ dk = q.new_empty(NV, *k.shape)
379
+ dv = q.new_empty(NK, *v.shape)
380
+ if beta_vector:
381
+ db = q.new_empty(NV, NK, B, H, T, V) if head_first else q.new_empty(NV, NK, B, T, H, V)
382
+ else:
383
+ db = q.new_empty(NV, B, H, T) if head_first else q.new_empty(NV, B, T, H)
384
+ grid = (NV, NK, N * H)
385
+
386
+ if initial_state is not None and initial_state.requires_grad:
387
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
388
+ else:
389
+ dh0 = None
390
+
391
+ fused_recurrent_delta_rule_bwd_kernel[grid](
392
+ q,
393
+ k,
394
+ v,
395
+ beta,
396
+ initial_state,
397
+ dh0,
398
+ dht,
399
+ do,
400
+ dq,
401
+ dk,
402
+ dv,
403
+ db,
404
+ offsets,
405
+ scale,
406
+ T=T,
407
+ B=B,
408
+ H=H,
409
+ K=K,
410
+ V=V,
411
+ BK=BK,
412
+ BV=BV,
413
+ NK=NK,
414
+ IS_BETA_HEADWISE=beta_vector,
415
+ HEAD_FIRST=head_first,
416
+ num_warps=num_warps,
417
+ num_stages=num_stages
418
+ )
419
+ dq = dq.sum(0)
420
+ dk = dk.sum(0)
421
+ dv = dv.sum(0)
422
+ db = db.sum((0, 1)) if beta_vector else db.sum(0)
423
+
424
+ return dq, dk, dv, db, dh0
425
+
426
+
427
+ class FusedRecurrentFunction(torch.autograd.Function):
428
+
429
+ @staticmethod
430
+ @input_guard
431
+ def forward(
432
+ ctx,
433
+ q: torch.Tensor,
434
+ k: torch.Tensor,
435
+ v: torch.Tensor,
436
+ beta: torch.Tensor,
437
+ scale: float,
438
+ initial_state: torch.Tensor,
439
+ output_final_state: bool,
440
+ offsets: Optional[torch.LongTensor] = None,
441
+ head_first: bool = True,
442
+ use_qk_l2norm_in_kernel: bool = False
443
+ ):
444
+ q_orig = q
445
+ k_orig = k
446
+
447
+ if use_qk_l2norm_in_kernel:
448
+ q = l2norm_fwd(q)
449
+ k = l2norm_fwd(k)
450
+
451
+ o, u, final_state = fused_recurrent_delta_rule_fwd(
452
+ q=q,
453
+ k=k,
454
+ v=v,
455
+ beta=beta,
456
+ scale=scale,
457
+ initial_state=initial_state,
458
+ output_final_state=output_final_state,
459
+ offsets=offsets,
460
+ head_first=head_first
461
+ )
462
+
463
+ ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state)
464
+ ctx.scale = scale
465
+ ctx.offsets = offsets
466
+ ctx.head_first = head_first
467
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
468
+ return o, final_state
469
+
470
+ @staticmethod
471
+ @input_guard
472
+ def backward(ctx, do, dht):
473
+ q, k, v, beta, initial_state = ctx.saved_tensors
474
+ if ctx.use_qk_l2norm_in_kernel:
475
+ q, q_orig = l2norm_fwd(q), q
476
+ k, k_orig = l2norm_fwd(k), k
477
+ dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd(
478
+ q=q,
479
+ k=k,
480
+ v=v,
481
+ beta=beta,
482
+ dht=dht,
483
+ do=do,
484
+ scale=ctx.scale,
485
+ initial_state=initial_state,
486
+ offsets=ctx.offsets,
487
+ head_first=ctx.head_first
488
+ )
489
+ if ctx.use_qk_l2norm_in_kernel:
490
+ dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk)
491
+ return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None, None
492
+
493
+
494
+ @torch.compiler.disable
495
+ def fused_recurrent_delta_rule(
496
+ q: torch.Tensor,
497
+ k: torch.Tensor,
498
+ v: torch.Tensor,
499
+ beta: torch.Tensor = None,
500
+ scale: float = None,
501
+ initial_state: torch.Tensor = None,
502
+ output_final_state: bool = False,
503
+ cu_seqlens: Optional[torch.LongTensor] = None,
504
+ head_first: bool = True,
505
+ use_qk_l2norm_in_kernel: bool = False
506
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
507
+ r"""
508
+ Args:
509
+ q (torch.Tensor):
510
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
511
+ k (torch.Tensor):
512
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
513
+ v (torch.Tensor):
514
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
515
+ beta (torch.Tensor):
516
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
517
+ scale (Optional[int]):
518
+ Scale factor for the RetNet attention scores.
519
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
520
+ initial_state (Optional[torch.Tensor]):
521
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
522
+ For equal-length input sequences, `N` equals the batch size `B`.
523
+ Default: `None`.
524
+ output_final_state (Optional[bool]):
525
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
526
+ cu_seqlens (torch.LongTensor):
527
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
528
+ consistent with the FlashAttention API.
529
+ head_first (Optional[bool]):
530
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
531
+ Default: `False`.
532
+
533
+ Returns:
534
+ o (torch.Tensor):
535
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
536
+ final_state (torch.Tensor):
537
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
538
+
539
+ Examples::
540
+ >>> import torch
541
+ >>> import torch.nn.functional as F
542
+ >>> from einops import rearrange
543
+ >>> from fla.ops.delta_rule import fused_recurrent_delta_rule
544
+ # inputs with equal lengths
545
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
546
+ >>> q = torch.randn(B, T, H, K, device='cuda')
547
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
548
+ >>> v = torch.randn(B, T, H, V, device='cuda')
549
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
550
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
551
+ >>> o, ht = fused_recurrent_delta_rule(
552
+ q, k, v, beta,
553
+ initial_state=h0,
554
+ output_final_state=True
555
+ )
556
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
557
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
558
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
559
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
560
+ >>> o_var, ht_var = fused_recurrent_delta_rule(
561
+ q, k, v, beta,
562
+ initial_state=h0,
563
+ output_final_state=True,
564
+ cu_seqlens=cu_seqlens
565
+ )
566
+ >>> assert o.allclose(o_var.view(o.shape))
567
+ >>> assert ht.allclose(ht_var)
568
+ """
569
+ if cu_seqlens is not None:
570
+ if q.shape[0] != 1:
571
+ raise ValueError(
572
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
573
+ f"Please flatten variable-length inputs before processing."
574
+ )
575
+ if head_first:
576
+ raise RuntimeError(
577
+ "Sequences with variable lengths are not supported for head-first mode"
578
+ )
579
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
580
+ raise ValueError(
581
+ f"The number of initial states is expected to be equal to the number of input sequences, "
582
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
583
+ )
584
+ if scale is None:
585
+ scale = k.shape[-1] ** -0.5
586
+ else:
587
+ assert scale > 0, "scale must be positive"
588
+ if beta is None:
589
+ beta = torch.ones_like(q[..., 0])
590
+ if head_first:
591
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
592
+ beta = rearrange(beta, 'b h t -> b t h')
593
+ o, final_state = FusedRecurrentFunction.apply(
594
+ q,
595
+ k,
596
+ v,
597
+ beta,
598
+ scale,
599
+ initial_state,
600
+ output_final_state,
601
+ cu_seqlens,
602
+ False,
603
+ use_qk_l2norm_in_kernel
604
+ )
605
+ if head_first:
606
+ o = rearrange(o, 'b t h v -> b h t v')
607
+ return o, final_state
fla/ops/delta_rule/parallel.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.delta_rule.wy_fast import fwd_prepare_T
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.autotune(
16
+ configs=[
17
+ triton.Config({}, num_warps=num_warps)
18
+ for num_warps in [1, 2, 4]
19
+ ],
20
+ key=['BT', 'K', 'V'],
21
+ )
22
+ @triton.jit(do_not_specialize=['T'])
23
+ def chunk_transform_qk_fwd_kernel(
24
+ q,
25
+ k,
26
+ v,
27
+ beta,
28
+ o,
29
+ A,
30
+ q_new,
31
+ k_new,
32
+ A_local,
33
+ scale,
34
+ T,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ OUTPUT_ATTENTIONS: tl.constexpr
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+
44
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
45
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
46
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
47
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(p_q.dtype.element_ty)
48
+ b_k = tl.load(p_k, boundary_check=(0, 1))
49
+ b_v = tl.load(p_v, boundary_check=(0, 1))
50
+
51
+ p_T = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_T = tl.load(p_T, boundary_check=(0, 1))
53
+
54
+ o_i = tl.arange(0, BT)
55
+ m_t = o_i[:, None] >= o_i[None, :]
56
+ b_qk = tl.where(m_t, tl.dot(b_q, tl.trans(b_k), allow_tf32=False), 0).to(b_q.dtype)
57
+ m_t = o_i[:, None] > o_i[None, :]
58
+ b_kk = tl.where(m_t, tl.dot(b_k, tl.trans(b_k), allow_tf32=False), 0).to(b_k.dtype)
59
+
60
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (i_t * BT, ), (BT, ), (0, ))
61
+ b_beta = tl.load(p_beta, boundary_check=(0, ))
62
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
63
+
64
+ b_qkT = tl.dot(b_qk, b_T, allow_tf32=False).to(b_k.dtype)
65
+
66
+ if OUTPUT_ATTENTIONS:
67
+ p_a = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_a, b_qkT.to(p_a.dtype.element_ty), boundary_check=(0, 1))
69
+
70
+ b_kkT = tl.dot(b_kk, b_T, allow_tf32=False).to(b_k.dtype)
71
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
72
+ tl.store(p_o, tl.dot(b_qkT, b_v).to(p_o.dtype.element_ty), boundary_check=(0, 1))
73
+
74
+ p_q_new = tl.make_block_ptr(q_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
75
+ tl.store(p_q_new, (b_q - tl.dot(b_qkT, b_k_beta, allow_tf32=False)).to(p_q_new.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+ p_k_new = tl.make_block_ptr(k_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
78
+ b_k_new = b_k - tl.dot(tl.trans(b_kkT), b_k_beta, allow_tf32=False)
79
+ tl.store(p_k_new, b_k_new.to(p_k_new.dtype.element_ty), boundary_check=(0, 1))
80
+
81
+
82
+ def chunk_transform_qk_fwd(
83
+ q: torch.Tensor,
84
+ k: torch.Tensor,
85
+ v: torch.Tensor,
86
+ beta: torch.Tensor,
87
+ A: torch.Tensor,
88
+ scale: float,
89
+ chunk_size: int,
90
+ output_attentions: bool
91
+ ):
92
+ B, H, T, K = k.shape
93
+ BT = chunk_size
94
+ q_new = torch.empty_like(q)
95
+ k_new = torch.empty_like(k)
96
+ o = torch.empty_like(v)
97
+ grid = (triton.cdiv(T, BT), B*H)
98
+ V = v.shape[-1]
99
+ A_local = torch.empty_like(A) if output_attentions else None
100
+ chunk_transform_qk_fwd_kernel[grid](
101
+ q,
102
+ k,
103
+ v,
104
+ beta,
105
+ o,
106
+ A,
107
+ q_new,
108
+ k_new,
109
+ A_local,
110
+ scale=scale,
111
+ T=T,
112
+ K=K,
113
+ V=V,
114
+ BT=BT,
115
+ BK=triton.next_power_of_2(K),
116
+ BV=triton.next_power_of_2(V),
117
+ OUTPUT_ATTENTIONS=output_attentions
118
+ )
119
+ return q_new, k_new, o, A_local
120
+
121
+
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=1),
125
+ triton.Config({}, num_warps=2),
126
+ ],
127
+ key=['BT'],
128
+ )
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def save_intra_chunk_attn(
131
+ A,
132
+ A_local,
133
+ T,
134
+ BT: tl.constexpr,
135
+ ):
136
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
137
+ p_A = tl.make_block_ptr(A + i_bh * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BT), (1, 0))
138
+ p_A_local = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
139
+ b_A_local = tl.load(p_A_local, boundary_check=(0, 1))
140
+ tl.store(p_A, b_A_local.to(p_A.dtype.element_ty), boundary_check=(0, 1))
141
+
142
+
143
+ @triton.heuristics({
144
+ 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None
145
+ })
146
+ @triton.jit(do_not_specialize=['T'])
147
+ def parallel_delta_rule_fwd_kernel(
148
+ q,
149
+ k,
150
+ k2, # original k
151
+ v,
152
+ beta,
153
+ o,
154
+ o_new,
155
+ attn,
156
+ T,
157
+ K: tl.constexpr,
158
+ V: tl.constexpr,
159
+ BT: tl.constexpr,
160
+ BS: tl.constexpr,
161
+ BK: tl.constexpr,
162
+ BV: tl.constexpr,
163
+ OUTPUT_ATTENTIONS: tl.constexpr
164
+ ):
165
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
166
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
167
+
168
+ # the Q block is kept in the shared memory throughout the whole kernel
169
+ # [BT, BK]
170
+ b_q = tl.zeros([BT, BK], dtype=tl.float32)
171
+ b_q += tl.load(p_q, boundary_check=(0, 1))
172
+
173
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
174
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
175
+ b_o += tl.load(p_o, boundary_check=(0, 1))
176
+
177
+ # As opposed to Flashattention, this kernel requires scanning the KV blocks from right to left
178
+ # Q block and K block have overlap.
179
+ # masks required
180
+ for offset in range((i_t + 1) * BT - 2 * BS, i_t * BT - BS, -BS):
181
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1))
182
+ p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0))
183
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0))
184
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,))
185
+ # [BK, BS]
186
+ b_k = tl.load(p_k, boundary_check=(0, 1))
187
+ # [BS, BV]
188
+ b_v = tl.load(p_v, boundary_check=(0, 1))
189
+ # [BS]
190
+ b_beta = tl.load(p_beta, boundary_check=(0,))
191
+ # [BT, BS]
192
+ m_s = tl.arange(0, BT) >= (offset - i_t*BT + BS)
193
+ b_s = tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False)
194
+ b_s = tl.where(m_s[:, None], b_s, 0)
195
+
196
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
197
+ b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype)
198
+ b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False)
199
+
200
+ if OUTPUT_ATTENTIONS:
201
+ p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0))
202
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+ # Q block and K block have no overlap
205
+ # no need for mask, thereby saving flops
206
+ for offset in range(i_t * BT - BS, -BS, -BS):
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1))
208
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0))
209
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,))
210
+ p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0))
211
+
212
+ # [BK, BS]
213
+ b_k = tl.load(p_k, boundary_check=(0, 1))
214
+ # [BS, BV]
215
+ b_v = tl.load(p_v, boundary_check=(0, 1))
216
+ # [BS]
217
+ b_beta = tl.load(p_beta, boundary_check=(0,))
218
+ # [BT, BS]
219
+ b_s = (tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False))
220
+ # [BT, BV]
221
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
222
+ b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype)
223
+ b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False).to(b_q.dtype)
224
+
225
+ if OUTPUT_ATTENTIONS:
226
+ p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0))
227
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ p_o_new = tl.make_block_ptr(o_new + i_bh * T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
230
+ tl.store(p_o_new, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
231
+
232
+
233
+ class ParallelDeltaRuleFunction(torch.autograd.Function):
234
+
235
+ @staticmethod
236
+ @input_guard
237
+ @autocast_custom_fwd
238
+ def forward(ctx, q, k, v, beta, scale, output_attentions):
239
+ B, H, T, K, V = *k.shape, v.shape[-1]
240
+ assert q.shape[-1] <= 128, 'The maximum supported sequence length is 128.'
241
+ BT, BS = 128, 32
242
+ BK = triton.next_power_of_2(k.shape[-1])
243
+ BV = triton.next_power_of_2(v.shape[-1])
244
+ assert BT % BS == 0
245
+
246
+ A = fwd_prepare_T(k, beta, BS)
247
+ attn = q.new_zeros(B, H, T, T) if output_attentions else None
248
+ q_new, k_new, o, A_local = chunk_transform_qk_fwd(
249
+ q,
250
+ k,
251
+ v,
252
+ beta,
253
+ A,
254
+ scale,
255
+ BS,
256
+ output_attentions
257
+ )
258
+
259
+ num_stages = 3 if K <= 64 else 2
260
+ num_warps = 4
261
+ grid = (triton.cdiv(T, BT), B * H)
262
+ o_new = torch.empty_like(o)
263
+
264
+ parallel_delta_rule_fwd_kernel[grid](
265
+ q=q_new,
266
+ k=k_new,
267
+ k2=k,
268
+ v=v,
269
+ beta=beta,
270
+ o=o,
271
+ o_new=o_new,
272
+ attn=attn,
273
+ T=T,
274
+ K=K,
275
+ V=V,
276
+ BT=BT,
277
+ BS=BS,
278
+ BK=BK,
279
+ BV=BV,
280
+ num_stages=num_stages,
281
+ num_warps=num_warps
282
+ )
283
+
284
+ if output_attentions:
285
+ grid = (triton.cdiv(T, BS), B * H)
286
+ save_intra_chunk_attn[grid](
287
+ A=attn,
288
+ A_local=A_local,
289
+ T=T,
290
+ BT=BS
291
+ )
292
+ return o_new.to(q.dtype), attn
293
+
294
+ @staticmethod
295
+ @input_guard
296
+ @autocast_custom_bwd
297
+ def backward(ctx, do, d_attn=None):
298
+ raise NotImplementedError('Backward pass is not implemented. Stay tuned!')
299
+
300
+
301
+ def parallel_delta_rule(
302
+ q: torch.Tensor,
303
+ k: torch.Tensor,
304
+ v: torch.Tensor,
305
+ beta: torch.Tensor,
306
+ scale: float = None,
307
+ output_attentions: bool = False,
308
+ head_first: bool = True
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ r"""
311
+ Args:
312
+ q (torch.Tensor):
313
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
314
+ k (torch.Tensor):
315
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
316
+ v (torch.Tensor):
317
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
318
+ beta (torch.Tensor):
319
+ betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`.
320
+ scale (Optional[int]):
321
+ Scale factor for attention scores.
322
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
323
+ output_attentions (bool):
324
+ Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`.
325
+ head_first (Optional[bool]):
326
+ Whether the inputs are in the head-first format.
327
+ Default: `True`.
328
+
329
+ Returns:
330
+ o (torch.Tensor):
331
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
332
+ attn (torch.Tensor):
333
+ Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`.
334
+ """
335
+ if not head_first:
336
+ q, k, v, beta = map(lambda x: x.transpose(1, 2), (q, k, v, beta))
337
+ o, attn = ParallelDeltaRuleFunction.apply(q, k, v, beta, scale, output_attentions)
338
+ if not head_first:
339
+ o = o.transpose(1, 2)
340
+ return o, attn
341
+
342
+
343
+ def naive_delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
344
+ b, h, l, d_k = q.shape
345
+ q = q * (d_k ** -0.5)
346
+ v = v * beta[..., None]
347
+ k_beta = k * beta[..., None]
348
+ # compute (I - tri(diag(beta) KK^T))^{-1}
349
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
350
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
351
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
352
+ for i in range(1, BN):
353
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
354
+ T = T + torch.eye(BN, dtype=q.dtype, device=q.device)
355
+
356
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
357
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
358
+ o_intra = A_local @ v
359
+
360
+ # apply cumprod transition matrices on k to the last position within the chunk
361
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
362
+ # apply cumprod transition matrices on q to the first position within the chunk
363
+ q = q - A_local @ k_beta
364
+ o_intra = A_local @ v
365
+
366
+ A = torch.zeros(b, h, l, l, device=q.device)
367
+
368
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
369
+ o = torch.empty_like(v)
370
+ for i in range(0, l, BM):
371
+ q_i = q[:, :, i:i+BM]
372
+ o_i = o_intra[:, :, i:i+BM]
373
+ # intra block
374
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
375
+ k_j = k[:, :, j:j+BN]
376
+ A_ij = q_i @ k_j.transpose(-1, -2)
377
+ mask = torch.arange(i, i+BM) >= (j + BN)
378
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
379
+ A[:, :, i:i+BM, j:j+BN] = A_ij
380
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
381
+ o_i += A_ij @ v[:, :, j:j+BN]
382
+ # inter block
383
+ for j in range(i - BN, -BN, -BN):
384
+ k_j = k[:, :, j:j+BN]
385
+ A_ij = q_i @ k_j.transpose(-1, -2)
386
+ A[:, :, i:i+BM, j:j+BN] = A_ij
387
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
388
+ o_i += A_ij @ v[:, :, j:j+BN]
389
+ o[:, :, i:i+BM] = o_i
390
+
391
+ for i in range(0, l//BN):
392
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
393
+
394
+ return o, A
fla/ops/gated_delta_rule/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_gated_delta_rule
2
+ from .fused_recurrent import fused_recurrent_gated_delta_rule
3
+
4
+ __all__ = [
5
+ "chunk_gated_delta_rule",
6
+ "fused_recurrent_gated_delta_rule"
7
+ ]
fla/ops/gated_delta_rule/chunk.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
14
+ from fla.ops.utils import chunk_local_cumsum
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_gated_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ g: torch.Tensor,
23
+ beta: torch.Tensor,
24
+ scale: float,
25
+ initial_state: torch.Tensor,
26
+ output_final_state: bool,
27
+ offsets: Optional[torch.LongTensor] = None,
28
+ indices: Optional[torch.LongTensor] = None,
29
+ head_first: bool = True,
30
+ chunk_size: int = 64
31
+ ):
32
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, Aw, Au = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ g=g,
39
+ offsets=offsets,
40
+ indices=indices,
41
+ head_first=head_first,
42
+ chunk_size=chunk_size
43
+ )
44
+
45
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
46
+ k=k,
47
+ w=w,
48
+ u=u,
49
+ g=g,
50
+ initial_state=initial_state,
51
+ output_final_state=output_final_state,
52
+ offsets=offsets,
53
+ indices=indices,
54
+ head_first=head_first,
55
+ chunk_size=chunk_size
56
+ )
57
+
58
+ # obtain output
59
+ o = chunk_fwd_o(
60
+ q=q,
61
+ k=k,
62
+ v=v_new,
63
+ h=h,
64
+ g=g,
65
+ scale=scale,
66
+ offsets=offsets,
67
+ indices=indices,
68
+ head_first=head_first,
69
+ chunk_size=chunk_size
70
+ )
71
+ return g, o, Aw, Au, final_state
72
+
73
+
74
+ def chunk_gated_delta_rule_bwd(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ v: torch.Tensor,
78
+ g: torch.Tensor,
79
+ beta: torch.Tensor,
80
+ Aw: torch.Tensor,
81
+ Au: torch.Tensor,
82
+ scale: float,
83
+ initial_state: torch.Tensor,
84
+ do: torch.Tensor,
85
+ dht: torch.Tensor,
86
+ offsets: Optional[torch.LongTensor] = None,
87
+ indices: Optional[torch.LongTensor] = None,
88
+ head_first: bool = True,
89
+ chunk_size: int = 64
90
+ ):
91
+ T = q.shape[2] if head_first else q.shape[1]
92
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
93
+ w, u = fwd_recompute_w_u(
94
+ k=k,
95
+ v=v,
96
+ beta=beta,
97
+ Aw=Aw,
98
+ Au=Au,
99
+ offsets=offsets,
100
+ indices=indices,
101
+ head_first=head_first,
102
+ chunk_size=BT
103
+ )
104
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
105
+ k=k,
106
+ w=w,
107
+ u=u,
108
+ g=g,
109
+ initial_state=initial_state,
110
+ output_final_state=False,
111
+ offsets=offsets,
112
+ indices=indices,
113
+ head_first=head_first,
114
+ chunk_size=BT
115
+ )
116
+ dv = chunk_bwd_dv_local(
117
+ q=q,
118
+ k=k,
119
+ g=g,
120
+ do=do,
121
+ dh=None,
122
+ scale=scale,
123
+ offsets=offsets,
124
+ indices=indices,
125
+ head_first=head_first,
126
+ chunk_size=BT
127
+ )
128
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
129
+ q=q,
130
+ k=k,
131
+ w=w,
132
+ g=g,
133
+ h0=initial_state,
134
+ dht=dht,
135
+ do=do,
136
+ dv=dv,
137
+ scale=scale,
138
+ offsets=offsets,
139
+ indices=indices,
140
+ head_first=head_first,
141
+ chunk_size=BT
142
+ )
143
+ dq, dk, dw, dg = chunk_bwd_dqkwg(
144
+ q=q,
145
+ k=k,
146
+ v=v_new,
147
+ w=w,
148
+ g=g,
149
+ h=h,
150
+ dv=dv,
151
+ do=do,
152
+ dh=dh,
153
+ scale=scale,
154
+ offsets=offsets,
155
+ indices=indices,
156
+ head_first=head_first,
157
+ chunk_size=BT
158
+ )
159
+ dk2, dv, db, dg2 = bwd_prepare_wy_repr(
160
+ k=k,
161
+ v=v,
162
+ beta=beta,
163
+ g=g,
164
+ Aw=Aw,
165
+ Au=Au,
166
+ dw=dw,
167
+ du=dv,
168
+ offsets=offsets,
169
+ indices=indices,
170
+ head_first=head_first,
171
+ chunk_size=BT
172
+ )
173
+ dk.add_(dk2)
174
+ dg.add_(dg2)
175
+ assert dg.dtype == torch.float32, "dg should be fp32"
176
+ dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first)
177
+ return dq, dk, dv, db, dg, dh0
178
+
179
+
180
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
181
+
182
+ @staticmethod
183
+ @input_guard
184
+ @autocast_custom_fwd
185
+ def forward(
186
+ ctx,
187
+ q: torch.Tensor,
188
+ k: torch.Tensor,
189
+ v: torch.Tensor,
190
+ g: torch.Tensor,
191
+ beta: torch.Tensor,
192
+ scale: float,
193
+ initial_state: torch.Tensor,
194
+ output_final_state: bool,
195
+ offsets: Optional[torch.LongTensor] = None,
196
+ head_first: bool = True,
197
+ use_qk_l2norm_in_kernel: bool = False
198
+ ):
199
+ chunk_size = 64
200
+ q_orig = q
201
+ k_orig = k
202
+
203
+ if use_qk_l2norm_in_kernel:
204
+ q = l2norm_fwd(q)
205
+ k = l2norm_fwd(k)
206
+
207
+ # 2-d indices denoting the offsets of chunks in each sequence
208
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
209
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
210
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
211
+ indices = None
212
+ if offsets is not None:
213
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
214
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
215
+
216
+ g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd(
217
+ q=q,
218
+ k=k,
219
+ v=v,
220
+ g=g,
221
+ beta=beta,
222
+ scale=scale,
223
+ initial_state=initial_state,
224
+ output_final_state=output_final_state,
225
+ offsets=offsets,
226
+ indices=indices,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ )
230
+ ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices)
231
+ ctx.chunk_size = chunk_size
232
+ ctx.scale = scale
233
+ ctx.head_first = head_first
234
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
235
+ return o.to(q.dtype), final_state
236
+
237
+ @staticmethod
238
+ @input_guard
239
+ @autocast_custom_bwd
240
+ def backward(
241
+ ctx,
242
+ do: torch.Tensor,
243
+ dht: torch.Tensor
244
+ ):
245
+ q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors
246
+ if ctx.use_qk_l2norm_in_kernel:
247
+ q, q_orig = l2norm_fwd(q), q
248
+ k, k_orig = l2norm_fwd(k), k
249
+ dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ g=g,
254
+ beta=beta,
255
+ Aw=Aw,
256
+ Au=Au,
257
+ scale=ctx.scale,
258
+ initial_state=initial_state,
259
+ do=do,
260
+ dht=dht,
261
+ offsets=offsets,
262
+ indices=indices,
263
+ head_first=ctx.head_first,
264
+ chunk_size=ctx.chunk_size
265
+ )
266
+ if ctx.use_qk_l2norm_in_kernel:
267
+ dq = l2norm_bwd(q_orig, dq)
268
+ dk = l2norm_bwd(k_orig, dk)
269
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
270
+
271
+
272
+ @torch.compiler.disable
273
+ def chunk_gated_delta_rule(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ g: torch.Tensor,
278
+ beta: torch.Tensor,
279
+ scale: float = None,
280
+ initial_state: torch.Tensor = None,
281
+ output_final_state: bool = False,
282
+ cu_seqlens: Optional[torch.LongTensor] = None,
283
+ head_first: bool = False,
284
+ use_qk_l2norm_in_kernel: bool = False
285
+ ):
286
+ r"""
287
+ Args:
288
+ q (torch.Tensor):
289
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
290
+ k (torch.Tensor):
291
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
292
+ v (torch.Tensor):
293
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
294
+ g (torch.Tensor):
295
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
296
+ beta (torch.Tensor):
297
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
298
+ scale (Optional[int]):
299
+ Scale factor for the RetNet attention scores.
300
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
301
+ initial_state (Optional[torch.Tensor]):
302
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
303
+ For equal-length input sequences, `N` equals the batch size `B`.
304
+ Default: `None`.
305
+ output_final_state (Optional[bool]):
306
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
307
+ cu_seqlens (torch.LongTensor):
308
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
309
+ consistent with the FlashAttention API.
310
+ head_first (Optional[bool]):
311
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
312
+ Default: `False`.
313
+
314
+ Returns:
315
+ o (torch.Tensor):
316
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
317
+ final_state (torch.Tensor):
318
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319
+
320
+ Examples::
321
+ >>> import torch
322
+ >>> import torch.nn.functional as F
323
+ >>> from einops import rearrange
324
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
325
+ # inputs with equal lengths
326
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
327
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
328
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
329
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
330
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
331
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
332
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
333
+ >>> o, ht = chunk_gated_delta_rule(
334
+ q, k, v, g, beta,
335
+ initial_state=h0,
336
+ output_final_state=True,
337
+ head_first=False
338
+ )
339
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
340
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
341
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
342
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
343
+ >>> o_var, ht_var = chunk_gated_delta_rule(
344
+ q, k, v, g, beta,
345
+ initial_state=h0,
346
+ output_final_state=True,
347
+ cu_seqlens=cu_seqlens,
348
+ head_first=False
349
+ )
350
+ """
351
+ assert q.dtype == k.dtype == v.dtype
352
+ assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
353
+ assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False."
354
+
355
+ if cu_seqlens is not None:
356
+ if q.shape[0] != 1:
357
+ raise ValueError(
358
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
359
+ f"Please flatten variable-length inputs before processing."
360
+ )
361
+ if head_first:
362
+ raise RuntimeError(
363
+ "Sequences with variable lengths are not supported for head-first mode"
364
+ )
365
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
366
+ raise ValueError(
367
+ f"The number of initial states is expected to be equal to the number of input sequences, "
368
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
369
+ )
370
+ if head_first:
371
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
372
+ beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g))
373
+ if scale is None:
374
+ scale = k.shape[-1] ** -0.5
375
+ else:
376
+ assert scale > 0, "Scale must be positive."
377
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
378
+ q,
379
+ k,
380
+ v,
381
+ g,
382
+ beta,
383
+ scale,
384
+ initial_state,
385
+ output_final_state,
386
+ cu_seqlens,
387
+ False,
388
+ use_qk_l2norm_in_kernel
389
+ )
390
+ if head_first:
391
+ o = rearrange(o, 'b t h v -> b h t v')
392
+ return o, final_state
fla/ops/gated_delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_gated_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ g,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
40
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
41
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
42
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
43
+ USE_OFFSETS: tl.constexpr
44
+ ):
45
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+ o_k = i_k * BK + tl.arange(0, BK)
55
+ o_v = i_v * BV + tl.arange(0, BV)
56
+
57
+ p_q = q + (bos * H + i_h) * K + o_k
58
+ p_k = k + (bos * H + i_h) * K + o_k
59
+ p_v = v + (bos * H + i_h) * V + o_v
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + (bos * H + i_h) * V + o_v
62
+ else:
63
+ p_beta = beta + bos * H + i_h
64
+ p_g = g + bos * H + i_h
65
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v
66
+
67
+ mask_k = o_k < K
68
+ mask_v = o_v < V
69
+ mask_h = mask_k[:, None] & mask_v[None, :]
70
+
71
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
74
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
75
+
76
+ for _ in range(0, T):
77
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_g = tl.load(p_g).to(tl.float32)
81
+
82
+ if USE_QK_L2NORM_IN_KERNEL:
83
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
84
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
85
+ b_q = b_q * scale
86
+ # [BK, BV]
87
+ b_h *= exp(b_g)
88
+ # [BV]
89
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
90
+ if IS_BETA_HEADWISE:
91
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
92
+ else:
93
+ b_beta = tl.load(p_beta).to(tl.float32)
94
+ b_v *= b_beta
95
+ # [BK, BV]
96
+ b_h += b_k[:, None] * b_v[None, :]
97
+ # [BV]
98
+ b_o = tl.sum(b_h * b_q[:, None], 0)
99
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
100
+
101
+ p_q += H*K
102
+ p_k += H*K
103
+ p_o += H*V
104
+ p_v += H*V
105
+ p_g += H
106
+ p_beta += H * (V if IS_BETA_HEADWISE else 1)
107
+
108
+ if STORE_FINAL_STATE:
109
+ p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :]
110
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
111
+
112
+
113
+ def fused_recurrent_gated_delta_rule_fwd(
114
+ q: torch.Tensor,
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ g: torch.Tensor,
118
+ beta: torch.Tensor,
119
+ scale: float,
120
+ initial_state: torch.Tensor,
121
+ output_final_state: bool,
122
+ use_qk_l2norm_in_kernel: bool = False,
123
+ offsets: Optional[torch.LongTensor] = None,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ B, T, H, K, V = *k.shape, v.shape[-1]
126
+ N = B if offsets is None else len(offsets) - 1
127
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
128
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
129
+ assert NK == 1, "NK > 1 is not supported yet"
130
+ num_stages = 3
131
+ num_warps = 1
132
+
133
+ o = q.new_empty(NK, *v.shape)
134
+ if output_final_state:
135
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
136
+ else:
137
+ final_state = None
138
+
139
+ grid = (NK, NV, N * H)
140
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ g=g,
145
+ beta=beta,
146
+ o=o,
147
+ h0=initial_state,
148
+ ht=final_state,
149
+ offsets=offsets,
150
+ scale=scale,
151
+ T=T,
152
+ B=B,
153
+ H=H,
154
+ K=K,
155
+ V=V,
156
+ BK=BK,
157
+ BV=BV,
158
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
159
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
160
+ num_warps=num_warps,
161
+ num_stages=num_stages,
162
+ )
163
+ o = o.squeeze(0)
164
+ return o, final_state
165
+
166
+
167
+ class FusedRecurrentFunction(torch.autograd.Function):
168
+
169
+ @staticmethod
170
+ @input_guard
171
+ def forward(
172
+ ctx,
173
+ q: torch.Tensor,
174
+ k: torch.Tensor,
175
+ v: torch.Tensor,
176
+ g: torch.Tensor,
177
+ beta: torch.Tensor,
178
+ scale: float,
179
+ initial_state: torch.Tensor,
180
+ output_final_state: bool,
181
+ offsets: Optional[torch.LongTensor] = None,
182
+ use_qk_l2norm_in_kernel: bool = False
183
+ ):
184
+ o, final_state = fused_recurrent_gated_delta_rule_fwd(
185
+ q=q,
186
+ k=k,
187
+ v=v,
188
+ g=g,
189
+ beta=beta,
190
+ scale=scale,
191
+ initial_state=initial_state,
192
+ output_final_state=output_final_state,
193
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
194
+ offsets=offsets
195
+ )
196
+
197
+ return o, final_state
198
+
199
+ @staticmethod
200
+ @input_guard
201
+ def backward(ctx, do, dht):
202
+ raise NotImplementedError(
203
+ "Backward pass is not implemented yet and we do not have plans to implement it "
204
+ "because we haven't figured out how to compute dg without materializing the full "
205
+ "hidden states for all time steps."
206
+ )
207
+
208
+
209
+ def fused_recurrent_gated_delta_rule(
210
+ q: torch.Tensor,
211
+ k: torch.Tensor,
212
+ v: torch.Tensor,
213
+ g: torch.Tensor,
214
+ beta: torch.Tensor = None,
215
+ scale: float = None,
216
+ initial_state: torch.Tensor = None,
217
+ output_final_state: bool = False,
218
+ cu_seqlens: Optional[torch.LongTensor] = None,
219
+ use_qk_l2norm_in_kernel: bool = False,
220
+ head_first: bool = False,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ r"""
223
+ Args:
224
+ q (torch.Tensor):
225
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
226
+ k (torch.Tensor):
227
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
228
+ v (torch.Tensor):
229
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
230
+ g (torch.Tensor):
231
+ g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
232
+ beta (torch.Tensor):
233
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
234
+ scale (Optional[int]):
235
+ Scale factor for the RetNet attention scores.
236
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
237
+ initial_state (Optional[torch.Tensor]):
238
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
239
+ For equal-length input sequences, `N` equals the batch size `B`.
240
+ Default: `None`.
241
+ output_final_state (Optional[bool]):
242
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
243
+ cu_seqlens (torch.LongTensor):
244
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
245
+ consistent with the FlashAttention API.
246
+
247
+ Returns:
248
+ o (torch.Tensor):
249
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
250
+ final_state (torch.Tensor):
251
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
252
+
253
+ Examples::
254
+ >>> import torch
255
+ >>> import torch.nn.functional as F
256
+ >>> from einops import rearrange
257
+ >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
258
+ # inputs with equal lengths
259
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
260
+ >>> q = torch.randn(B, T, H, K, device='cuda')
261
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
262
+ >>> v = torch.randn(B, T, H, V, device='cuda')
263
+ >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda'))
264
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
265
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
266
+ >>> o, ht = fused_gated_recurrent_delta_rule(
267
+ q, k, v, g, beta,
268
+ initial_state=h0,
269
+ output_final_state=True,
270
+ )
271
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
272
+ >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
273
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
274
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
275
+ >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
276
+ q, k, v, g, beta,
277
+ initial_state=h0,
278
+ output_final_state=True,
279
+ cu_seqlens=cu_seqlens
280
+ )
281
+ >>> assert o.allclose(o_var.view(o.shape))
282
+ >>> assert ht.allclose(ht_var)
283
+ """
284
+ if cu_seqlens is not None:
285
+ if q.shape[0] != 1:
286
+ raise ValueError(
287
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
288
+ f"Please flatten variable-length inputs before processing."
289
+ )
290
+ if head_first:
291
+ raise RuntimeError(
292
+ "Sequences with variable lengths are not supported for head-first mode"
293
+ )
294
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
295
+ raise ValueError(
296
+ f"The number of initial states is expected to be equal to the number of input sequences, "
297
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
298
+ )
299
+ if scale is None:
300
+ scale = k.shape[-1] ** -0.5
301
+ else:
302
+ assert scale > 0, "scale must be positive"
303
+ if beta is None:
304
+ beta = torch.ones_like(q[..., 0])
305
+ if head_first:
306
+ q, k, v, g, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g, beta))
307
+ o, final_state = FusedRecurrentFunction.apply(
308
+ q,
309
+ k,
310
+ v,
311
+ g,
312
+ beta,
313
+ scale,
314
+ initial_state,
315
+ output_final_state,
316
+ cu_seqlens,
317
+ use_qk_l2norm_in_kernel
318
+ )
319
+ if head_first:
320
+ o = rearrange(o, 'b t h v -> b h t v')
321
+ return o, final_state
fla/ops/generalized_delta_rule/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generalized Delta Rule
2
+
3
+ In delta rule we have the recurrence:
4
+
5
+ ```math
6
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T
7
+ ```
8
+
9
+ This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$.
10
+
11
+ ## IPLR (Identity Plus Low Rank)
12
+
13
+ The first variant is IPLR, where we have:
14
+
15
+ ```math
16
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
17
+ ```
18
+
19
+ When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR.
20
+
21
+ ### Numerical Stability
22
+
23
+ $\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix.
24
+
25
+ ## DPLR (Diagonal Plus Low Rank)
26
+
27
+ The second variant is DPLR, where we have:
28
+
29
+ ```math
30
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
31
+ ```
32
+
33
+ Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7.
34
+
35
+ ## Efficient Chunkwise Implementation
36
+
37
+ For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing).
fla/ops/generalized_delta_rule/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule
2
+ from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule',
7
+ 'chunk_iplr_delta_rule',
8
+ 'fused_recurrent_iplr_delta_rule'
9
+ ]
fla/ops/gla/fused_chunk.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+ from packaging import version
12
+
13
+ from fla.ops.utils import chunk_local_cumsum
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def prepare_qg_kg(
20
+ q,
21
+ k,
22
+ g,
23
+ qg,
24
+ kg,
25
+ scale,
26
+ T,
27
+ K: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BK: tl.constexpr
30
+ ):
31
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+ p_q = q + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
33
+ p_g = g + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
34
+ p_k = k + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
35
+ p_qg = qg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
36
+ p_kg = kg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
37
+
38
+ mask = (i_k * BK + tl.arange(0, BK)) < K
39
+
40
+ last_decay = tl.load(g + i_bh * T*K + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK))
41
+
42
+ for _ in range(BT):
43
+ b_q = tl.load(p_q, mask=mask, other=0)
44
+ b_k = tl.load(p_k, mask=mask, other=0)
45
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
46
+ b_q *= exp(b_g) * scale
47
+ b_k *= exp(last_decay - b_g)
48
+ tl.store(p_kg, b_k.to(p_kg.dtype.element_ty), mask=mask)
49
+ tl.store(p_qg, b_q.to(p_qg.dtype.element_ty), mask=mask)
50
+ p_q += K
51
+ p_g += K
52
+ p_k += K
53
+ p_kg += K
54
+ p_qg += K
55
+
56
+
57
+ @triton.jit(do_not_specialize=['T'])
58
+ def bwd_decay_global_cumsum(
59
+ dq_inner,
60
+ dq_inter,
61
+ dk_inner,
62
+ dk_inter,
63
+ q,
64
+ k,
65
+ g,
66
+ dg,
67
+ T,
68
+ K: tl.constexpr,
69
+ BT: tl.constexpr,
70
+ BK: tl.constexpr
71
+ ):
72
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
73
+ p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
74
+ p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
75
+ p_g = g + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
76
+ p_dg = dg + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
77
+ p_dq_inner = dq_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
78
+ p_dk_inner = dk_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
79
+ p_dq_inter = dq_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
80
+ p_dk_inter = dk_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
81
+ cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
82
+ mask = (i_k * BK + tl.arange(0, BK)) < K
83
+ last_g = tl.zeros([BK], dtype=tl.float32)
84
+ for j in range(BT-1, -1, -1):
85
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
86
+ if j == (BT-1):
87
+ last_g = b_g
88
+ b_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
89
+ b_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
90
+ b_dq2 *= exp(b_g)
91
+ b_dq = b_dq1 + b_dq2
92
+ tl.store(p_dq_inter, b_dq, mask=mask)
93
+ b_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
94
+ b_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
95
+ b_dk2 *= safe_exp(last_g - b_g)
96
+ b_dk = b_dk1 + b_dk2
97
+ tl.store(p_dk_inter, b_dk, mask=mask)
98
+ b_q = tl.load(p_q, mask=mask, other=0)
99
+ b_k = tl.load(p_k, mask=mask, other=0)
100
+ b_dg = b_dq * b_q - b_dk * b_k
101
+ cum_grad_dg += b_dg
102
+ tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
103
+ p_g -= K
104
+ p_k -= K
105
+ p_q -= K
106
+ p_dq_inner -= K
107
+ p_dk_inner -= K
108
+ p_dq_inter -= K
109
+ p_dk_inter -= K
110
+ p_dg -= K
111
+
112
+
113
+ @triton.jit(do_not_specialize=['T'])
114
+ def fused_chunk_gla_fwd_kernel(
115
+ q,
116
+ k,
117
+ v,
118
+ g,
119
+ o,
120
+ h0,
121
+ ht,
122
+ T,
123
+ B: tl.constexpr,
124
+ H: tl.constexpr,
125
+ K: tl.constexpr,
126
+ V: tl.constexpr,
127
+ BT: tl.constexpr,
128
+ BK: tl.constexpr,
129
+ BV: tl.constexpr,
130
+ USE_INITIAL_STATE: tl.constexpr,
131
+ STORE_FINAL_STATE: tl.constexpr,
132
+ CHECK: tl.constexpr
133
+ ):
134
+ # indices
135
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
136
+
137
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
138
+
139
+ # make block pointers
140
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
141
+ p_gn = g + i_bh * T*K + (BT - 1) * K + i_k * BK + tl.arange(0, BK)
142
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
143
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
144
+ p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
145
+
146
+ if USE_INITIAL_STATE:
147
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
148
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
149
+
150
+ mask = (i_k * BK + tl.arange(0, BK)) < K
151
+
152
+ for i in range(0, tl.cdiv(T, BT)):
153
+ # [BK, BT]
154
+ b_k = tl.load(p_k, boundary_check=(0, 1))
155
+ # [BT, BV]
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ # [BT, BK]
158
+ b_q = tl.load(p_q, boundary_check=(0, 1))
159
+ b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
160
+ if CHECK and i == 0:
161
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
162
+ b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
163
+ else:
164
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
165
+ b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
166
+
167
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
168
+ p_q = tl.advance(p_q, (BT, 0))
169
+ p_k = tl.advance(p_k, (0, BT))
170
+ p_v = tl.advance(p_v, (BT, 0))
171
+ p_o = tl.advance(p_o, (BT, 0))
172
+ p_gn += BT * K
173
+
174
+ if STORE_FINAL_STATE:
175
+ p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
176
+ tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
177
+
178
+
179
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
180
+ @triton.jit(do_not_specialize=['T'])
181
+ def fused_chunk_gla_bwd_kernel(
182
+ q, k, v, g,
183
+ do,
184
+ dq,
185
+ dk,
186
+ dv,
187
+ h0,
188
+ scale,
189
+ T,
190
+ B: tl.constexpr,
191
+ H: tl.constexpr,
192
+ K: tl.constexpr,
193
+ V: tl.constexpr,
194
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_INITIAL_STATE: tl.constexpr,
199
+ CHECK: tl.constexpr
200
+ ):
201
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ # [BV, BK]
203
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
204
+
205
+ if USE_INITIAL_STATE:
206
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
207
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
208
+
209
+ mask = (i_k * BK + tl.arange(0, BK)) < K
210
+ for i in range(0, tl.cdiv(T, BT)):
211
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
212
+ p_gn = g + i_bh * T*K + ((i+1) * BT - 1) * K + i_k * BK + tl.arange(0, BK)
213
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
214
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
215
+ p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
216
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
217
+ # [BT, K]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
220
+
221
+ # [V, BT]
222
+ b_v = tl.load(p_v, boundary_check=(0, 1))
223
+ # [BT, V]
224
+ b_do = tl.load(p_do, boundary_check=(0, 1))
225
+ # [V, K]
226
+ if CHECK and i == 0:
227
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
228
+ b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
229
+ else:
230
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
231
+ b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
232
+ b_dq *= scale
233
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
234
+
235
+ # sync threads
236
+ b_h = None
237
+ tl.debug_barrier()
238
+ # [BK, BV]
239
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
240
+
241
+ # cum = tl.zeros([BK], dtype=tl.float32)
242
+ for i in range(1, tl.cdiv(T, BT) + 1):
243
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
244
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
245
+ p_gn = g + i_bh * T*K + (T - (i-1) * BT - 1) * K + i_k * BK + tl.arange(0, BK)
246
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
247
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
248
+ p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * T*K, (T, K),
249
+ (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
250
+ p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * T*V, (T, V),
251
+ (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
252
+ # [K, BT]
253
+ b_q = tl.load(p_q, boundary_check=(0, 1))
254
+ # [BT, K]
255
+ b_k = tl.load(p_k, boundary_check=(0, 1))
256
+ # [BT, V]
257
+ b_v = tl.load(p_v, boundary_check=(0, 1))
258
+ b_do = tl.load(p_do, boundary_check=(0, 1))
259
+ b_db = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
260
+
261
+ # inter-chunk
262
+ # [K, V]
263
+ if CHECK and i == 1:
264
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
265
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
266
+ b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
267
+ else:
268
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
269
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
270
+ b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
271
+
272
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
273
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
274
+
275
+
276
+ @triton.jit
277
+ def fwd_inner_chunk(
278
+ q, k, g, A,
279
+ scale, # K ** -0.5
280
+ B: tl.constexpr, # B
281
+ H: tl.constexpr, # H
282
+ T, # T
283
+ K: tl.constexpr, # K
284
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
285
+ BK: tl.constexpr # BLOCK SIZE along the K dimension
286
+ ):
287
+
288
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
289
+
290
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
291
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
292
+
293
+ b_k = tl.load(p_k, boundary_check=(0, 1))
294
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
295
+
296
+ mask = (i_k * BK + tl.arange(0, BK)) < K
297
+ o_i = tl.arange(0, BT)
298
+
299
+ p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
300
+ p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
301
+ p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
302
+
303
+ for i in range(BT):
304
+ b_q = tl.load(p_q, mask=mask, other=0) * scale
305
+ b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
306
+ s = b_q[None, :] * b_k * safe_exp(b_gq[None, :] - b_g)
307
+ score = tl.sum(s, axis=1)
308
+ score = tl.where(o_i <= i, score, 0)
309
+ tl.store(p_A, score.to(p_A.dtype.element_ty))
310
+ p_q += K
311
+ p_gq += K
312
+ p_A += BT
313
+
314
+
315
+ @triton.jit
316
+ def bwd_inner_chunk(
317
+ q,
318
+ k,
319
+ g,
320
+ dA,
321
+ dq,
322
+ dk,
323
+ T, # T
324
+ K: tl.constexpr, # K
325
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
326
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
327
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
328
+ ):
329
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
330
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
331
+ b_k = tl.load(p_k, boundary_check=(0, 1))
332
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
333
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
334
+
335
+ mask = (i_k * BK + tl.arange(0, BK)) < K
336
+ o_i = tl.arange(0, BT)
337
+
338
+ p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
339
+ p_dq = dq + (i_bh) * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
340
+ p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
341
+ p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
342
+
343
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
344
+
345
+ for i in range(BT):
346
+ b_q = tl.load(p_q, mask=mask, other=0)
347
+ b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
348
+ score = safe_exp(b_gq[None, :] - b_g)
349
+ score = tl.where(o_i[:, None] <= i, score, 0)
350
+ b_dA = tl.load(p_dA)
351
+ b_dA = tl.where(o_i <= i, b_dA, 0)
352
+ b_dk += (b_dA[:, None] * score * b_q[None, :])
353
+ b_dq = tl.sum(b_dA[:, None] * score * b_k, axis=0)
354
+ tl.store(p_dq, b_dq, mask=mask)
355
+ p_q += K
356
+ p_dq += K
357
+ p_gq += K
358
+ p_dA += BT
359
+
360
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
361
+ tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
362
+
363
+
364
+ class FusedChunkGLAFunction(torch.autograd.Function):
365
+
366
+ @staticmethod
367
+ @input_guard
368
+ @autocast_custom_fwd
369
+ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
370
+ ctx.g_dtype = g.dtype
371
+ ctx.scale = scale
372
+ B, H, T, K, V = *k.shape, v.shape[-1]
373
+ BT = 16 # chunk_size
374
+ BK, BV = min(K, 64), min(V, 64)
375
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
376
+ num_stages = 1
377
+ num_warps = 2
378
+
379
+ g_org = g
380
+ # cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
381
+ g = chunk_local_cumsum(g_org, chunk_size=BT)
382
+ o = q.new_empty(NK, B, H, T, V)
383
+ q_g = torch.empty_like(q)
384
+ k_g = torch.empty_like(k)
385
+
386
+ grid = (NK, triton.cdiv(T, BT), B * H)
387
+ prepare_qg_kg[grid](
388
+ q,
389
+ k,
390
+ g,
391
+ q_g,
392
+ k_g,
393
+ scale,
394
+ T=T,
395
+ K=K,
396
+ BT=BT,
397
+ BK=BK,
398
+ num_warps=1
399
+ )
400
+
401
+ if output_final_state:
402
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
403
+ else:
404
+ final_state = None
405
+ # the bug still exists even for Triton 2.2 on H100 GPUs
406
+ # so we always enable initial checks
407
+ CHECK = True
408
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
409
+ import warnings
410
+ warnings.warn(
411
+ "Triton<2.2.0 detected for running this kernel, "
412
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
413
+ "that lead to significant precision loss. "
414
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
415
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
416
+ )
417
+ CHECK = True
418
+
419
+ grid = (NV, NK, B * H)
420
+ fused_chunk_gla_fwd_kernel[grid](
421
+ q_g, k_g, v, g, o, initial_state, final_state,
422
+ T=T,
423
+ B=B,
424
+ H=H,
425
+ K=K,
426
+ V=V,
427
+ BT=BT,
428
+ BK=BK,
429
+ BV=BV,
430
+ USE_INITIAL_STATE=initial_state is not None,
431
+ STORE_FINAL_STATE=output_final_state,
432
+ CHECK=CHECK,
433
+ num_warps=num_warps,
434
+ num_stages=num_stages
435
+ )
436
+
437
+ o = o.sum(0)
438
+
439
+ # intra-chunk
440
+ chunk_size = 16
441
+ num_chunk = T // chunk_size
442
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
443
+ BK = min(K, 64)
444
+ NK = triton.cdiv(K, BK)
445
+ A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT)
446
+ grid = (NK, triton.cdiv(T, BT), B * H)
447
+ fwd_inner_chunk[grid](
448
+ q, k, g, A,
449
+ scale,
450
+ B=B,
451
+ H=H,
452
+ T=T,
453
+ K=K,
454
+ BT=BT,
455
+ BK=BK,
456
+ num_stages=3,
457
+ num_warps=4
458
+ )
459
+ A = A.sum(0)
460
+ o2 = A @ v2
461
+ o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
462
+ # combine inner and inter
463
+ o.add_(o2)
464
+ ctx.save_for_backward(q, k, v, g_org, A, initial_state)
465
+ ctx.CHECK = CHECK
466
+ return o.to(v), final_state
467
+
468
+ @staticmethod
469
+ @input_guard
470
+ @autocast_custom_bwd
471
+ def backward(ctx, do, dht=None):
472
+ q, k, v, g_org, A, initial_state = ctx.saved_tensors
473
+ B, H, T, K, V = *k.shape, v.shape[-1]
474
+ scale = ctx.scale
475
+
476
+ # recomputation
477
+ # inter-chunk
478
+ BT = 16 # chunk_size
479
+ g = chunk_local_cumsum(g_org, chunk_size=BT)
480
+ BK, BV = min(K, 64), min(V, 64)
481
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
482
+ q_g = torch.empty_like(q)
483
+ k_g = torch.empty_like(k)
484
+ grid = (NK, triton.cdiv(T, BT), B * H)
485
+ prepare_qg_kg[grid](
486
+ q,
487
+ k,
488
+ g,
489
+ q_g,
490
+ k_g,
491
+ scale,
492
+ T=T,
493
+ K=K,
494
+ BT=BT,
495
+ BK=BK,
496
+ num_warps=1
497
+ )
498
+
499
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
500
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
501
+ num_stages = 1
502
+ num_warps = 2
503
+ dq = q.new_empty(NV, B, H, T, K)
504
+ dk = q.new_empty(NV, B, H, T, K)
505
+ dv = q.new_empty(NK, B, H, T, V)
506
+
507
+ grid = (NV, NK, B * H)
508
+
509
+ fused_chunk_gla_bwd_kernel[grid](
510
+ q_g,
511
+ k_g,
512
+ v,
513
+ g,
514
+ do,
515
+ dq,
516
+ dk,
517
+ dv,
518
+ initial_state,
519
+ scale,
520
+ T=T,
521
+ B=B,
522
+ H=H,
523
+ K=K,
524
+ V=V,
525
+ BT=BT,
526
+ BK=BK,
527
+ BV=BV,
528
+ USE_INITIAL_STATE=initial_state is not None,
529
+ CHECK=ctx.CHECK,
530
+ num_warps=num_warps,
531
+ num_stages=num_stages,
532
+ )
533
+ dq = dq.sum(0)
534
+ dk = dk.sum(0)
535
+ dv = dv.sum(0)
536
+
537
+ # intra chunk
538
+ NT = T // BT
539
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=NT)
540
+ do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=NT)
541
+ dA2 = (do2 @ v2.transpose(-2, -1)) * scale
542
+ dv2 = A.transpose(-1, -2) @ do2
543
+ dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=NT)
544
+
545
+ BK = min(triton.next_power_of_2(K), 16)
546
+ NK = triton.cdiv(K, BK)
547
+ dk2 = torch.empty_like(k)
548
+ dq2 = torch.empty_like(q)
549
+
550
+ grid = (NK, NT, B * H)
551
+ bwd_inner_chunk[grid](
552
+ q, k, g,
553
+ dA2,
554
+ dq2,
555
+ dk2,
556
+ T=T,
557
+ K=K,
558
+ BT=BT,
559
+ BK=BK,
560
+ num_warps=1,
561
+ num_stages=3
562
+ )
563
+
564
+ BK = min(triton.next_power_of_2(K), 32)
565
+ NK = triton.cdiv(K, BK)
566
+ dg = torch.empty_like(g, dtype=torch.float32)
567
+ grid = (NK, triton.cdiv(T, BT), B * H)
568
+ bwd_decay_global_cumsum[grid](
569
+ dq2,
570
+ dq,
571
+ dk2,
572
+ dk,
573
+ q,
574
+ k,
575
+ g,
576
+ dg,
577
+ T=T,
578
+ K=K,
579
+ BT=BT,
580
+ BK=BK,
581
+ num_warps=1,
582
+ num_stages=1
583
+ )
584
+ dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
585
+
586
+ def rev_cumsum_exclusive(x):
587
+ cumsum_x = x.cumsum(-2)
588
+ rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
589
+ return rev_cumsum_x
590
+
591
+ rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
592
+ dg.add_(rev_cumsum_dg.unsqueeze(-2))
593
+ dv.add_(dv2)
594
+ dg = rearrange(dg, 'b h n c d -> b h (n c) d')
595
+
596
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
597
+
598
+
599
+ def ceildiv(a, b):
600
+ return -(a // -b)
601
+
602
+
603
+ def pad(x, chunk_size=16):
604
+ T = x.shape[-2]
605
+ padded_seq_len = ceildiv(T, chunk_size) * chunk_size
606
+ if x.shape[-2] % chunk_size != 0:
607
+ x = F.pad(x, (0, 0, 0, padded_seq_len - T))
608
+ return x
609
+
610
+
611
+ def fused_chunk_gla(
612
+ q: torch.Tensor,
613
+ k: torch.Tensor,
614
+ v: torch.Tensor,
615
+ g: torch.Tensor,
616
+ scale: int = -1,
617
+ initial_state: torch.Tensor = None,
618
+ output_final_state: bool = False,
619
+ head_first: bool = True
620
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
621
+ if scale == -1:
622
+ scale = q.shape[-1] ** -0.5
623
+ if not head_first:
624
+ q, k, v, g = map(lambda x: x.transpose(1, 2), (q, k, v, g))
625
+ seq_len = q.shape[-2]
626
+ q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
627
+ o, final_state = FusedChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
628
+ o = o[..., :seq_len, :].contiguous()
629
+ if not head_first:
630
+ o = o.transpose(1, 2)
631
+ return o, final_state
fla/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla/ops/hgrn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_hgrn
4
+ from .fused_recurrent import fused_recurrent_hgrn
5
+
6
+ __all__ = [
7
+ 'chunk_hgrn',
8
+ 'fused_recurrent_hgrn'
9
+ ]
fla/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, T, D = x.shape
17
+
18
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, i].exp() * h + x[:, i]
27
+ o[:, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, T, D = x.shape
44
+
45
+ gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, j].exp() * h + x[:, j]
58
+ o[:, j] = hp * gc[:, j].exp() + h
59
+ h = o[:, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
fla/ops/lightning_attn/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
fla/ops/linear_attn/fused_chunk.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from packaging import version
10
+
11
+ from fla.ops.linear_attn.utils import normalize_output
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.jit
16
+ def fused_chunk_linear_attn_fwd_kernel(
17
+ q, # query [B, H, T, K]
18
+ k, # key [B, H, T, V]
19
+ v, # value [B, H, T, V]
20
+ o, # output [B, H, T, V]
21
+ h0,
22
+ ht,
23
+ scale,
24
+ B, # batch size
25
+ H, # H
26
+ T, # T
27
+ K: tl.constexpr, # K
28
+ V: tl.constexpr, # V
29
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
30
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
31
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
32
+ USE_INITIAL_STATE: tl.constexpr,
33
+ STORE_FINAL_STATE: tl.constexpr,
34
+ CHECK: tl.constexpr
35
+ ):
36
+ # indices
37
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
38
+
39
+ o_i = tl.arange(0, BT)
40
+
41
+ # [BT, BT]
42
+ m_s = o_i[:, None] >= o_i[None, :]
43
+ # [BK, BV]
44
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
45
+
46
+ # make block pointers
47
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
48
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
49
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+ p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
51
+
52
+ if USE_INITIAL_STATE:
53
+ p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
54
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BT, BK]
58
+ b_q = tl.load(p_q, boundary_check=(0, 1))
59
+ b_q = (b_q * scale).to(b_q.dtype)
60
+ # [BK, BT]
61
+ b_k = tl.load(p_k, boundary_check=(0, 1))
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+
65
+ # [BT, BT]
66
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
67
+ b_s = tl.where(m_s, b_s, 0)
68
+ # [BT, BV]
69
+ b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
70
+ if CHECK and i == 0:
71
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
72
+ b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
73
+ else:
74
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
75
+ b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
76
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
77
+ p_q = tl.advance(p_q, (BT, 0))
78
+ p_k = tl.advance(p_k, (0, BT))
79
+ p_v = tl.advance(p_v, (BT, 0))
80
+ p_o = tl.advance(p_o, (BT, 0))
81
+
82
+ if STORE_FINAL_STATE:
83
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
84
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
85
+
86
+
87
+ @triton.jit
88
+ def fused_chunk_linear_attn_bwd_kernel(
89
+ q, # query [B, H, T, K]
90
+ k, # key [B, H, T, V]
91
+ v, # value [B, H, T, V]
92
+ do, # gradient of output [B, H, T, V]
93
+ dq, # gradient of query [NV, B, H, T, K]
94
+ dk, # gradient of key [NV, B, H, T, K]
95
+ dv, # gradient of value [NK, B, H, T, V]
96
+ h0, # initial state of the chunk [B, H, K, V]
97
+ scale, # K ** -0.5
98
+ B, # B
99
+ H, # H
100
+ T, # T
101
+ K: tl.constexpr, # K
102
+ V: tl.constexpr, # V
103
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
104
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
105
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
106
+ USE_INITIAL_STATE: tl.constexpr,
107
+ CHECK: tl.constexpr
108
+ ):
109
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
110
+ o_i = tl.arange(0, BT)
111
+
112
+ m_s = o_i[:, None] >= o_i[None, :]
113
+ # [BV, BK]
114
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
115
+ if USE_INITIAL_STATE:
116
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
117
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
118
+
119
+ for i in range(0, tl.cdiv(T, BT)):
120
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
121
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
122
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
123
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
124
+
125
+ # [BT, BK]
126
+ b_k = tl.load(p_k, boundary_check=(0, 1))
127
+ # [V, BT]
128
+ b_v = tl.load(p_v, boundary_check=(0, 1))
129
+ # [BT, V]
130
+ b_do = tl.load(p_do, boundary_check=(0, 1))
131
+
132
+ # [BT, BT]
133
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
134
+ b_ds = tl.where(m_s, b_ds, 0)
135
+ # [BT, BK]
136
+ b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
137
+ # [BV, BK]
138
+ if CHECK and i == 0:
139
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
140
+ b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
141
+ else:
142
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
143
+ b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
144
+ b_dq *= scale
145
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
146
+
147
+ # sync threads
148
+ b_h = None
149
+ tl.debug_barrier()
150
+ # [BK, BV]
151
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
152
+ m_s = o_i[:, None] <= o_i[None, :]
153
+ for i in range(1, tl.cdiv(T, BT) + 1):
154
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
156
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
157
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
158
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
159
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
160
+ # [BK, BT]
161
+ b_q = tl.load(p_q, boundary_check=(0, 1))
162
+ b_q = (b_q * scale).to(b_q.dtype)
163
+ # [BT, BK]
164
+ b_k = tl.load(p_k, boundary_check=(0, 1))
165
+ # [BT, BV]
166
+ b_v = tl.load(p_v, boundary_check=(0, 1))
167
+ b_do = tl.load(p_do, boundary_check=(0, 1))
168
+
169
+ # [BT, BT]
170
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
171
+ b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
172
+ # [BT, BT]
173
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
174
+ b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
175
+ # [BT, BK]
176
+ b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
177
+ # [BT, BV]
178
+ b_dv = tl.dot(b_s, b_do, allow_tf32=False)
179
+ if CHECK and i == 1:
180
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
181
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
182
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
183
+ else:
184
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
185
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
186
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
187
+
188
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
189
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
190
+
191
+
192
+ class FusedChunkLinearAttentionFunction(torch.autograd.Function):
193
+
194
+ @staticmethod
195
+ @input_guard
196
+ @autocast_custom_fwd
197
+ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
198
+ B, H, T, K, V = *k.shape, v.shape[-1]
199
+ BT = 64
200
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
201
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
202
+ num_warps = 4
203
+ num_stages = 1
204
+
205
+ o = q.new_empty(NK, B, H, T, V)
206
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float) if output_final_state else None
207
+ # the bug still exists even for Triton 2.2 on H100 GPUs
208
+ # so we always enable initial checks
209
+ CHECK = True
210
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
211
+ import warnings
212
+ warnings.warn(
213
+ "Triton<2.2.0 detected for running this kernel, "
214
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
215
+ "that lead to significant precision loss. "
216
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
217
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
218
+ )
219
+ CHECK = True
220
+
221
+ grid = (NV, NK, B * H)
222
+ fused_chunk_linear_attn_fwd_kernel[grid](
223
+ q, k, v, o, initial_state, final_state,
224
+ scale,
225
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
226
+ USE_INITIAL_STATE=initial_state is not None,
227
+ STORE_FINAL_STATE=output_final_state,
228
+ CHECK=CHECK,
229
+ num_warps=num_warps,
230
+ num_stages=num_stages
231
+ )
232
+ o = o.sum(0) if NK > 1 else o[0]
233
+
234
+ ctx.save_for_backward(q, k, v, initial_state)
235
+ ctx.scale = scale
236
+ ctx.CHECK = CHECK
237
+ return o.to(q.dtype), final_state
238
+
239
+ @staticmethod
240
+ @input_guard
241
+ @autocast_custom_bwd
242
+ def backward(ctx, do, dht=None):
243
+ q, k, v, initial_state = ctx.saved_tensors
244
+ B, H, T, K, V = *k.shape, v.shape[-1]
245
+ scale = ctx.scale
246
+
247
+ BT = 64
248
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
249
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
250
+ num_warps = 4
251
+ num_stages = 1
252
+
253
+ dq = q.new_empty(NV, B, H, T, K)
254
+ dk = q.new_empty(NV, B, H, T, K)
255
+ dv = q.new_empty(NK, B, H, T, V)
256
+ grid = (NV, NK, B * H)
257
+
258
+ fused_chunk_linear_attn_bwd_kernel[grid](
259
+ q, k, v, do, dq, dk, dv, initial_state,
260
+ scale,
261
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
262
+ USE_INITIAL_STATE=initial_state is not None,
263
+ CHECK=ctx.CHECK,
264
+ num_warps=num_warps,
265
+ num_stages=num_stages
266
+ )
267
+ dq = dq.sum(0)
268
+ dk = dk.sum(0)
269
+ dv = dv.sum(0)
270
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
271
+
272
+
273
+ def fused_chunk_linear_attn(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ scale: Optional[float] = None,
278
+ initial_state: torch.Tensor = None,
279
+ output_final_state: bool = False,
280
+ normalize: bool = True,
281
+ head_first: bool = True
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ r"""
284
+ Args:
285
+ q (torch.Tensor):
286
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
287
+ k (torch.Tensor):
288
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
289
+ v (torch.Tensor):
290
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
291
+ scale (Optional[int]):
292
+ Scale factor for linear attention scores.
293
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
294
+ initial_state (Optional[torch.Tensor]):
295
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
296
+ output_final_state (Optional[bool]):
297
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
298
+ normalize (bool):
299
+ Whether to normalize the output. Default: `True`.
300
+ head_first (Optional[bool]):
301
+ Whether the inputs are in the head-first format. Default: `True`.
302
+
303
+ Returns:
304
+ o (torch.Tensor):
305
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
306
+ final_state (torch.Tensor):
307
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
308
+ """
309
+ if scale is None:
310
+ scale = q.shape[-1] ** -0.5
311
+ if not head_first:
312
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
313
+ o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
314
+ if normalize:
315
+ o = normalize_output(q * scale, k, o)
316
+ if not head_first:
317
+ o = o.transpose(1, 2)
318
+ return o, final_state
fla/ops/linear_attn/fused_recurrent.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.linear_attn.utils import normalize_output
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.jit
15
+ def fused_recurrent_linear_attn_fwd_kernel(
16
+ q, # query [B, H, L, K]
17
+ k, # key [B, H, L, V]
18
+ v, # value [B, H, L, V]
19
+ o, # output [B, H, L, V]
20
+ h0,
21
+ ht, # final hidden state [B, H, K, V]
22
+
23
+ s_k_h, # stride size: L * K
24
+ s_v_h, # stride size: L * V
25
+
26
+ scale,
27
+ B, # batch size
28
+ H, # H
29
+ T, # T
30
+ K: tl.constexpr, # K
31
+ V: tl.constexpr, # V
32
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
33
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
34
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
35
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
36
+ ):
37
+ # indices
38
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+
40
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
41
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
42
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV)
43
+ p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV)
44
+
45
+ mask_bk = (i_k * BK + tl.arange(0, BK)) < K
46
+ mask_bv = (i_v * BV + tl.arange(0, BV)) < V
47
+ mask_kv = mask_bk[None, :] & mask_bv[:, None]
48
+
49
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
50
+
51
+ if USE_INITIAL_STATE:
52
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
53
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
54
+
55
+ for _ in range(0, T):
56
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
57
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
58
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
59
+
60
+ b_h += b_k[None, :] * b_v[:, None]
61
+ b_o = b_h * b_q[None, :]
62
+ b_o = tl.sum(b_o, axis=1)
63
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
64
+
65
+ p_q += K
66
+ p_k += K
67
+ p_o += V
68
+ p_v += V
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
73
+
74
+
75
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
76
+ @triton.jit
77
+ def fused_recurrent_linear_attn_bwd_kernel(
78
+ q, # query [B, H, L, K]
79
+ k, # key [B, H, L, V]
80
+ v, # value [B, H, L, V]
81
+
82
+ do, # gradient of output [B, H, L, V]
83
+ dq, # gradient of query [NV, B, H, L, K]
84
+ dk, # gradient of key [NV, B, H, L, K]
85
+ dv, # gradient of value [NK, B, H, L, V]
86
+ h0, # initial hidden state initialization [B, H, K, V]
87
+
88
+ s_k_h, # stride size: L * K
89
+ s_v_h, # stride size: L * V
90
+ scale, # K ** -0.5
91
+
92
+ B, # B
93
+ H, # H
94
+ T, # T
95
+ K: tl.constexpr, # K
96
+ V: tl.constexpr, # V
97
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
98
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
99
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
100
+ ):
101
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
102
+
103
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
104
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
105
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV)
106
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV)
107
+
108
+ p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK)
109
+ mask_bk = i_k * BK + tl.arange(0, BK) < K
110
+ mask_bv = i_v * BV + tl.arange(0, BV) < V
111
+
112
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
113
+
114
+ if USE_INITIAL_STATE:
115
+ mask_kv = mask_bk[:, None] & mask_bv[None, :]
116
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
117
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
118
+
119
+ for _ in range(0, T):
120
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
121
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
122
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
123
+
124
+ b_h += b_k[:, None] * b_v[None, :]
125
+ _d_q = b_h * b_do[None, :]
126
+ d_q = tl.sum(_d_q, axis=1) * scale
127
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
128
+
129
+ p_k += K
130
+ p_do += V
131
+ p_v += V
132
+ p_dq += K
133
+
134
+ # sync threads
135
+ tl.debug_barrier()
136
+
137
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
138
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
139
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
140
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
141
+ p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
142
+ p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
143
+ d_h = tl.zeros([BK, BV], dtype=tl.float32)
144
+
145
+ for _ in range(T):
146
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
147
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
148
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
149
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
150
+ d_h += b_q[:, None] * b_do[None, :]
151
+ d_k = tl.sum(d_h * b_v[None, :], axis=1)
152
+ d_v = tl.sum(d_h * b_k[:, None], axis=0)
153
+
154
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
155
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
156
+
157
+ p_do -= V
158
+ p_q -= K
159
+ p_k -= K
160
+ p_v -= V
161
+ p_dk -= K
162
+ p_dv -= V
163
+
164
+
165
+ class FusedRecurrentLinearAttentionFunction(torch.autograd.Function):
166
+
167
+ @staticmethod
168
+ @input_guard
169
+ def forward(ctx, q, k, v, scale, initial_state=None, output_final_state=False):
170
+ B, H, T, K = q.shape
171
+ V = v.shape[-1]
172
+
173
+ BK, BV = min(K, 32), min(V, 32)
174
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
175
+ num_warps = 1
176
+ num_stages = 1
177
+
178
+ o = q.new_empty(NK, B, H, T, V)
179
+ final_state = q.new_empty(B, H, K, V) if output_final_state else None
180
+
181
+ grid = (NV, NK, B * H)
182
+ fused_recurrent_linear_attn_fwd_kernel[grid](
183
+ q, k, v, o, initial_state, final_state,
184
+ q.stride(1),
185
+ v.stride(1), scale,
186
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
187
+ USE_INITIAL_STATE=initial_state is not None,
188
+ STORE_FINAL_STATE=final_state is not None,
189
+ num_warps=num_warps,
190
+ num_stages=num_stages
191
+ )
192
+
193
+ o = o.sum(0)
194
+ ctx.save_for_backward(q, k, v, initial_state)
195
+ ctx.scale = scale
196
+ return o, final_state
197
+
198
+ @staticmethod
199
+ @input_guard
200
+ def backward(ctx, do, dht=None):
201
+ q, k, v, initial_state = ctx.saved_tensors
202
+ B, H, T, K = q.shape
203
+ V = v.shape[-1]
204
+ scale = ctx.scale
205
+
206
+ BK, BV = min(K, 32), min(V, 32)
207
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
208
+ num_warps = 1
209
+ num_stages = 1
210
+
211
+ dq = q.new_empty(NV, B, H, T, K)
212
+ dk = q.new_empty(NV, B, H, T, K)
213
+ dv = q.new_empty(NK, B, H, T, V)
214
+ grid = (NV, NK, B * H)
215
+
216
+ fused_recurrent_linear_attn_bwd_kernel[grid](
217
+ q, k, v, do, dq, dk, dv, initial_state,
218
+ q.stride(1),
219
+ v.stride(1),
220
+ scale,
221
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
222
+ USE_INITIAL_STATE=initial_state is not None,
223
+ num_warps=num_warps,
224
+ num_stages=num_stages
225
+ )
226
+ dq = dq.sum(0)
227
+ dk = dk.sum(0)
228
+ dv = dv.sum(0)
229
+ return dq, dk, dv, None, None, None
230
+
231
+
232
+ def fused_recurrent_linear_attn(
233
+ q: torch.Tensor,
234
+ k: torch.Tensor,
235
+ v: torch.Tensor,
236
+ scale: Optional[float] = None,
237
+ initial_state: torch.Tensor = None,
238
+ output_final_state: bool = False,
239
+ normalize: bool = False,
240
+ head_first: bool = True
241
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
242
+ if scale is None:
243
+ scale = q.shape[-1] ** -0.5
244
+ if not head_first:
245
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
246
+ o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
247
+ if normalize:
248
+ o = normalize_output(q * scale, k, o)
249
+ if not head_first:
250
+ o = o.transpose(1, 2)
251
+ return o, final_state
fla/ops/linear_attn/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+
5
+
6
+ @torch.jit.script
7
+ def normalize_output(q, k, o):
8
+ k = k.cumsum(-2)
9
+ z = (q * k).sum(-1, keepdim=True)
10
+ return o / (z + 1e-10)
fla/ops/nsa/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Implements argsort based on bitonic sort.
5
+ # [What is bitonic sort?](https://en.wikipedia.org/wiki/Bitonic_sorter)
6
+
7
+ # Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396
8
+
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from fla.ops.utils.op import log2
14
+
15
+
16
+ @triton.jit
17
+ def _compare_and_swap(
18
+ x,
19
+ ids,
20
+ flip,
21
+ i: tl.constexpr,
22
+ n_dims: tl.constexpr,
23
+ ):
24
+ n_outer: tl.constexpr = x.numel >> n_dims
25
+ shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
26
+ y = tl.reshape(x, shape)
27
+ # slice left/right with 'stride' 2**(n_dims - i - 1)
28
+ mask = tl.arange(0, 2)[None, :, None]
29
+ left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
30
+ right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
31
+ left = tl.reshape(left, x.shape)
32
+ right = tl.reshape(right, x.shape)
33
+ # idx
34
+ y_idx = tl.reshape(ids, shape)
35
+ left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape)
36
+ right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape)
37
+ left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype)
38
+ right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype)
39
+ # actual compare-and-swap
40
+ idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
41
+ ileft = left.to(idtype, bitcast=True)
42
+ iright = right.to(idtype, bitcast=True)
43
+ ix = x.to(idtype, bitcast=True)
44
+
45
+ cond = (left > right) != flip
46
+ ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
47
+ new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids))
48
+ return ret.to(x.dtype, bitcast=True), new_ids
49
+
50
+
51
+ @triton.jit
52
+ def _bitonic_merge(
53
+ x,
54
+ ids,
55
+ stage: tl.constexpr,
56
+ order: tl.constexpr,
57
+ n_dims: tl.constexpr,
58
+ ):
59
+ n_outer: tl.constexpr = x.numel >> n_dims
60
+ tl.static_assert(stage <= n_dims)
61
+ # flip denotes whether to re-arrange sub-sequences of elements in ascending or
62
+ # descending order.
63
+ # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
64
+ # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
65
+ # a stride of 2) at this stage
66
+ if order == 2:
67
+ shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
68
+ flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
69
+ else:
70
+ flip = order
71
+ # perform `stage` rounds of `compare-and-swap`
72
+ for i in tl.static_range(stage):
73
+ x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
74
+ return x, ids
75
+
76
+
77
+ @triton.jit
78
+ def argsort(
79
+ x,
80
+ ids,
81
+ dim: tl.constexpr = None,
82
+ descending: tl.constexpr = tl.core.CONSTEXPR_0,
83
+ ):
84
+ # handle default dimension or check that it is the most minor dim
85
+ _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
86
+ tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
87
+ # iteratively run bitonic merge-sort steps
88
+ n_dims: tl.constexpr = log2(x.shape[_dim])
89
+
90
+ for i in tl.static_range(1, n_dims + 1):
91
+ x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
92
+ return x, ids
fla/ops/rebased/naive.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+
9
+ def naive_parallel_rebased(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True,
15
+ ) -> torch.Tensor:
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = attn ** 2
21
+ attn.masked_fill_(~torch.tril(torch.ones(q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
22
+ o = attn @ v
23
+ if use_norm:
24
+ z = attn.sum(-1)
25
+ return o / (z[..., None] + 1e-6)
26
+ else:
27
+ return o
fla/ops/rebased/parallel.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
10
+
11
+ # Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
12
+ # https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def parallel_rebased_fwd_kernel(
17
+ q,
18
+ k,
19
+ v,
20
+ o,
21
+ z,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BTL: tl.constexpr,
29
+ BTS: tl.constexpr,
30
+ BK: tl.constexpr,
31
+ BV: tl.constexpr,
32
+ ):
33
+ # i_c: chunk index. used for sequence parallelism
34
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ NV = tl.cdiv(V, BV)
36
+ i_k = i_kv // (NV)
37
+ i_v = i_kv % (NV)
38
+
39
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
40
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, 0), (BK, BTS), (0, 1))
41
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v*BV), (BTS, BV), (1, 0))
42
+
43
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
44
+ b_q = tl.load(p_q, boundary_check=(0, 1))
45
+ b_q = (b_q * scale).to(b_q.dtype)
46
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
47
+ b_z = tl.zeros([BTL], dtype=tl.float32)
48
+
49
+ # Q block and K block have no overlap
50
+ # no need for mask, thereby saving flops
51
+ for _ in range(0, i_c*BTL, BTS):
52
+ # [BK, BTS]
53
+ b_k = tl.load(p_k, boundary_check=(0, 1))
54
+
55
+ # [BTS, BV]
56
+ b_v = tl.load(p_v, boundary_check=(0, 1))
57
+ # [BTL, BTS]
58
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
59
+ b_s = b_s * b_s
60
+ b_z += tl.sum(b_s, axis=1)
61
+
62
+ # [BQ, BD]
63
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
64
+ p_k = tl.advance(p_k, (0, BTS))
65
+ p_v = tl.advance(p_v, (BTS, 0))
66
+
67
+ # # rescale interchunk output
68
+ tl.debug_barrier()
69
+ o_q = tl.arange(0, BTL)
70
+ # # sync threads, easy for compiler to optimize
71
+ # tl.debug_barrier()
72
+
73
+ o_k = tl.arange(0, BTS)
74
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, i_c*BTL), (BK, BTS), (0, 1))
75
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTS, BV), (1, 0))
76
+ # Q block and K block have overlap. masks required
77
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
78
+ # [BK, BTS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BTS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BTL, BTS]
83
+ m_s = o_q[:, None] >= o_k[None, :]
84
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
85
+ b_s = b_s * b_s
86
+ b_s = tl.where(m_s, b_s, 0)
87
+ b_z += tl.sum(b_s, axis=1)
88
+ # [BTL, BV]
89
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
90
+ p_k = tl.advance(p_k, (0, BTS))
91
+ p_v = tl.advance(p_v, (BTS, 0))
92
+ o_k += BTS
93
+
94
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
95
+ p_z = z + (i_bh + B * H * i_k) * T + i_c*BTL + tl.arange(0, BTL)
96
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
97
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c*BTL + tl.arange(0, BTL)) < T))
98
+
99
+
100
+ @triton.jit(do_not_specialize=['T'])
101
+ def _parallel_rebased_bwd_dq(
102
+ i_bh,
103
+ i_c,
104
+ i_k,
105
+ i_v,
106
+ i_h,
107
+ q,
108
+ k,
109
+ v,
110
+ do,
111
+ dz,
112
+ dq,
113
+ scale,
114
+ T,
115
+ B: tl.constexpr,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BTL: tl.constexpr,
120
+ BTS: tl.constexpr,
121
+ BK: tl.constexpr,
122
+ BV: tl.constexpr
123
+ ):
124
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
125
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
126
+ b_q = tl.load(p_q, boundary_check=(0, 1))
127
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
130
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k*BK), (BTS, BK), (1, 0))
131
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, 0), (BV, BTS), (0, 1))
132
+ p_dz = dz + i_bh * T + i_c*BTL + tl.arange(0, BTL)
133
+ b_dz = tl.load(p_dz, mask=(i_c*BTL + tl.arange(0, BTL)) < T)
134
+
135
+ for _ in range(0, i_c*BTL, BTS):
136
+ # [BTS, BK]
137
+ b_k = tl.load(p_k, boundary_check=(0, 1))
138
+ # [BV, BTS]
139
+ b_v = tl.load(p_v, boundary_check=(0, 1))
140
+ # [BTL, BTS]
141
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
142
+ if i_v == 0:
143
+ b_ds += b_dz[:, None]
144
+ else:
145
+ b_ds = b_ds
146
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
147
+ # [BQ, BD]
148
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)
149
+ p_k = tl.advance(p_k, (BTS, 0))
150
+ p_v = tl.advance(p_v, (0, BTS))
151
+
152
+ b_dq *= scale
153
+ o_q = tl.arange(0, BTL)
154
+ o_k = tl.arange(0, BTS)
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTS, BK), (1, 0))
156
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, i_c*BTL), (BV, BTS), (0, 1))
157
+ # Q block and K block have overlap. masks required
158
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
159
+ # [BTS, BK]
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ # [BV, BTS]
162
+ b_v = tl.load(p_v, boundary_check=(0, 1))
163
+ # [BTL, BTS]
164
+ m_s = o_q[:, None] >= o_k[None, :]
165
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
166
+ if i_v == 0:
167
+ b_ds += b_dz[:, None]
168
+ else:
169
+ b_ds = b_ds
170
+ b_ds = tl.where(m_s, b_ds, 0) * scale
171
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
172
+ b_s = tl.where(m_s, b_s, 0)
173
+ # [BTL, BK]
174
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),
175
+ b_k, allow_tf32=False)
176
+ p_k = tl.advance(p_k, (BTS, 0))
177
+ p_v = tl.advance(p_v, (0, BTS))
178
+ o_k += BTS
179
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
180
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
181
+ return
182
+
183
+
184
+ @triton.jit(do_not_specialize=['T'])
185
+ def _parallel_rebased_bwd_dkv(
186
+ i_bh,
187
+ i_c,
188
+ i_k,
189
+ i_v,
190
+ i_h,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ K: tl.constexpr,
203
+ V: tl.constexpr,
204
+ BTL: tl.constexpr,
205
+ BTS: tl.constexpr,
206
+ BK: tl.constexpr,
207
+ BV: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
214
+ [BTL, BV], dtype=tl.float32)
215
+
216
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
217
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
218
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
219
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
220
+ # [BK, BTS]
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ # [BV, BTS]
223
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
224
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
225
+ # [BTL, BTS]
226
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale
227
+ b_s2 = b_s * b_s
228
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
229
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
230
+ if i_v == 0:
231
+ b_ds += b_dz[None, :] * scale
232
+ else:
233
+ b_ds = b_ds
234
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
235
+
236
+ tl.debug_barrier()
237
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
238
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
239
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
240
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
241
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
242
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
243
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
244
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
245
+ # [BK, BQ]
246
+ m_s = o_k[:, None] <= o_q[None, :]
247
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
248
+ b_s2 = b_s * b_s
249
+ b_s = tl.where(m_s, b_s, 0)
250
+ b_s2 = tl.where(m_s, b_s2, 0)
251
+
252
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
253
+ if i_v == 0:
254
+ b_ds += b_dz[None, :]
255
+ else:
256
+ b_ds = b_ds
257
+ b_ds = tl.where(m_s, b_ds, 0) * scale
258
+ # [BK, BD]
259
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
260
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
261
+ o_q += BTS
262
+
263
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
264
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
265
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
266
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
267
+ return
268
+
269
+
270
+ @triton.jit(do_not_specialize=['T'])
271
+ def parallel_rebased_bwd_kernel(
272
+ q,
273
+ k,
274
+ v,
275
+ do,
276
+ dz,
277
+ dq,
278
+ dk,
279
+ dv,
280
+ scale,
281
+ T,
282
+ B: tl.constexpr,
283
+ H: tl.constexpr,
284
+ K: tl.constexpr,
285
+ V: tl.constexpr,
286
+ BTL: tl.constexpr,
287
+ BTS: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr
290
+ ):
291
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
292
+ NV = tl.cdiv(V, BV)
293
+ i_k = i_kv // (NV)
294
+ i_v = i_kv % (NV)
295
+ i_h = i_bh % H
296
+ _parallel_rebased_bwd_dq(
297
+ i_bh,
298
+ i_c,
299
+ i_k,
300
+ i_v,
301
+ i_h,
302
+ q,
303
+ k,
304
+ v,
305
+ do,
306
+ dz,
307
+ dq,
308
+ scale,
309
+ B=B,
310
+ H=H,
311
+ T=T,
312
+ K=K,
313
+ V=V,
314
+ BTL=BTL,
315
+ BTS=BTS,
316
+ BK=BK,
317
+ BV=BV
318
+ )
319
+ tl.debug_barrier()
320
+ _parallel_rebased_bwd_dkv(
321
+ i_bh,
322
+ i_c,
323
+ i_k,
324
+ i_v,
325
+ i_h,
326
+ q,
327
+ k,
328
+ v,
329
+ do,
330
+ dz,
331
+ dk,
332
+ dv,
333
+ scale,
334
+ B=B,
335
+ H=H,
336
+ T=T,
337
+ K=K,
338
+ V=V,
339
+ BTL=BTL,
340
+ BTS=BTS,
341
+ BK=BK,
342
+ BV=BV
343
+ )
344
+
345
+
346
+ class ParallelBasedFunction(torch.autograd.Function):
347
+
348
+ @staticmethod
349
+ @input_guard
350
+ @autocast_custom_fwd
351
+ def forward(ctx, q, k, v, scale):
352
+ BTL, BTS = 128, 32
353
+ assert BTL % BTS == 0
354
+ # assert q.shape[-1] % 16 == 0
355
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
356
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
357
+ BK, BV = max(BK, 16), max(BV, 16)
358
+ B, H, T, K, V = *k.shape, v.shape[-1]
359
+ num_stages = 2
360
+ num_warps = 4
361
+ NK = triton.cdiv(K, BK)
362
+ NV = triton.cdiv(V, BV)
363
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
364
+
365
+ assert NK == 1, "will encounter some synchronization issue if not."
366
+
367
+ o = torch.empty(NK, B, H, T, V, device=q.device)
368
+ z = torch.empty(NK, B, H, T, device=q.device)
369
+ parallel_rebased_fwd_kernel[grid](
370
+ q,
371
+ k,
372
+ v,
373
+ o,
374
+ z,
375
+ scale,
376
+ T=T,
377
+ B=B,
378
+ H=H,
379
+ K=K,
380
+ V=V,
381
+ BTL=BTL,
382
+ BTS=BTS,
383
+ BK=BK,
384
+ BV=BV,
385
+ num_warps=num_warps,
386
+ num_stages=num_stages
387
+ )
388
+ ctx.save_for_backward(q, k, v)
389
+ ctx.scale = scale
390
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
391
+
392
+ @staticmethod
393
+ @input_guard
394
+ @autocast_custom_bwd
395
+ def backward(ctx, do, dz):
396
+ q, k, v = ctx.saved_tensors
397
+ scale = ctx.scale
398
+ BTL, BTS = 64, 32
399
+ assert BTL % BTS == 0
400
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
401
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
402
+ BK, BV = max(BK, 16), max(BV, 16)
403
+ B, H, T, K, V = *k.shape, v.shape[-1]
404
+ num_stages = 2
405
+ num_warps = 4
406
+ NK = triton.cdiv(K, BK)
407
+ NV = triton.cdiv(V, BV)
408
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
409
+
410
+ assert NK == 1, "will encounter some synchronization issue if not"
411
+
412
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
413
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
414
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
415
+
416
+ parallel_rebased_bwd_kernel[grid](
417
+ q,
418
+ k,
419
+ v,
420
+ do,
421
+ dz,
422
+ dq,
423
+ dk,
424
+ dv,
425
+ scale,
426
+ T=T,
427
+ B=B,
428
+ H=H,
429
+ K=K,
430
+ V=V,
431
+ BTL=BTL,
432
+ BTS=BTS,
433
+ BK=BK,
434
+ BV=BV,
435
+ num_warps=num_warps,
436
+ num_stages=num_stages
437
+ )
438
+
439
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
440
+
441
+
442
+ def parallel_rebased(
443
+ q: torch.Tensor,
444
+ k: torch.Tensor,
445
+ v: torch.Tensor,
446
+ eps: float = 1e-5,
447
+ use_scale: bool = True,
448
+ use_normalize: bool = True,
449
+ return_both: bool = False,
450
+ head_first: bool = True
451
+ ):
452
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
453
+ if use_scale:
454
+ scale = q.shape[-1] ** -0.5
455
+ else:
456
+ scale = 1
457
+ if not head_first:
458
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
459
+ o, z = ParallelBasedFunction.apply(q, k, v, scale)
460
+ if return_both:
461
+ return o, z
462
+ if use_normalize:
463
+ o = o / (z[..., None] + eps)
464
+ if not head_first:
465
+ o = o.transpose(1, 2)
466
+ return o.to(q.dtype)
fla/ops/retention/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_retention
4
+ from .fused_chunk import fused_chunk_retention
5
+ from .fused_recurrent import fused_recurrent_retention
6
+ from .parallel import parallel_retention
7
+
8
+ __all__ = [
9
+ 'chunk_retention',
10
+ 'fused_chunk_retention',
11
+ 'parallel_retention',
12
+ 'fused_recurrent_retention'
13
+ ]
fla/ops/retention/fused_chunk.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from packaging import version
10
+
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.jit(do_not_specialize=['T'])
15
+ def fused_chunk_retention_fwd_kernel(
16
+ q,
17
+ k,
18
+ v,
19
+ o,
20
+ h0,
21
+ ht,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BK: tl.constexpr,
30
+ BV: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr,
33
+ CHECK: tl.constexpr
34
+ ):
35
+ # indices
36
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
37
+ i_h = i_bh % H
38
+
39
+ o_i = tl.arange(0, BT)
40
+ # decay rate given the head index
41
+ b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))
42
+
43
+ # d_b: overall decay for the entire chunk
44
+ # d_o: cumulative decay from the start of the chunk
45
+ # d_h: cumulative decay from the end of the chunk
46
+ d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
47
+
48
+ # [BT, BT]
49
+ m_s = o_i[:, None] >= o_i[None, :]
50
+ d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
51
+ # [BK, BV]
52
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
53
+
54
+ # make block pointers
55
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
56
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
57
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
58
+ p_o = tl.make_block_ptr(o + (i_k*B*H+i_bh).to(tl.int64) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
59
+
60
+ if USE_INITIAL_STATE:
61
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
62
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
63
+
64
+ NT = tl.cdiv(T, BT)
65
+ for i in range(0, NT):
66
+ # [BT, BK]
67
+ b_q = tl.load(p_q, boundary_check=(0, 1))
68
+ b_q = (b_q * scale).to(b_q.dtype)
69
+ # [BK, BT]
70
+ b_k = tl.load(p_k, boundary_check=(0, 1))
71
+ # [BT, BV]
72
+ b_v = tl.load(p_v, boundary_check=(0, 1))
73
+
74
+ # [BT, BT]
75
+ b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
76
+ # [BT, BV]
77
+ b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
78
+ if CHECK and i == 0:
79
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
80
+ b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
81
+ else:
82
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
83
+ if i == NT - 1 and (T % BT) != 0:
84
+ d_b = tl.math.exp2((T % BT) * b_b)
85
+ d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
86
+ b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
87
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
88
+
89
+ p_q = tl.advance(p_q, (BT, 0))
90
+ p_k = tl.advance(p_k, (0, BT))
91
+ p_v = tl.advance(p_v, (BT, 0))
92
+ p_o = tl.advance(p_o, (BT, 0))
93
+
94
+ if STORE_FINAL_STATE:
95
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
96
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+
99
+ @triton.jit(do_not_specialize=['T'])
100
+ def fused_chunk_retention_bwd_kernel(
101
+ q,
102
+ k,
103
+ v,
104
+ do,
105
+ dq,
106
+ dk,
107
+ dv,
108
+ h0,
109
+ scale,
110
+ T,
111
+ B: tl.constexpr,
112
+ H: tl.constexpr,
113
+ K: tl.constexpr,
114
+ V: tl.constexpr,
115
+ BT: tl.constexpr,
116
+ BK: tl.constexpr,
117
+ BV: tl.constexpr,
118
+ USE_INITIAL_STATE: tl.constexpr,
119
+ CHECK: tl.constexpr
120
+ ):
121
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
122
+ i_h = i_bh % H
123
+
124
+ o_i = tl.arange(0, BT)
125
+ b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))
126
+ d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)
127
+ d_b = tl.math.exp2(BT * b_b)
128
+
129
+ m_s = o_i[:, None] >= o_i[None, :]
130
+ d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
131
+ # [BV, BK]
132
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
133
+ if USE_INITIAL_STATE:
134
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
135
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
136
+
137
+ for i in range(0, tl.cdiv(T, BT)):
138
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
139
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
140
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
141
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
142
+
143
+ # [BT, K]
144
+ b_k = tl.load(p_k, boundary_check=(0, 1))
145
+ # [V, BT]
146
+ b_v = tl.load(p_v, boundary_check=(0, 1))
147
+ # [BT, V]
148
+ b_do = tl.load(p_do, boundary_check=(0, 1))
149
+ b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
150
+
151
+ # [BT, BT]
152
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
153
+ b_ds = (b_ds * d_s).to(b_k.dtype)
154
+ # [BT, K]
155
+ b_dq = tl.dot(b_ds, b_k, allow_tf32=False)
156
+ # [V, K]
157
+ if CHECK and i == 0:
158
+ b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
159
+ b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
160
+ else:
161
+ b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
162
+ b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
163
+
164
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ # sync threads
167
+ b_h = None
168
+ tl.debug_barrier()
169
+ d_s = tl.trans(d_s)
170
+ # [BK, BV]
171
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
172
+ for i in range(1, tl.cdiv(T, BT) + 1):
173
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
174
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
175
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
176
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
177
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
178
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H).to(tl.int64) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
179
+ # [K, BT]
180
+ b_q = tl.load(p_q, boundary_check=(0, 1))
181
+ # [BT, BK]
182
+ b_k = tl.load(p_k, boundary_check=(0, 1))
183
+ # [BT, BV]
184
+ b_v = tl.load(p_v, boundary_check=(0, 1))
185
+ b_do = tl.load(p_do, boundary_check=(0, 1))
186
+ b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
187
+
188
+ # [BT, BT]
189
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
190
+ b_ds = (b_ds * d_s).to(b_k.dtype)
191
+
192
+ # [BT, BT]
193
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
194
+ # [BT, BK]
195
+ b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
196
+ # [BT, BV]
197
+ b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
198
+ if CHECK and i == 1:
199
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
200
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
201
+ b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
202
+ else:
203
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
204
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
205
+ b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
206
+
207
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
208
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
209
+
210
+
211
+ class FusedChunkRetentionFunction(torch.autograd.Function):
212
+
213
+ @staticmethod
214
+ @input_guard
215
+ @autocast_custom_fwd
216
+ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
217
+ B, H, T, K, V = *k.shape, v.shape[-1]
218
+
219
+ BT = 64
220
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
221
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
222
+ num_stages = 1
223
+ num_warps = 4
224
+
225
+ o = q.new_empty(NK, B, H, T, V)
226
+
227
+ if output_final_state:
228
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
229
+ else:
230
+ final_state = None
231
+ # the bug still exists even for Triton 2.2 on H100 GPUs
232
+ # so we always enable initial checks
233
+ CHECK = True
234
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
235
+ import warnings
236
+ warnings.warn(
237
+ "Triton<2.2.0 detected for running this kernel, "
238
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
239
+ "that lead to significant precision loss. "
240
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
241
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
242
+ )
243
+ CHECK = True
244
+
245
+ grid = (NV, NK, B * H)
246
+ fused_chunk_retention_fwd_kernel[grid](
247
+ q,
248
+ k,
249
+ v,
250
+ o,
251
+ initial_state,
252
+ final_state,
253
+ scale,
254
+ T=T,
255
+ B=B,
256
+ H=H,
257
+ K=K,
258
+ V=V,
259
+ BT=BT,
260
+ BK=BK,
261
+ BV=BV,
262
+ USE_INITIAL_STATE=initial_state is not None,
263
+ STORE_FINAL_STATE=output_final_state,
264
+ CHECK=CHECK,
265
+ num_warps=num_warps,
266
+ num_stages=num_stages
267
+ )
268
+
269
+ o = o.sum(0)
270
+ ctx.save_for_backward(q, k, v, initial_state)
271
+ ctx.CHECK = CHECK
272
+ return o.to(q.dtype), final_state
273
+
274
+ @staticmethod
275
+ @input_guard
276
+ @autocast_custom_bwd
277
+ def backward(ctx, do, dht=None):
278
+ q, k, v, initial_state = ctx.saved_tensors
279
+ B, H, T, K, V = *k.shape, v.shape[-1]
280
+ scale = K ** -0.5
281
+
282
+ BT = 64
283
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
284
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
285
+ num_stages = 1
286
+ num_warps = 4
287
+
288
+ dq = q.new_empty(NV, B, H, T, K)
289
+ dk = q.new_empty(NV, B, H, T, K)
290
+ dv = q.new_empty(NK, B, H, T, V)
291
+ grid = (NV, NK, B * H)
292
+
293
+ fused_chunk_retention_bwd_kernel[grid](
294
+ q,
295
+ k,
296
+ v,
297
+ do,
298
+ dq,
299
+ dk,
300
+ dv,
301
+ initial_state,
302
+ scale,
303
+ T=T,
304
+ B=B,
305
+ H=H,
306
+ K=K,
307
+ V=V,
308
+ BT=BT,
309
+ BK=BK,
310
+ BV=BV,
311
+ USE_INITIAL_STATE=initial_state is not None,
312
+ CHECK=ctx.CHECK,
313
+ num_warps=num_warps,
314
+ num_stages=num_stages
315
+ )
316
+ dq = dq.sum(0)
317
+ dk = dk.sum(0)
318
+ dv = dv.sum(0)
319
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
320
+
321
+
322
+ def fused_chunk_retention(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ v: torch.Tensor,
326
+ scale: Optional[float] = None,
327
+ initial_state: Optional[torch.Tensor] = None,
328
+ output_final_state: bool = False,
329
+ head_first: bool = True
330
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
331
+ r"""
332
+ Args:
333
+ q (torch.Tensor):
334
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
335
+ k (torch.Tensor):
336
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
337
+ v (torch.Tensor):
338
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
339
+ scale (Optional[int]):
340
+ Scale factor for the attention scores.
341
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
342
+ initial_state (Optional[torch.Tensor]):
343
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
344
+ output_final_state (Optional[bool]):
345
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
346
+ head_first (Optional[bool]):
347
+ Whether the inputs are in the head-first format.
348
+ Default: `True`.
349
+
350
+ Returns:
351
+ o (torch.Tensor):
352
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
353
+ final_state (torch.Tensor):
354
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
355
+ """
356
+ if scale is None:
357
+ scale = k.shape[-1] ** -0.5
358
+ if not head_first:
359
+ q = q.transpose(1, 2)
360
+ k = k.transpose(1, 2)
361
+ v = v.transpose(1, 2)
362
+ o, final_state = FusedChunkRetentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
363
+ if not head_first:
364
+ o = o.transpose(1, 2)
365
+ return o, final_state
fla/ops/retention/fused_recurrent.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla
9
+
10
+
11
+ def fused_recurrent_retention(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ scale: Optional[float] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: bool = False,
18
+ reverse: bool = False,
19
+ cu_seqlens: Optional[torch.LongTensor] = None,
20
+ head_first: bool = True
21
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ if head_first:
23
+ n_heads = q.shape[1]
24
+ else:
25
+ n_heads = q.shape[2]
26
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log()
27
+ if head_first:
28
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
29
+ else:
30
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
31
+ return fused_recurrent_simple_gla(
32
+ q=q,
33
+ k=k,
34
+ v=v,
35
+ g=g,
36
+ scale=scale,
37
+ initial_state=initial_state,
38
+ output_final_state=output_final_state,
39
+ reverse=reverse,
40
+ cu_seqlens=cu_seqlens,
41
+ head_first=head_first
42
+ )
fla/ops/retention/naive.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+
5
+
6
+ def naive_retention(q, k, v):
7
+ orig_type = q.dtype
8
+ q, k, v = q.float(), k.float(), v.float()
9
+ _, n_heads, seq_len, d_head = q.shape
10
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
11
+ n = q.new_tensor(range(seq_len), dtype=torch.float)
12
+ n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
13
+ s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
14
+ o = torch.einsum('bhqk,bhkd->bhqd', s, v)
15
+ return o.to(orig_type)
fla/ops/rwkv6/recurrent_naive.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_rwkv6(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ w: torch.Tensor,
13
+ u: torch.Tensor,
14
+ scale: Optional[float] = None,
15
+ initial_state: Optional[torch.Tensor] = None,
16
+ output_final_state: Optional[bool] = False
17
+ ):
18
+ orig_dtype = q.dtype
19
+ B, H, T, K, V = *q.shape, v.shape[-1]
20
+ q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
21
+ h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
22
+ o = torch.zeros_like(v)
23
+
24
+ if scale is None:
25
+ scale = K ** -0.5
26
+
27
+ if initial_state is not None:
28
+ h += initial_state
29
+
30
+ for i in range(T):
31
+ q_i = q[:, :, i, :] * scale
32
+ k_i = k[:, :, i]
33
+ v_i = v[:, :, i, :]
34
+ w_i = w[:, :, i].exp()
35
+ kv_i = k_i[..., None] * v_i[..., None, :]
36
+ o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
37
+ o[:, :, i] = o_i.sum(-2)
38
+ h = h * w_i[..., None] + kv_i
39
+ ht = h if output_final_state else None
40
+ return o.to(orig_dtype), ht
41
+
42
+
43
+ @torch.no_grad
44
+ @torch.jit.script
45
+ def naive_recurrent_rwkv6_bwd(
46
+ q: torch.Tensor,
47
+ k: torch.Tensor,
48
+ v: torch.Tensor,
49
+ w: torch.Tensor,
50
+ u: torch.Tensor,
51
+ o: torch.Tensor,
52
+ do: torch.Tensor,
53
+ initial_state: Optional[torch.Tensor] = None
54
+ ):
55
+ q, k, v, w, u, o, do = (x.to(dtype=torch.float32) for x in (q, k, v, w, u, o, do))
56
+ B, H, T, K, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1]
57
+ h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
58
+ dq = torch.zeros_like(q)
59
+ dq_aux = torch.zeros_like(q)
60
+
61
+ if initial_state is not None:
62
+ h += initial_state
63
+
64
+ for i in range(T):
65
+ k_i = k[:, :, i]
66
+ v_i = v[:, :, i]
67
+ w_i = w[:, :, i].exp()
68
+ kv_i = k_i[..., None] * v_i[..., None, :]
69
+ h_i = (h + u[None, ..., None] * kv_i)
70
+ dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
71
+ dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
72
+ dq[:, :, i] = dq_i
73
+ dq_aux[:, :, i] = dq_aux_i
74
+ h = h * w_i[..., None] + kv_i
75
+
76
+ du = torch.zeros_like(u)
77
+ dh = torch.zeros_like(h)
78
+ dk = torch.zeros_like(k)
79
+ dk_aux = torch.zeros_like(k)
80
+ dv = torch.zeros_like(v)
81
+
82
+ for i in range(T - 1, -1, -1):
83
+ d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
84
+ k_i = k[:, :, i]
85
+ v_i = v[:, :, i]
86
+ du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
87
+ du += du_i.sum(0)
88
+ dk_i = (dh * v_i[..., None, :]).sum(-1)
89
+ dk_aux[:, :, i] = dk_i
90
+ dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
91
+ dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
92
+ dv_i += (dh * k_i[..., None]).sum(-2)
93
+
94
+ dk[:, :, i] = dk_i
95
+ dv[:, :, i] = dv_i
96
+ dh = dh * w[:, :, i, :, None].exp() + d_kv_i
97
+
98
+ # dw = q * dq_aux - k * dk_aux
99
+ dw = torch.zeros_like(w)
100
+ for i in range(T - 2, -1, -1):
101
+ dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
102
+
103
+ return dq, dk, dv, dw, du, dh
fla/ops/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (320 Bytes). View file
 
fla/ops/rwkv7/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
fla/ops/simple_gla/chunk.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
10
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv, chunk_fwd_o
11
+ from fla.ops.utils import chunk_local_cumsum
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ def chunk_simple_gla_fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ g: torch.Tensor,
20
+ scale: float,
21
+ initial_state: torch.Tensor,
22
+ output_final_state: bool,
23
+ offsets: Optional[torch.LongTensor] = None,
24
+ indices: Optional[torch.LongTensor] = None,
25
+ head_first: bool = True,
26
+ chunk_size: int = 64
27
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
28
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) if g is not None else None
29
+ h, ht = chunk_fwd_h(
30
+ k=k,
31
+ v=v,
32
+ g=g,
33
+ gk=None,
34
+ gv=None,
35
+ h0=initial_state,
36
+ output_final_state=output_final_state,
37
+ states_in_fp32=False,
38
+ offsets=offsets,
39
+ head_first=head_first,
40
+ chunk_size=chunk_size
41
+ )
42
+ o = chunk_fwd_o(
43
+ q=q,
44
+ k=k,
45
+ v=v,
46
+ g=g,
47
+ h=h,
48
+ scale=scale,
49
+ offsets=offsets,
50
+ indices=indices,
51
+ head_first=head_first,
52
+ chunk_size=chunk_size
53
+ )
54
+ return g, o, ht
55
+
56
+
57
+ def chunk_simple_gla_bwd(
58
+ q: torch.Tensor,
59
+ k: torch.Tensor,
60
+ v: torch.Tensor,
61
+ g: torch.Tensor,
62
+ initial_state: torch.Tensor,
63
+ do: torch.Tensor,
64
+ dht: torch.Tensor,
65
+ scale: float,
66
+ offsets: Optional[torch.LongTensor] = None,
67
+ indices: Optional[torch.LongTensor] = None,
68
+ head_first: bool = True,
69
+ chunk_size: int = 64
70
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
71
+ # (SY 09/22) states_in_fp32 seems not affecting the error of dg but for safety, set to True
72
+ h, _ = chunk_fwd_h(
73
+ k=k,
74
+ v=v,
75
+ g=g,
76
+ gk=None,
77
+ gv=None,
78
+ h0=initial_state,
79
+ output_final_state=False,
80
+ states_in_fp32=True,
81
+ offsets=offsets,
82
+ head_first=head_first,
83
+ chunk_size=chunk_size
84
+ )
85
+ dh, dh0 = chunk_bwd_dh(
86
+ q=q,
87
+ k=k,
88
+ v=v,
89
+ g=g,
90
+ gk=None,
91
+ gv=None,
92
+ do=do,
93
+ h0=initial_state,
94
+ dht=dht,
95
+ scale=scale,
96
+ states_in_fp32=True,
97
+ offsets=offsets,
98
+ head_first=head_first,
99
+ chunk_size=chunk_size
100
+ )
101
+ dq, dk, _, dg = chunk_bwd_dqkwg(
102
+ q=q,
103
+ k=k,
104
+ v=v,
105
+ g=g,
106
+ h=h,
107
+ do=do,
108
+ dh=dh,
109
+ scale=scale,
110
+ offsets=offsets,
111
+ indices=indices,
112
+ head_first=head_first,
113
+ chunk_size=chunk_size
114
+ )
115
+ dv = chunk_bwd_dv(
116
+ q=q,
117
+ k=k,
118
+ g=g,
119
+ do=do,
120
+ dh=dh,
121
+ scale=scale,
122
+ offsets=offsets,
123
+ indices=indices,
124
+ head_first=head_first,
125
+ chunk_size=chunk_size
126
+ )
127
+ return dq, dk, dv, dg, dh0
128
+
129
+
130
+ class ChunkSimpleGLAFunction(torch.autograd.Function):
131
+
132
+ @staticmethod
133
+ @input_guard
134
+ @autocast_custom_fwd
135
+ def forward(
136
+ ctx,
137
+ q,
138
+ k,
139
+ v,
140
+ g,
141
+ scale,
142
+ initial_state,
143
+ output_final_state,
144
+ offsets,
145
+ head_first
146
+ ):
147
+ T = q.shape[2] if head_first else q.shape[1]
148
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
149
+
150
+ # 2-d indices denoting the offsets of chunks in each sequence
151
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
152
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
153
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
154
+ indices = None
155
+ if offsets is not None:
156
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
157
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
158
+
159
+ g, o, ht = chunk_simple_gla_fwd(
160
+ q=q,
161
+ k=k,
162
+ v=v,
163
+ g=g,
164
+ scale=scale,
165
+ initial_state=initial_state,
166
+ output_final_state=output_final_state,
167
+ offsets=offsets,
168
+ indices=indices,
169
+ head_first=head_first,
170
+ chunk_size=chunk_size
171
+ )
172
+ ctx.save_for_backward(q, k, v, g, initial_state)
173
+ ctx.chunk_size = chunk_size
174
+ ctx.scale = scale
175
+ ctx.offsets = offsets
176
+ ctx.indices = indices
177
+ ctx.head_first = head_first
178
+ return o.to(q.dtype), ht
179
+
180
+ @staticmethod
181
+ @input_guard
182
+ @autocast_custom_bwd
183
+ def backward(ctx, do, dht):
184
+ chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first
185
+ q, k, v, g, initial_state = ctx.saved_tensors
186
+ dq, dk, dv, dg, dh0 = chunk_simple_gla_bwd(
187
+ q=q,
188
+ k=k,
189
+ v=v,
190
+ g=g,
191
+ initial_state=initial_state,
192
+ do=do,
193
+ dht=dht,
194
+ scale=scale,
195
+ offsets=offsets,
196
+ indices=indices,
197
+ head_first=head_first,
198
+ chunk_size=chunk_size
199
+ )
200
+ if g is not None:
201
+ dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets,
202
+ indices=indices, head_first=head_first).to(g.dtype)
203
+ else:
204
+ dg = None
205
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, None, dh0, None, None, None
206
+
207
+
208
+ @torch.compiler.disable
209
+ def chunk_simple_gla(
210
+ q: torch.Tensor,
211
+ k: torch.Tensor,
212
+ v: torch.Tensor,
213
+ g: torch.Tensor, # log decay
214
+ scale: Optional[float] = None,
215
+ initial_state: Optional[torch.Tensor] = None,
216
+ output_final_state: bool = False,
217
+ cu_seqlens: Optional[torch.LongTensor] = None,
218
+ head_first: bool = True
219
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
220
+ r"""
221
+ Args:
222
+ q (torch.Tensor):
223
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
224
+ k (torch.Tensor):
225
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
226
+ v (torch.Tensor):
227
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
228
+ g (torch.Tensor):
229
+ Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`.
230
+ Compared to GLA, the gating is head-wise instead of elementwise.
231
+ scale (Optional[int]):
232
+ Scale factor for the attention scores.
233
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
234
+ initial_state (Optional[torch.Tensor]):
235
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
236
+ For equal-length input sequences, `N` equals the batch size `B`.
237
+ Default: `None`.
238
+ output_final_state (Optional[bool]):
239
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
240
+ cu_seqlens (torch.LongTensor):
241
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
242
+ consistent with the FlashAttention API.
243
+ head_first (Optional[bool]):
244
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
245
+ Default: `True`.
246
+
247
+ Returns:
248
+ o (torch.Tensor):
249
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
250
+ final_state (torch.Tensor):
251
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
252
+
253
+ Examples::
254
+ >>> import torch
255
+ >>> import torch.nn.functional as F
256
+ >>> from einops import rearrange
257
+ >>> from fla.ops.simple_gla import chunk_simple_gla
258
+ # inputs with equal lengths
259
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
260
+ >>> q = torch.randn(B, T, H, K, device='cuda')
261
+ >>> k = torch.randn(B, T, H, K, device='cuda')
262
+ >>> v = torch.randn(B, T, H, V, device='cuda')
263
+ >>> g = F.logsigmoid(torch.randn(B, T, H, device='cuda'))
264
+ >>> o, ht = chunk_simple_gla(q, k, v, g,
265
+ initial_state=None,
266
+ output_final_state=True,
267
+ head_first=False)
268
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
269
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g))
270
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
271
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
272
+ >>> o_var, ht_var = chunk_simple_gla(q, k, v, g,
273
+ initial_state=None,
274
+ output_final_state=True,
275
+ cu_seqlens=cu_seqlens,
276
+ head_first=False)
277
+ >>> assert o.allclose(o_var.view(o.shape))
278
+ >>> assert ht.allclose(ht_var)
279
+ """
280
+ if cu_seqlens is not None:
281
+ if q.shape[0] != 1:
282
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
283
+ f"Please flatten variable-length inputs before processing.")
284
+ if head_first:
285
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
286
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
287
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
288
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
289
+ if scale is None:
290
+ scale = k.shape[-1] ** -0.5
291
+ o, final_state = ChunkSimpleGLAFunction.apply(
292
+ q,
293
+ k,
294
+ v,
295
+ g,
296
+ scale,
297
+ initial_state,
298
+ output_final_state,
299
+ cu_seqlens,
300
+ head_first
301
+ )
302
+ return o, final_state
fla/ops/simple_gla/parallel.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
11
+ from fla.ops.utils.op import safe_exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, is_intel_alchemist
13
+
14
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
15
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
16
+
17
+
18
+ @triton.heuristics({
19
+ 'NV': lambda args: triton.cdiv(args['V'], args['BV']),
20
+ 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ 'USE_G': lambda args: args['g'] is not None
23
+ })
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
27
+ for num_warps in [2, 4, 8, 16]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=["BT", "BS", "BK", "BV", "USE_G"],
31
+ )
32
+ @triton.jit
33
+ def parallel_simple_gla_fwd_kernel(
34
+ q,
35
+ k,
36
+ v,
37
+ g,
38
+ o,
39
+ attn,
40
+ scale,
41
+ offsets,
42
+ indices,
43
+ T,
44
+ B: tl.constexpr,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BS: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NV: tl.constexpr,
53
+ OUTPUT_ATTENTIONS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr,
55
+ USE_OFFSETS: tl.constexpr,
56
+ USE_G: tl.constexpr
57
+ ):
58
+ tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time")
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_k, i_v = i_kv // NV, i_kv % NV
61
+ i_b, i_h = i_bh // H, i_bh % H
62
+ o += i_k * B * T * H * V
63
+
64
+ if USE_OFFSETS:
65
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
66
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
67
+ T = eos - bos
68
+ else:
69
+ bos, eos = i_b * T, i_b * T + T
70
+
71
+ q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
72
+ k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
73
+ v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
74
+ o += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
75
+ if USE_G:
76
+ g += i_bh * T if HEAD_FIRST else bos * H + i_h
77
+ if OUTPUT_ATTENTIONS:
78
+ attn += (bos * H + i_h * T) * T + i_k * B * H * T * T
79
+ stride_qk = K if HEAD_FIRST else H * K
80
+ stride_vo = V if HEAD_FIRST else H * V
81
+ stride_g = 1 if HEAD_FIRST else H
82
+
83
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
84
+
85
+ # the Q block is kept in the shared memory throughout the whole kernel
86
+ # [BT, BK]
87
+ b_q = tl.load(p_q, boundary_check=(0, 1))
88
+ b_q = (b_q * scale).to(b_q.dtype)
89
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
90
+
91
+ # [BT]
92
+ o_q = i_t * BT + tl.arange(0, BT)
93
+ # [BS]
94
+ o_k = i_t * BT + tl.arange(0, BS)
95
+ # Q block and K block have overlap.
96
+ # masks required
97
+ if USE_G:
98
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
99
+ # [BT,]
100
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
101
+ # rescale interchunk output
102
+ else:
103
+ b_gq = None
104
+
105
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
106
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1))
107
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
108
+ # [BK, BS]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ # [BS, BV]
111
+ b_v = tl.load(p_v, boundary_check=(0, 1))
112
+ # [BT, BS]
113
+ m_s = o_q[:, None] >= o_k[None, :]
114
+ b_s = tl.dot(b_q, b_k)
115
+ if USE_G:
116
+ p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
117
+ b_gk = tl.load(p_gk, boundary_check=(0,))
118
+ b_s *= safe_exp(b_gq[:, None] - b_gk[None, :])
119
+ b_s = tl.where(m_s, b_s, 0)
120
+ else:
121
+ b_s = tl.where(m_s, b_s, 0)
122
+ # [BT, BV]
123
+ if i_s >= 0:
124
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v)
125
+ if OUTPUT_ATTENTIONS:
126
+ p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0))
127
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
128
+ o_k += BS
129
+
130
+ for i_s in range(i_t * BT - BS, -BS, -BS):
131
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1))
132
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
133
+ # [BK, BS]
134
+ b_k = tl.load(p_k, boundary_check=(0, 1))
135
+ # [BS, BV]
136
+ b_v = tl.load(p_v, boundary_check=(0, 1))
137
+ b_s = tl.dot(b_q, b_k)
138
+ if USE_G:
139
+ p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
140
+ b_g = tl.load(p_g, boundary_check=(0,))
141
+ b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
142
+ b_gp = tl.load(g + (i_s-1) * stride_g) if i_s % BT > 0 else 0.
143
+ # No concrete meaning. Just to avoid some layout bugs.
144
+ b_s *= safe_exp(b_gq[:, None] + (b_gn - b_g)[None, :])
145
+ b_gq += (b_gn - b_gp)
146
+ if OUTPUT_ATTENTIONS:
147
+ p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0))
148
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
149
+ if i_s >= 0:
150
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v)
151
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
152
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+
155
+ @triton.jit(do_not_specialize=['T'])
156
+ def parallel_simple_gla_bwd_kernel_dq(
157
+ i_t,
158
+ i_k,
159
+ i_v,
160
+ q,
161
+ k,
162
+ v,
163
+ g,
164
+ do,
165
+ dq,
166
+ dg,
167
+ stride_qk,
168
+ stride_vo,
169
+ stride_g,
170
+ scale,
171
+ T,
172
+ K: tl.constexpr,
173
+ V: tl.constexpr,
174
+ BT: tl.constexpr,
175
+ BS: tl.constexpr,
176
+ BK: tl.constexpr,
177
+ BV: tl.constexpr,
178
+ USE_G: tl.constexpr
179
+ ):
180
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
181
+ # [BT, BV]
182
+ b_do = tl.load(p_do, boundary_check=(0, 1))
183
+ # [BT, BK]
184
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
185
+
186
+ for i_s in range(0, i_t * BT, BS):
187
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
188
+ p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1))
189
+ # [BS, BK]
190
+ b_k = tl.load(p_k, boundary_check=(0, 1))
191
+ # [BV, BS]
192
+ b_v = tl.load(p_v, boundary_check=(0, 1))
193
+ # [BT, BV] @ [BV, BS] = [BT, BS]
194
+ b_ds = tl.dot(b_do, b_v)
195
+ if USE_G:
196
+ p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
197
+ b_g = tl.load(p_g, boundary_check=(0,))
198
+ b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
199
+ b_gp = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0.
200
+ b_ds *= safe_exp(b_gn - b_g)[None, :]
201
+ if i_s > 0:
202
+ b_dq *= safe_exp(b_gn - b_gp)
203
+ # [BT, BS] @ [BS, BK] = [BT, BK]
204
+ b_dq += tl.dot(b_ds.to(b_v.dtype), b_k)
205
+
206
+ if USE_G:
207
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
208
+ # [BT,]
209
+ b_gq = tl.load(p_gq, boundary_check=(0,))
210
+ # [BT, BK]
211
+ b_dq *= safe_exp(b_gq)[:, None]
212
+
213
+ # [BT]
214
+ o_q = i_t * BT + tl.arange(0, BT)
215
+ # [BS]
216
+ o_k = i_t * BT + tl.arange(0, BS)
217
+ # Q block and K block have overlap. masks required
218
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
219
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
220
+ p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1))
221
+ # [BS, BK]
222
+ b_k = tl.load(p_k, boundary_check=(0, 1))
223
+ # [BV, BS]
224
+ b_v = tl.load(p_v, boundary_check=(0, 1))
225
+ # [BT, BV] @ [BV, BS] = [BT, BS]
226
+ b_ds = tl.dot(b_do, b_v)
227
+ if USE_G:
228
+ p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
229
+ b_gk = tl.load(p_gk, boundary_check=(0,))
230
+ b_ds *= safe_exp(b_gq[:, None] - b_gk[None, :])
231
+ b_ds = tl.where(o_q[:, None] >= o_k[None, :], b_ds, 0)
232
+ # [BT, BK]
233
+ b_dq += tl.dot(b_ds.to(b_k.dtype), b_k)
234
+ o_k += BS
235
+
236
+ b_dq *= scale
237
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
238
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
239
+ if USE_G:
240
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
241
+ b_q = tl.load(p_q, boundary_check=(0, 1))
242
+ b_dg = tl.sum(b_dq * b_q, 1)
243
+ p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
244
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
245
+
246
+
247
+ @triton.jit(do_not_specialize=['T'])
248
+ def parallel_simple_gla_bwd_kernel_dkv(
249
+ i_t,
250
+ i_k,
251
+ i_v,
252
+ q,
253
+ k,
254
+ v,
255
+ g,
256
+ do,
257
+ dk,
258
+ dv,
259
+ dg,
260
+ scale,
261
+ stride_qk,
262
+ stride_vo,
263
+ stride_g,
264
+ T,
265
+ K: tl.constexpr,
266
+ V: tl.constexpr,
267
+ BT: tl.constexpr,
268
+ BS: tl.constexpr,
269
+ BK: tl.constexpr,
270
+ BV: tl.constexpr,
271
+ USE_G: tl.constexpr
272
+ ):
273
+ # [BT, BK]
274
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
275
+ b_k = tl.load(p_k, boundary_check=(0, 1))
276
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
277
+ # [BT, BV]
278
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
279
+ b_v = tl.load(p_v, boundary_check=(0, 1))
280
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
281
+ if USE_G:
282
+ p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
283
+ b_gk = tl.load(p_gk, boundary_check=(0,))
284
+ NTS = tl.cdiv(T, BS)
285
+ # [BT, BK]
286
+ for i_s in range(NTS * BS - BS, (i_t + 1) * BT - BS, -BS):
287
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
288
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
289
+ b_q = tl.load(p_q, boundary_check=(0, 1))
290
+ b_do = tl.load(p_do, boundary_check=(0, 1))
291
+ b_ds = tl.dot(b_v, tl.trans(b_do))
292
+ b_s = tl.dot(b_k, tl.trans(b_q))
293
+ if USE_G:
294
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
295
+ b_gq = tl.load(p_gq, boundary_check=(0,))
296
+ b_gp = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
297
+ b_gn = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0.
298
+ if i_s >= 0:
299
+ tmp = safe_exp(b_gp - b_gn)
300
+ b_dk *= tmp
301
+ b_dv *= tmp
302
+ tmp2 = safe_exp(b_gq - b_gn)
303
+ b_ds *= tmp2[None, :]
304
+ b_s *= tmp2[None, :]
305
+ # [BT, BK]
306
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
307
+ # [BT, BV]
308
+ b_dv += tl.dot(b_s.to(b_do.dtype), b_do)
309
+
310
+ if USE_G:
311
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * stride_g)
312
+ if i_t >= 0:
313
+ tmp2 = safe_exp(b_g_last - b_gk)[:, None]
314
+ b_dk *= tmp2
315
+ b_dv *= tmp2
316
+
317
+ o_q = i_t * BT + tl.arange(0, BS)
318
+ o_k = i_t * BT + tl.arange(0, BT)
319
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
320
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
321
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
322
+ # [BS, BK]
323
+ b_q = tl.load(p_q, boundary_check=(0, 1))
324
+ # [BS, BV]
325
+ b_do = tl.load(p_do, boundary_check=(0, 1))
326
+ # [BS]
327
+ b_ds = tl.dot(b_v, tl.trans(b_do))
328
+ b_s = tl.dot(b_k, tl.trans(b_q))
329
+ if USE_G:
330
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
331
+ b_gq = tl.load(p_gq, boundary_check=(0,))
332
+ if i_s >= 0:
333
+ tmp = safe_exp(-b_gk[:, None] + b_gq[None, :])
334
+ b_ds *= tmp
335
+ b_s *= tmp
336
+ m_s = o_k[:, None] <= o_q[None, :]
337
+ b_s = tl.where(m_s, b_s, 0)
338
+ b_ds = tl.where(m_s, b_ds, 0)
339
+ # [BT, BK]
340
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
341
+ b_dv += tl.dot(b_s.to(b_do.dtype), b_do)
342
+ o_q += BS
343
+ b_dk *= scale
344
+ b_dv *= scale
345
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
346
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
347
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
348
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
349
+ if USE_G:
350
+ p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
351
+ b_dg = tl.load(p_dg, boundary_check=(0,))
352
+ b_dg -= tl.sum(b_dk * b_k, 1)
353
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
354
+
355
+
356
+ @triton.heuristics({
357
+ 'NV': lambda args: triton.cdiv(args['V'], args['BV']),
358
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
359
+ 'USE_G': lambda args: args['g'] is not None
360
+ })
361
+ @triton.autotune(
362
+ configs=[
363
+ triton.Config(triton_config, num_warps=num_warps)
364
+ for num_warps in [2, 4, 8, 16]
365
+ ],
366
+ key=['BT', 'BS', 'BK', 'BV', 'USE_G'],
367
+ )
368
+ @triton.jit(do_not_specialize=['T'])
369
+ def parallel_simple_gla_bwd_kernel(
370
+ q,
371
+ k,
372
+ v,
373
+ g,
374
+ do,
375
+ dq,
376
+ dk,
377
+ dv,
378
+ dg,
379
+ scale,
380
+ offsets,
381
+ indices,
382
+ T,
383
+ B: tl.constexpr,
384
+ H: tl.constexpr,
385
+ K: tl.constexpr,
386
+ V: tl.constexpr,
387
+ BT: tl.constexpr,
388
+ BS: tl.constexpr,
389
+ BK: tl.constexpr,
390
+ BV: tl.constexpr,
391
+ NV: tl.constexpr,
392
+ USE_OFFSETS: tl.constexpr,
393
+ HEAD_FIRST: tl.constexpr,
394
+ USE_G: tl.constexpr
395
+ ):
396
+ tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time")
397
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
398
+ i_k, i_v = i_kv // NV, i_kv % NV
399
+ i_b, i_h = i_bh // H, i_bh % H
400
+ dq += i_v * B * H * T * K
401
+ dk += i_v * B * H * T * K
402
+ dv += i_k * B * H * T * V
403
+ if USE_G:
404
+ dg += i_kv * B * H * T
405
+
406
+ if USE_OFFSETS:
407
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
408
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
409
+ T = eos - bos
410
+ else:
411
+ bos, eos = i_b * T, i_b * T + T
412
+
413
+ q += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K
414
+ k += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K
415
+ v += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V
416
+ do += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V
417
+ dq += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K
418
+ dk += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K
419
+ dv += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V
420
+ if USE_G:
421
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
422
+ dg += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
423
+ stride_qk = K if HEAD_FIRST else H * K
424
+ stride_vo = V if HEAD_FIRST else H * V
425
+ stride_g = 1 if HEAD_FIRST else H
426
+
427
+ parallel_simple_gla_bwd_kernel_dq(
428
+ i_t=i_t,
429
+ i_k=i_k,
430
+ i_v=i_v,
431
+ q=q,
432
+ k=k,
433
+ v=v,
434
+ g=g,
435
+ do=do,
436
+ dq=dq,
437
+ dg=dg,
438
+ scale=scale,
439
+ stride_qk=stride_qk,
440
+ stride_vo=stride_vo,
441
+ stride_g=stride_g,
442
+ T=T,
443
+ K=K,
444
+ V=V,
445
+ BT=BT,
446
+ BS=BS,
447
+ BK=BK,
448
+ BV=BV,
449
+ USE_G=USE_G
450
+ )
451
+ tl.debug_barrier()
452
+ parallel_simple_gla_bwd_kernel_dkv(
453
+ i_t=i_t,
454
+ i_k=i_k,
455
+ i_v=i_v,
456
+ q=q,
457
+ k=k,
458
+ v=v,
459
+ g=g,
460
+ do=do,
461
+ dk=dk,
462
+ dv=dv,
463
+ dg=dg,
464
+ scale=scale,
465
+ stride_qk=stride_qk,
466
+ stride_vo=stride_vo,
467
+ stride_g=stride_g,
468
+ T=T,
469
+ K=K,
470
+ V=V,
471
+ BT=BT,
472
+ BS=BS,
473
+ BK=BK,
474
+ BV=BV,
475
+ USE_G=USE_G
476
+ )
477
+
478
+
479
+ def parallel_simple_gla_fwd(
480
+ q: torch.Tensor,
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ scale: float,
485
+ output_attentions: bool = False,
486
+ chunk_size: int = 128,
487
+ head_first: bool = True,
488
+ offsets: Optional[torch.LongTensor] = None,
489
+ indices: Optional[torch.LongTensor] = None,
490
+ ):
491
+ if head_first:
492
+ B, H, T, K, V = *k.shape, v.shape[-1]
493
+ else:
494
+ B, T, H, K, V = *k.shape, v.shape[-1]
495
+ BT, BS = chunk_size, 32
496
+ if check_shared_mem('hopper', k.device.index):
497
+ BK = min(256, triton.next_power_of_2(K))
498
+ BV = min(256, triton.next_power_of_2(V))
499
+ elif check_shared_mem('ampere', k.device.index):
500
+ BK = min(128, triton.next_power_of_2(K))
501
+ BV = min(128, triton.next_power_of_2(V))
502
+ else:
503
+ BK = min(64, triton.next_power_of_2(K))
504
+ BV = min(64, triton.next_power_of_2(V))
505
+
506
+ NK = triton.cdiv(K, BK)
507
+ NV = triton.cdiv(V, BV)
508
+ assert BT % BS == 0
509
+
510
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
511
+
512
+ # local cumulative decay in log space
513
+ if g is not None:
514
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
515
+ grid = (NK * NV, NT, B * H)
516
+ o = torch.empty(NK, *v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device)
517
+ attn = q.new_zeros(NK, B, H, T, T) if output_attentions else None
518
+
519
+ parallel_simple_gla_fwd_kernel[grid](
520
+ q=q,
521
+ k=k,
522
+ v=v,
523
+ g=g,
524
+ o=o,
525
+ attn=attn,
526
+ scale=scale,
527
+ offsets=offsets,
528
+ indices=indices,
529
+ B=B,
530
+ H=H,
531
+ T=T,
532
+ K=K,
533
+ V=V,
534
+ BT=BT,
535
+ BS=BS,
536
+ BK=BK,
537
+ BV=BV,
538
+ HEAD_FIRST=head_first,
539
+ )
540
+ o = o.sum(0)
541
+
542
+ if output_attentions:
543
+ attn = attn.sum(0)
544
+ return o, g, attn
545
+
546
+
547
+ def parallel_simple_gla_bwd(
548
+ q: torch.Tensor,
549
+ k: torch.Tensor,
550
+ v: torch.Tensor,
551
+ g: torch.Tensor,
552
+ do: torch.Tensor,
553
+ scale: float,
554
+ chunk_size: int = 128,
555
+ head_first: bool = True,
556
+ offsets: Optional[torch.LongTensor] = None,
557
+ indices: Optional[torch.LongTensor] = None,
558
+ ):
559
+ if head_first:
560
+ B, H, T, K, V = *k.shape, v.shape[-1]
561
+ else:
562
+ B, T, H, K, V = *k.shape, v.shape[-1]
563
+ BT, BS = chunk_size, 32
564
+ if check_shared_mem('hopper', k.device.index):
565
+ BK = min(256, triton.next_power_of_2(K))
566
+ BV = min(256, triton.next_power_of_2(V))
567
+ elif check_shared_mem('ampere', k.device.index):
568
+ BK = min(128, triton.next_power_of_2(K))
569
+ BV = min(128, triton.next_power_of_2(V))
570
+ elif check_shared_mem('ada', k.device.index):
571
+ BK = min(64, triton.next_power_of_2(K))
572
+ BV = min(64, triton.next_power_of_2(V))
573
+ else:
574
+ BK = min(32, triton.next_power_of_2(K))
575
+ BV = min(32, triton.next_power_of_2(V))
576
+
577
+ NK = triton.cdiv(K, BK)
578
+ NV = triton.cdiv(V, BV)
579
+ assert BT % BS == 0
580
+
581
+ dq = torch.empty(NV, * q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
582
+ dk = torch.empty(NV, * k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
583
+ dv = torch.empty(NK, * v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device)
584
+ dg = torch.empty(NK*NV, *g.shape, dtype=torch.float, device=q.device) if g is not None else None
585
+
586
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
587
+
588
+ grid = (NK * NV, NT, B * H)
589
+ parallel_simple_gla_bwd_kernel[grid](
590
+ q=q,
591
+ k=k,
592
+ v=v,
593
+ g=g,
594
+ do=do,
595
+ dq=dq,
596
+ dk=dk,
597
+ dv=dv,
598
+ dg=dg,
599
+ offsets=offsets,
600
+ indices=indices,
601
+ scale=scale,
602
+ T=T,
603
+ B=B,
604
+ H=H,
605
+ K=K,
606
+ V=V,
607
+ BT=BT,
608
+ BS=BS,
609
+ BK=BK,
610
+ BV=BV,
611
+ HEAD_FIRST=head_first
612
+ )
613
+ dq = dq.sum(0)
614
+ dk = dk.sum(0)
615
+ dv = dv.sum(0)
616
+ dg = chunk_global_cumsum(dg.sum(0), reverse=True, head_first=head_first, offsets=offsets) if g is not None else None
617
+ return dq, dk, dv, dg
618
+
619
+
620
+ class ParallelSimpleGLAFunction(torch.autograd.Function):
621
+
622
+ @staticmethod
623
+ @input_guard
624
+ @autocast_custom_fwd
625
+ def forward(ctx, q, k, v, g, scale, output_attentions, head_first, offsets):
626
+ chunk_size = 128
627
+ ctx.dtype = q.dtype
628
+
629
+ # 2-d indices denoting the offsets of chunks in each sequence
630
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
631
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
632
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
633
+ indices = None
634
+ if offsets is not None:
635
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
636
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
637
+
638
+ o, g, attn = parallel_simple_gla_fwd(
639
+ q=q,
640
+ k=k,
641
+ v=v,
642
+ g=g,
643
+ scale=scale,
644
+ output_attentions=output_attentions,
645
+ head_first=head_first,
646
+ offsets=offsets,
647
+ indices=indices,
648
+ chunk_size=chunk_size)
649
+ ctx.save_for_backward(q, k, v, g, offsets, indices)
650
+ ctx.scale = scale
651
+ ctx.chunk_size = chunk_size
652
+ ctx.head_first = head_first
653
+ return o.to(q.dtype), attn
654
+
655
+ @staticmethod
656
+ @input_guard
657
+ @autocast_custom_bwd
658
+ def backward(ctx, do, da=None):
659
+ q, k, v, g, offsets, indices = ctx.saved_tensors
660
+ dq, dk, dv, dg = parallel_simple_gla_bwd(
661
+ q=q,
662
+ k=k,
663
+ v=v,
664
+ g=g,
665
+ do=do,
666
+ scale=ctx.scale,
667
+ chunk_size=ctx.chunk_size,
668
+ offsets=offsets,
669
+ indices=indices,
670
+ head_first=ctx.head_first)
671
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.dtype) if dg is not None else None, None, None, None, None
672
+
673
+
674
+ def parallel_simple_gla(
675
+ q: torch.Tensor,
676
+ k: torch.Tensor,
677
+ v: torch.Tensor,
678
+ g: Optional[torch.Tensor] = None,
679
+ scale: Optional[float] = None,
680
+ output_attentions: bool = False,
681
+ cu_seqlens: Optional[torch.LongTensor] = None,
682
+ head_first: bool = True
683
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
684
+ r"""
685
+ Args:
686
+ q (torch.Tensor):
687
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
688
+ k (torch.Tensor):
689
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
690
+ v (torch.Tensor):
691
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
692
+ g (torch.Tensor):
693
+ Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`.
694
+ Compared to GLA, the gating is head-wise instead of elementwise.
695
+ scale (Optional[int]):
696
+ Scale factor for attention scores.
697
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
698
+ output_attentions (bool):
699
+ Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`.
700
+ head_first (Optional[bool]):
701
+ Whether the inputs are in the head-first format. Default: `True`.
702
+ cu_seqlens (torch.LongTensor):
703
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
704
+ consistent with the FlashAttention API.
705
+
706
+ Returns:
707
+ o (torch.Tensor):
708
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
709
+ attn (torch.Tensor):
710
+ Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`
711
+ """
712
+ if scale is None:
713
+ scale = k.shape[-1] ** -0.5
714
+ if cu_seqlens is not None:
715
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
716
+ assert not head_first, "head_first must be False when cu_seqlens are provided"
717
+ if g is not None:
718
+ g = g.float()
719
+ if output_attentions:
720
+ assert cu_seqlens is None, "output_attentions=True is not supported with variable-length sequences"
721
+ o, attn = ParallelSimpleGLAFunction.apply(q, k, v, g, scale, output_attentions, head_first, cu_seqlens)
722
+ return o, attn
fla/ops/titans/naive.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from fla.ops.titans.log_impl import combine_params_log
7
+
8
+
9
+ def cal_n(theta, eta, seq_len):
10
+ n = torch.zeros(*theta.shape, seq_len, dtype=theta.dtype).to(
11
+ theta.device
12
+ ) # [batch_size, num_heads, seq_len, seq_len]
13
+
14
+ # 1. deal with diagonal elements
15
+ indices = torch.arange(seq_len, device=theta.device)
16
+ n[..., indices, indices] = theta[..., indices]
17
+
18
+ # 2. Create a cumulative product matrix
19
+ # First create a mask to mark the positions where eta needs to be multiplied
20
+ mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(theta.device)
21
+ # Convert mask to boolean type
22
+ mask = mask.bool()
23
+ # Expand eta to match the target shape
24
+ eta_expanded = eta.unsqueeze(-2).expand(*theta.shape[:-1], seq_len, seq_len)
25
+ # Create a matrix filled with 1s for cumulative product
26
+ cumulative = torch.ones_like(eta_expanded)
27
+ cumulative = torch.where(mask, eta_expanded, cumulative)
28
+ # Calculate the cumulative product
29
+ cumulative_prod = torch.cumprod(cumulative, dim=-1)
30
+
31
+ # 3. Calculate non-diagonal elements
32
+ # Create an expanded version of theta
33
+ theta_expanded = theta.unsqueeze(-1).expand(*theta.shape[:-1], seq_len, seq_len)
34
+ # Create a mask to keep only the upper triangular part (excluding the diagonal)
35
+ upper_triangular = torch.triu(torch.ones_like(n), diagonal=1).bool()
36
+ # Combine theta and cumulative product
37
+ n = torch.where(upper_triangular, theta_expanded * cumulative_prod, n)
38
+ return n
39
+
40
+
41
+ def cal_f(beta, seq_len, m):
42
+ a = torch.tril(beta.to(torch.float32).unsqueeze(-1).expand(*beta.shape, seq_len), 0)
43
+ ratio = (m.to(torch.float32) / beta.to(torch.float32)).unsqueeze(-1)
44
+ f = torch.matmul(a, ratio).squeeze(-1)
45
+ return f.to(beta.dtype)
46
+
47
+
48
+ def cal_G(beta, n, seq_len):
49
+ i_indices = torch.arange(seq_len, device=beta.device)
50
+ j_indices = torch.arange(seq_len, device=beta.device)
51
+ k_indices = torch.arange(seq_len, device=beta.device)
52
+ beta_ratio = beta[..., :, None] / beta[..., None, :] # [..., i, k]
53
+
54
+ # create mask
55
+ k_mask = (k_indices[None, None, :] >= j_indices[None, :, None]) & (
56
+ k_indices[None, None, :] <= i_indices[:, None, None]
57
+ )
58
+
59
+ # use mask to filter out invalid values
60
+ masked_beta_ratio = beta_ratio[..., :, None, :] * k_mask # [..., i, j, k]
61
+ masked_n = n[..., None, :, :] * k_mask # [..., i, j, k]
62
+ # calculate G
63
+ G = torch.sum(masked_beta_ratio * masked_n, dim=-1) # [..., i, j]
64
+ return G
65
+
66
+
67
+ def combine_params(theta, alpha, eta, seq_len):
68
+ theta = theta.squeeze(-1)
69
+ eta = eta.squeeze(-1)
70
+ alpha = alpha.squeeze(-1)
71
+ beta = torch.cumprod(1 - alpha, dim=-1) # β_t = ∏(1 - α_t) in titans paper
72
+ beta_T = beta[..., -1] # β_T
73
+ # Calculate m_i = ∏(k=1 to i) η_k
74
+ m = torch.cumprod(eta, dim=-1) # [batch_size, num_heads, seq_len]
75
+ m_T = m[..., -1] # m_T
76
+ # Calculate n_{i,j}
77
+ # We need to calculate ∏(k=j+1 to i) η_k for each i,j pair
78
+ # # this may be optimized
79
+ # n = torch.zeros(*theta.shape, seq_len, dtype = theta.dtype).to(
80
+ # theta.device) # [batch_size, num_heads, seq_len, seq_len]
81
+ # for i in range(seq_len):
82
+ # for j in range(i + 1):
83
+ # if i == j:
84
+ # n[..., j, i] = theta[..., j]
85
+ # else:
86
+ # # Calculate product of eta from j+1 to i
87
+ # eta_product = torch.prod(eta[..., j + 1:i + 1], dim = -1)
88
+ # n[..., j, i] = theta[..., j] * eta_product
89
+
90
+ n = cal_n(theta, eta, seq_len)
91
+ n_T = n[..., -1] # [batch_size, num_heads, seq_len]
92
+ # Calculate f_t = ∑(i=1 to t) (β_t/β_i) m_i
93
+ # f = torch.zeros_like(theta)
94
+ # for t in range(seq_len):
95
+ # for i in range(t + 1):
96
+ # f[..., t] += (beta[..., t] / beta[..., i]) * m[..., i]
97
+ f = cal_f(beta, seq_len, m)
98
+ f_T = f[..., -1] # [batch_size, num_heads, seq_len]
99
+ # Calculate g_j = ∑(i=j to t) (β_t/β_i) n_{i,j}
100
+ # g = torch.zeros_like(theta) # [batch_size, num_heads, seq_len]
101
+ # for j in range(seq_len):
102
+ # for i in range(j, seq_len):
103
+ # g[..., j] += (beta[..., -1] / beta[..., i]) * n[..., j, i]
104
+ # G = torch.zeros(*beta.shape[:-1], seq_len, seq_len, device = beta.device)
105
+ # # Fill in the lower triangular part
106
+ # for i in range(seq_len): # row
107
+ # for j in range(i + 1): # column
108
+ # # Sum from k=j to i
109
+ # for k in range(j, i + 1):
110
+ # G[..., i, j] += (beta[..., i] / beta[..., k]) * n[..., j, k]
111
+ G = cal_G(beta, n, seq_len)
112
+ g = G[:, :, -1, :] # [batch_size, num_heads, seq_len]
113
+ # g2, G2 = compute_g_and_G(beta, n, seq_len)
114
+ return beta, beta_T, f, f_T, g, G, m_T, n_T
115
+
116
+
117
+ def titans_linear(
118
+ q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state
119
+ ):
120
+ """
121
+ Implementation of Titans Linear function based on the update rules:
122
+ M_t = (1 - alpha_t) * M_{t-1} + S_t
123
+ S_t = eta_t * S_{t-1} - theta_t * nabla_l(M_{t-1}; x_t)
124
+
125
+ Args:
126
+ q: Query tensor
127
+ k: Key tensor
128
+ v: Value tensor
129
+ w: Weight tensor
130
+ b: Bias tensor
131
+ theta: Learning rate tensor
132
+ alpha: Momentum decay tensor
133
+ eta: Step size tensor
134
+ eps: Epsilon for numerical stability
135
+ initial_state: Initial state M_0
136
+ output_final_state: Whether to output the final state
137
+
138
+ Returns:
139
+ Tuple of (output tensor, final state)
140
+ """
141
+ B, H, T, D = q.shape
142
+ device = q.device
143
+ w = w.reshape(H, 1, D).to(torch.float32)
144
+ b = b.reshape(H, 1, D).to(torch.float32)
145
+ # Initialize states
146
+ if initial_state is None:
147
+ M_prev = torch.zeros(B, H, D, D, device=device)
148
+ else:
149
+ M_prev = initial_state
150
+ M_prev_nabla = M_prev.clone()
151
+ S_prev = torch.zeros_like(M_prev)
152
+ outputs = []
153
+
154
+ # Process sequence step by step
155
+ for t in range(T):
156
+ # Get current step inputs
157
+ q_t = q[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
158
+ k_t = k[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
159
+ v_t = v[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
160
+ theta_t = theta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
161
+ alpha_t = alpha[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
162
+ eta_t = eta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
163
+
164
+ # Compute gradient
165
+ km = k_t @ M_prev_nabla # (batch_size, num_heads, 1, dim)
166
+ reconstruction_target = v_t - k_t
167
+ mean = km.mean(-1, keepdim=True)
168
+ var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32)
169
+ rstd = torch.sqrt(var + eps).to(torch.float32)
170
+ km_hat = (km - mean) / rstd
171
+
172
+ grad = w * km_hat + b - reconstruction_target
173
+ grad = grad * w
174
+ # v_new = (D * grad - grad.sum(-1, keepdim = True) - km_hat * (grad * km_hat).sum(-1, keepdim = True)) / (
175
+ # rstd * D)
176
+ v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D)
177
+ proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D)
178
+ v_new = v_new - proj_term
179
+ # v_new = grad
180
+
181
+ # Update S_t
182
+ S_t = eta_t * S_prev - 2 * theta_t * k_t.transpose(-2, -1) @ v_new
183
+
184
+ # Update M_t
185
+ M_t = (1 - alpha_t) * M_prev + S_t
186
+
187
+ # Store output
188
+ output_t = q_t @ M_t # (batch_size, num_heads, seq_len, dim)
189
+ mean = output_t.mean(dim=-1, keepdim=True)
190
+ var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
191
+ rstd = torch.sqrt(var + eps).to(torch.float32)
192
+ output_t = output_t + (output_t - mean) / rstd * w + b
193
+ outputs.append(output_t)
194
+
195
+ # Update states for next step
196
+ if (t + 1) % chunk_size == 0:
197
+ M_prev_nabla = M_t.clone()
198
+ M_prev = M_t
199
+ S_prev = S_t
200
+
201
+ # Stack outputs along sequence dimension
202
+ output = torch.stack(outputs, dim=-2).squeeze(
203
+ -3
204
+ ) # (batch_size, num_heads, seq_len, dim)
205
+
206
+ if output_final_state:
207
+ return output, M_prev
208
+ return output, None
209
+
210
+
211
+ def chunk_titans_linear(
212
+ q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state
213
+ ):
214
+ B, H, T, D = q.shape
215
+ num_batch = T // chunk_size
216
+ # [num_batch, B, num_heads, mini_batch_size, head_dim]
217
+ _q = q.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
218
+ _k = k.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
219
+ _v = v.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
220
+ # [num_batch, B, num_heads, mini_batch_size, 1]
221
+ _eta = eta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
222
+ _theta = theta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
223
+ _alpha = alpha.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
224
+ # [H, 1, D]
225
+ w = w.reshape(H, 1, D).to(torch.float32)
226
+ b = b.reshape(H, 1, D).to(torch.float32)
227
+ # [num_heads, 1, head_dim]
228
+ if initial_state is None:
229
+ M_prev = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to(
230
+ torch.float32
231
+ )
232
+ else:
233
+ M_prev = initial_state
234
+
235
+ S_prev = torch.zeros_like(M_prev)
236
+
237
+ # [num_batch, B, num_heads, mini_batch_size, head_dim]
238
+ o = torch.empty_like(_v)
239
+
240
+ for i in range(num_batch):
241
+ q_i, k_i, v_i, eta_i, theta_i, alpha_i = [
242
+ x[i] for x in [_q, _k, _v, _eta, _theta, _alpha]
243
+ ]
244
+
245
+ # beta, beta_T, f, f_T, g, G, m_T, n = combine_params(theta_i, alpha_i, eta_i, chunk_size)
246
+ beta, beta_T, f, f_T, g, G, m_T, n = combine_params_log(
247
+ theta_i, alpha_i, eta_i, chunk_size
248
+ )
249
+
250
+ m_T = m_T.unsqueeze(-1).unsqueeze(-1)
251
+ beta_T = beta_T.unsqueeze(-1).unsqueeze(-1)
252
+ f_T = f_T.unsqueeze(-1).unsqueeze(-1)
253
+ g_diag = torch.diag_embed(g).to(q_i.dtype)
254
+ n = torch.diag_embed(n).to(q_i.dtype)
255
+ beta = torch.diag_embed(beta).to(q_i.dtype)
256
+ f = torch.diag_embed(f).to(q_i.dtype)
257
+ km = k_i @ M_prev
258
+ reconstruction_target = v_i - k_i
259
+
260
+ mean = km.mean(-1, True)
261
+ var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32)
262
+ rstd = torch.sqrt(var + eps).to(torch.float32)
263
+ km_hat = (km - mean) / rstd
264
+
265
+ grad = w * km_hat + b - reconstruction_target
266
+ grad *= w
267
+ v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D)
268
+ proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D)
269
+ v_new = v_new - proj_term
270
+ # v_new = (D * grad - grad.sum(-1, True))
271
+ # print(f"Projection term stats: min={torch.abs(beta_T).min()}")
272
+
273
+ # v_new = grad
274
+
275
+ Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) * G
276
+
277
+ # o_i
278
+ output_t = beta @ q_i @ M_prev + f @ q_i @ S_prev - 2 * Attn @ v_new
279
+
280
+ M_t = (
281
+ beta_T * M_prev
282
+ + f_T * S_prev
283
+ - 2 * (g_diag @ k_i).transpose(-1, -2) @ v_new
284
+ )
285
+ # cal S_T from S_0
286
+ S_t = m_T * S_prev - 2 * (n @ k_i).transpose(-1, -2) @ v_new
287
+ # layer norm with residuals
288
+ mean = output_t.mean(dim=-1, keepdim=True)
289
+ var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
290
+ rstd = torch.sqrt(var + eps).to(torch.float32)
291
+ output_t = output_t + (output_t - mean) / rstd * w + b
292
+ o[i] = output_t
293
+ S_prev = S_t
294
+ M_prev = M_t
295
+
296
+ # [B, num_mini_batch, mini_batch_size, num_heads, head_dim]
297
+ o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
298
+ M_prev = M_prev if output_final_state else None
299
+ return o, M_prev
300
+
301
+
302
+ # most of the code is copied from ttt
303
+ def chunk_titans_linear_ref(
304
+ q: torch.Tensor,
305
+ k: torch.Tensor,
306
+ v: torch.Tensor,
307
+ w: torch.Tensor,
308
+ b: torch.Tensor,
309
+ theta: torch.Tensor,
310
+ alpha: torch.Tensor,
311
+ eta: torch.Tensor,
312
+ eps: float = 1e-6,
313
+ chunk_size: int = 16, # chunk size
314
+ initial_state: torch.Tensor = None,
315
+ output_final_state: bool = False,
316
+ head_first: bool = True,
317
+ use_chunk: bool = True,
318
+ ):
319
+ assert q.dtype == k.dtype == v.dtype
320
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
321
+ if not head_first:
322
+ q = q.transpose(1, 2)
323
+ k = k.transpose(1, 2)
324
+ v = v.transpose(1, 2)
325
+ eta = eta.transpose(1, 2)
326
+ alpha = alpha.transpose(1, 2)
327
+ theta = theta.transpose(1, 2)
328
+ seq_len = q.shape[-2]
329
+ pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size
330
+ if pad_len > 0:
331
+ q = F.pad(q, (0, 0, 0, pad_len))
332
+ k = F.pad(k, (0, 0, 0, pad_len))
333
+ v = F.pad(v, (0, 0, 0, pad_len))
334
+ theta = F.pad(theta, (0, 0, 0, pad_len))
335
+ alpha = F.pad(alpha, (0, 0, 0, pad_len))
336
+ eta = F.pad(eta, (0, 0, 0, pad_len))
337
+ theta[:, :, -1, :] = theta[:, :, -(pad_len + 1), :]
338
+ alpha[:, :, -1, :] = alpha[:, :, -(pad_len + 1), :]
339
+ eta[:, :, -1, :] = eta[:, :, -(pad_len + 1), :]
340
+ assert q.shape[-2] % chunk_size == 0, "Sequence length should be a multiple of BT."
341
+ q, k, v, w, b = map(lambda x: x.to(torch.float32), [q, k, v, w, b])
342
+ if use_chunk:
343
+ o, final_state = chunk_titans_linear(
344
+ q,
345
+ k,
346
+ v,
347
+ w,
348
+ b,
349
+ theta,
350
+ alpha,
351
+ eta,
352
+ eps,
353
+ chunk_size,
354
+ initial_state,
355
+ output_final_state,
356
+ )
357
+ else:
358
+ o, final_state = titans_linear(
359
+ q,
360
+ k,
361
+ v,
362
+ w,
363
+ b,
364
+ theta,
365
+ alpha,
366
+ eta,
367
+ eps,
368
+ chunk_size,
369
+ initial_state,
370
+ output_final_state,
371
+ )
372
+ o = o[:, :, :seq_len, :]
373
+ if not head_first:
374
+ o = o.transpose(1, 2)
375
+ return o, final_state
fla/ops/ttt/chunk.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.modules.layernorm import group_norm
12
+ from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
18
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [1, 2, 4, 8]
26
+ ],
27
+ key=['BT', 'BK', 'BV']
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_ttt_linear_fwd_kernel_h(
31
+ k,
32
+ v,
33
+ v_new,
34
+ eta,
35
+ w,
36
+ b,
37
+ eps,
38
+ h,
39
+ hb,
40
+ h0,
41
+ hb0,
42
+ ht,
43
+ hbt,
44
+ offsets,
45
+ chunk_offsets,
46
+ T,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ V: tl.constexpr,
50
+ BT: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ BV: tl.constexpr,
53
+ NT: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ USE_INITIAL_STATE_B: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr,
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
67
+ else:
68
+ bos, eos = i_n * T, i_n * T + T
69
+ NT = tl.cdiv(T, BT)
70
+ boh = i_n * NT
71
+
72
+ # [BK, BV]
73
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
74
+ # [BV]
75
+ b_hb = tl.zeros([BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
79
+ if USE_INITIAL_STATE_B:
80
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
81
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
82
+
83
+ offs = tl.arange(0, BV)
84
+ b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.)
85
+ b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.)
86
+
87
+ for i_t in range(NT):
88
+ if HEAD_FIRST:
89
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
90
+ p_hb = tl.make_block_ptr(hb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
91
+ else:
92
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ p_hb = tl.make_block_ptr(hb + ((boh + i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
94
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
95
+ tl.store(p_hb, b_hb.to(p_hb.dtype.element_ty), boundary_check=(0,))
96
+ if HEAD_FIRST:
97
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
98
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
100
+ p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
103
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
104
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
105
+ p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
106
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
107
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
108
+
109
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
110
+ b_kh = tl.where((offs < V)[None, :], b_kh, 0.)
111
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
112
+ xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.)
113
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
114
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
115
+ b_kh_hat = (b_kh - mean) * rstd
116
+
117
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
118
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
119
+ b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
120
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
121
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
122
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
123
+ b_eta_last = tl.load(p_eta_last)
124
+ b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
125
+ b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
129
+ p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
130
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
131
+ tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,))
132
+
133
+
134
+ @triton.heuristics({
135
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
136
+ })
137
+ @triton.autotune(
138
+ configs=[
139
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
140
+ for num_warps in [2, 4, 8]
141
+ for num_stages in [2, 3]
142
+ ],
143
+ key=['BT'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def chunk_ttt_linear_fwd_kernel_o(
147
+ q,
148
+ k,
149
+ v,
150
+ eta,
151
+ h,
152
+ hb,
153
+ o,
154
+ offsets,
155
+ indices,
156
+ scale,
157
+ T,
158
+ H: tl.constexpr,
159
+ K: tl.constexpr,
160
+ V: tl.constexpr,
161
+ BT: tl.constexpr,
162
+ BK: tl.constexpr,
163
+ BV: tl.constexpr,
164
+ USE_OFFSETS: tl.constexpr,
165
+ HEAD_FIRST: tl.constexpr,
166
+ ):
167
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
168
+ i_b, i_h = i_bh // H, i_bh % H
169
+
170
+ if USE_OFFSETS:
171
+ i_tg = i_t
172
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
173
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
174
+ T = eos - bos
175
+ NT = tl.cdiv(T, BT)
176
+ else:
177
+ NT = tl.cdiv(T, BT)
178
+ i_tg = i_b * NT + i_t
179
+ bos, eos = i_b * T, i_b * T + T
180
+
181
+ # offset calculation
182
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
183
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
184
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
185
+ eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
186
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
187
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
188
+ hb += ((i_bh * NT + i_t) * V) if HEAD_FIRST else ((i_tg * H + i_h) * V)
189
+ stride_qk = K if HEAD_FIRST else H*K
190
+ stride_vo = V if HEAD_FIRST else H*V
191
+ stride_eta = 1 if HEAD_FIRST else H
192
+
193
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (0, i_t * BT), (BK, BT), (0, 1))
195
+ p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,))
196
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0))
197
+ p_hb = tl.make_block_ptr(hb, (V,), (1,), (i_v * BV,), (BV,), (0,))
198
+ # [BT, BK]
199
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
200
+ # [BK, BT]
201
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
202
+ # [BT, 1]
203
+ b_eta = tl.load(p_eta, boundary_check=(0,), padding_option="zero")
204
+ # [BK, BV]
205
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
206
+ # [BV]
207
+ b_hb = tl.load(p_hb, boundary_check=(0,), padding_option="zero")
208
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
209
+ b_o = tl.dot(b_q, b_h, allow_tf32=False)
210
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
211
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
212
+
213
+ o_i = tl.arange(0, BT)
214
+ m_A = o_i[:, None] >= o_i[None, :]
215
+ b_A = tl.where(m_A, b_A, 0)
216
+ b_Ae = tl.where(m_A, b_eta[:, None], 0.0)
217
+
218
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
219
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
220
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
221
+ b_o = (b_o - tl.dot(b_eta[:, None] * b_A.to(b_v.dtype), b_v, allow_tf32=False)) * scale
222
+ b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v.dtype), b_v, allow_tf32=False)
223
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
224
+
225
+
226
+ @triton.heuristics({
227
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
228
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
229
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
230
+ })
231
+ @triton.autotune(
232
+ configs=[
233
+ triton.Config({}, num_warps=num_warps)
234
+ for num_warps in [1, 2, 4, 8]
235
+ ],
236
+ key=['BT', 'BK', 'BV'],
237
+ )
238
+ @triton.jit(do_not_specialize=['T'])
239
+ def chunk_ttt_linear_bwd_kernel_h(
240
+ k,
241
+ v,
242
+ v_new,
243
+ eta,
244
+ w,
245
+ b,
246
+ eps,
247
+ h,
248
+ h0,
249
+ hb0,
250
+ x,
251
+ y,
252
+ r,
253
+ offsets,
254
+ chunk_offsets,
255
+ T,
256
+ H: tl.constexpr,
257
+ K: tl.constexpr,
258
+ V: tl.constexpr,
259
+ BT: tl.constexpr,
260
+ BK: tl.constexpr,
261
+ BV: tl.constexpr,
262
+ NT: tl.constexpr,
263
+ USE_INITIAL_STATE: tl.constexpr,
264
+ USE_INITIAL_STATE_B: tl.constexpr,
265
+ USE_OFFSETS: tl.constexpr,
266
+ HEAD_FIRST: tl.constexpr,
267
+ ):
268
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
269
+ i_n, i_h = i_nh // H, i_nh % H
270
+ if USE_OFFSETS:
271
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
272
+ T = eos - bos
273
+ NT = tl.cdiv(T, BT)
274
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
275
+ else:
276
+ bos, eos = i_n * T, i_n * T + T
277
+ NT = tl.cdiv(T, BT)
278
+ boh = i_n * NT
279
+
280
+ # [BK, BV]
281
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
282
+ # [BV]
283
+ b_hb = tl.zeros([BV], dtype=tl.float32)
284
+ if USE_INITIAL_STATE:
285
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
286
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
287
+ if USE_INITIAL_STATE_B:
288
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
289
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
290
+
291
+ offs = tl.arange(0, BV)
292
+ b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.)
293
+ b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.)
294
+
295
+ for i_t in range(NT):
296
+ if HEAD_FIRST:
297
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
298
+ else:
299
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
300
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
301
+ if HEAD_FIRST:
302
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
303
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
304
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
305
+ p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
306
+ p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
307
+ p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0))
308
+ p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
309
+ else:
310
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
311
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
312
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
313
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
314
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
315
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
316
+ p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
317
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
318
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
319
+
320
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
321
+ b_kh = tl.where((offs < V)[None, :], b_kh, 0.)
322
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
323
+ xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.)
324
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
325
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
326
+ b_kh_hat = (b_kh - mean) * rstd
327
+
328
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
329
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
330
+ b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
331
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
332
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
333
+ tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1))
334
+ tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1))
335
+ tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1))
336
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
337
+ b_eta_last = tl.load(p_eta_last)
338
+ b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
339
+ b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0)
340
+
341
+
342
+ @triton.heuristics({
343
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
344
+ })
345
+ @triton.autotune(
346
+ configs=[
347
+ triton.Config({}, num_warps=num_warps)
348
+ for num_warps in [4]
349
+ ],
350
+ key=['BT', 'BK', 'BV'],
351
+ )
352
+ @triton.jit(do_not_specialize=['T'])
353
+ def chunk_ttt_linear_bwd_kernel_dv_local(
354
+ q,
355
+ k,
356
+ eta,
357
+ do,
358
+ dv,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ H: tl.constexpr,
364
+ K: tl.constexpr,
365
+ V: tl.constexpr,
366
+ BT: tl.constexpr,
367
+ BK: tl.constexpr,
368
+ BV: tl.constexpr,
369
+ USE_OFFSETS: tl.constexpr,
370
+ HEAD_FIRST: tl.constexpr,
371
+ ):
372
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
373
+ i_b, i_h = i_bh // H, i_bh % H
374
+ if USE_OFFSETS:
375
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
376
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
377
+ T = eos - bos
378
+ else:
379
+ bos, eos = i_b * T, i_b * T + T
380
+
381
+ # offset calculation
382
+ q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
383
+ k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
384
+ eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
385
+ do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
386
+ dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
387
+ stride_qk = K if HEAD_FIRST else H*K
388
+ stride_vo = V if HEAD_FIRST else H*V
389
+ stride_eta = 1 if HEAD_FIRST else H
390
+
391
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
392
+ for i_k in range(tl.cdiv(K, BK)):
393
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
394
+ p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
395
+ b_q = tl.load(p_q, boundary_check=(0, 1))
396
+ b_k = tl.load(p_k, boundary_check=(0, 1))
397
+ b_A += tl.dot(b_k, b_q)
398
+
399
+ p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,))
400
+ b_eta = tl.load(p_eta, boundary_check=(0,))
401
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
402
+ b_A = - tl.where(mask, b_A * scale * b_eta[None, :], 0).to(do.dtype.element_ty)
403
+ b_Ae = - tl.where(mask, b_eta[None, :], 0).to(do.dtype.element_ty)
404
+
405
+ for i_v in range(tl.cdiv(V, BV)):
406
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
407
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
408
+ b_do = tl.load(p_do, boundary_check=(0, 1))
409
+ b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do)
410
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
411
+
412
+
413
+ @triton.heuristics({
414
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
415
+ 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None,
416
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
417
+ 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None,
418
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
419
+ })
420
+ @triton.autotune(
421
+ configs=[
422
+ triton.Config({}, num_warps=num_warps)
423
+ for num_warps in [2, 4, 8, 16]
424
+ ],
425
+ key=['BT', 'BK', 'BV'],
426
+ )
427
+ @triton.jit(do_not_specialize=['T'])
428
+ def chunk_ttt_linear_bwd_kernel_norm(
429
+ q,
430
+ k,
431
+ v,
432
+ v_new,
433
+ x,
434
+ y,
435
+ r,
436
+ w,
437
+ b,
438
+ eta,
439
+ h,
440
+ dht,
441
+ dhbt,
442
+ dh0,
443
+ dhb0,
444
+ do,
445
+ dh,
446
+ dhb,
447
+ dv,
448
+ dv_new,
449
+ dk,
450
+ dw,
451
+ db,
452
+ offsets,
453
+ chunk_offsets,
454
+ scale,
455
+ T,
456
+ H: tl.constexpr,
457
+ K: tl.constexpr,
458
+ V: tl.constexpr,
459
+ BT: tl.constexpr,
460
+ BK: tl.constexpr,
461
+ BV: tl.constexpr,
462
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
463
+ USE_FINAL_STATE_GRADIENT_B: tl.constexpr,
464
+ USE_INITIAL_STATE: tl.constexpr,
465
+ USE_INITIAL_STATE_B: tl.constexpr,
466
+ USE_OFFSETS: tl.constexpr,
467
+ HEAD_FIRST: tl.constexpr
468
+ ):
469
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
470
+ i_n, i_h = i_nh // H, i_nh % H
471
+ if USE_OFFSETS:
472
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
473
+ T = eos - bos
474
+ NT = tl.cdiv(T, BT)
475
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
476
+ else:
477
+ bos, eos = i_n * T, i_n * T + T
478
+ NT = tl.cdiv(T, BT)
479
+ boh = i_n * NT
480
+
481
+ # [BK, BV]
482
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
483
+ # [BV]
484
+ b_dhb = tl.zeros([BV], dtype=tl.float32)
485
+ if USE_FINAL_STATE_GRADIENT:
486
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
487
+ b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero")
488
+ if USE_FINAL_STATE_GRADIENT_B:
489
+ p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
490
+ b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero")
491
+
492
+ # [BV]
493
+ offs_v = tl.arange(0, BV)
494
+ offs_t = tl.arange(0, BT)
495
+ b_w = tl.load(w + i_h * V + offs_v, mask=offs_v < V, other=0.)
496
+ b_b = tl.load(b + i_h * V + offs_v, mask=offs_v < V, other=0.)
497
+ b_dw = tl.zeros([BV,], dtype=b_w.dtype)
498
+ b_db = tl.zeros([BV,], dtype=b_b.dtype)
499
+ p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
500
+ p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
501
+
502
+ for i_t in range(NT - 1, -1, -1):
503
+ if HEAD_FIRST:
504
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
505
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
506
+ p_dhb = tl.make_block_ptr(dhb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
507
+ else:
508
+ p_h = tl.make_block_ptr(h + ((boh+i_t) * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
509
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
510
+ p_dhb = tl.make_block_ptr(dhb + ((boh+i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,))
511
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
512
+ tl.store(p_dhb, b_dhb.to(p_dhb.dtype.element_ty), boundary_check=(0,))
513
+ if HEAD_FIRST:
514
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
515
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
516
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
517
+ p_v_new = tl.make_block_ptr(v_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
518
+ p_x = tl.make_block_ptr(x + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
519
+ p_y = tl.make_block_ptr(y + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
520
+ p_dv_new = tl.make_block_ptr(dv_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
521
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
522
+ p_dk = tl.make_block_ptr(dk + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
523
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
524
+ p_r = tl.make_block_ptr(r + i_nh * T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0))
525
+ p_eta_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1
526
+ else:
527
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
528
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
529
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
530
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
531
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
532
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
533
+ p_dv_new = tl.make_block_ptr(dv_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
534
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
535
+ p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, i_k * BK), (BT, BK), (1, 0))
536
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0))
537
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
538
+ p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
539
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
540
+ b_dv_new = tl.load(p_dv_new, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
541
+ b_eta_last = tl.load(p_eta_last)
542
+ b_dv_new -= tl.dot(b_eta_last * b_k, b_dh.to(b_k.dtype))
543
+ b_dv_new -= b_eta_last * b_dhb.to(b_k.dtype)[None, :]
544
+
545
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1), padding_option="zero")
546
+ b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
547
+ b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
548
+ b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
549
+ b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) -
550
+ b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
551
+ b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) +
552
+ b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
553
+ b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v_new.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True)
554
+
555
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
556
+ b_w = b_w.to(b_k.dtype)
557
+ b_b = b_b.to(b_k.dtype)
558
+ b_dv = -b_w * b_dy.to(b_k.dtype)
559
+ b_dk = b_w * b_dy.to(b_k.dtype)
560
+ b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) +
561
+ (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype)
562
+ b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype)
563
+ b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype)
564
+
565
+ # d_rstd, dx --> dkh --> dk, dh
566
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
567
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
568
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
569
+ b_q = (b_q * scale).to(b_q.dtype)
570
+ b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) -
571
+ b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V
572
+ b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V
573
+ b_dkh = tl.where((offs_v < V)[None, :] * (offs_t < T-i_t*BT)[:, None], b_dkh, 0.)
574
+ b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype)
575
+ b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh)
576
+ b_dhb += tl.sum(b_do + b_dkh, axis=0)
577
+ b_dh = tl.where((offs_v < V)[None, :], b_dh, 0.)
578
+ b_dhb = tl.where((offs_v < V), b_dhb, 0.)
579
+
580
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
581
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
582
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,))
583
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
584
+
585
+ if USE_INITIAL_STATE:
586
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
587
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
588
+ if USE_INITIAL_STATE_B:
589
+ p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (i_v * BV,), (BV,), (0,))
590
+ tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,))
591
+
592
+
593
+ @triton.heuristics({
594
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
595
+ })
596
+ @triton.autotune(
597
+ configs=[
598
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
599
+ for num_warps in [2, 4, 8]
600
+ for num_stages in [2, 3]
601
+ ],
602
+ key=['BT', 'BK', 'BV'],
603
+ )
604
+ @triton.jit(do_not_specialize=['T'])
605
+ def chunk_bwd_kernel_dqke(
606
+ q,
607
+ k,
608
+ v,
609
+ e,
610
+ h,
611
+ do,
612
+ dh,
613
+ dhb,
614
+ dq,
615
+ dk,
616
+ de,
617
+ offsets,
618
+ indices,
619
+ scale,
620
+ T,
621
+ B: tl.constexpr,
622
+ H: tl.constexpr,
623
+ K: tl.constexpr,
624
+ V: tl.constexpr,
625
+ BT: tl.constexpr,
626
+ BK: tl.constexpr,
627
+ BV: tl.constexpr,
628
+ USE_OFFSETS: tl.constexpr,
629
+ HEAD_FIRST: tl.constexpr,
630
+ ):
631
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
632
+ i_b, i_h = i_bh // H, i_bh % H
633
+ if USE_OFFSETS:
634
+ i_tg = i_t
635
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
636
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
637
+ T = eos - bos
638
+ NT = tl.cdiv(T, BT)
639
+ else:
640
+ NT = tl.cdiv(T, BT)
641
+ i_tg = i_b * NT + i_t
642
+ bos, eos = i_b * T, i_b * T + T
643
+
644
+ # offset calculation
645
+ v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
646
+ do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V
647
+ h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V
648
+ dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V
649
+ dhb += (i_bh * NT + i_t) * V if HEAD_FIRST else (i_tg * H + i_h) * V
650
+ q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
651
+ k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
652
+ dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
653
+ dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
654
+ e += i_bh * T if HEAD_FIRST else (bos * H + i_h)
655
+ de += i_bh * T if HEAD_FIRST else (bos * H + i_h)
656
+ stride_qk = K if HEAD_FIRST else H*K
657
+ stride_vo = V if HEAD_FIRST else H*V
658
+ stride_e = 1 if HEAD_FIRST else H
659
+
660
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
661
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
662
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
663
+ b_de = tl.zeros([BT,], dtype=tl.float32)
664
+
665
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
666
+ b_k = tl.load(p_k, boundary_check=(0, 1))
667
+ p_e_last = (e + (i_t*BT+BT-1)*stride_e) if (i_t*BT+BT) <= T else (e + (T-1)*stride_e)
668
+ i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1)
669
+ mask = (tl.arange(0, BT) == i_last)
670
+ b_e_last = tl.load(p_e_last)
671
+
672
+ for i_v in range(tl.cdiv(V, BV)):
673
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
674
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
675
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
676
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
677
+ p_dhb = tl.make_block_ptr(dhb, (V,), (1,), (i_v * BV,), (BV,), (0,))
678
+ # [BT, BV]
679
+ b_v = tl.load(p_v, boundary_check=(0, 1))
680
+ b_do = tl.load(p_do, boundary_check=(0, 1))
681
+ # [BV, BK]
682
+ b_h = tl.load(p_h, boundary_check=(0, 1))
683
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
684
+ # [BV]
685
+ b_dhb = tl.load(p_dhb, boundary_check=(0,))
686
+ # [BT, BV] @ [BV, BT] -> [BT, BT]
687
+ b_ds += tl.dot(b_do, tl.trans(b_v))
688
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
689
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
690
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
691
+ b_dk -= b_e_last * tl.dot(b_v, b_dh.to(b_v.dtype))
692
+ b_de -= mask * tl.sum(tl.trans(b_dh) * tl.dot(tl.trans(b_k), b_v.to(b_k.dtype)))
693
+ b_de -= mask * tl.sum(b_dhb * tl.sum(b_v, axis=0).to(b_k.dtype))
694
+
695
+ o_i = tl.arange(0, BT)
696
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
697
+ p_e = tl.make_block_ptr(e, (T,), (stride_e,), (i_t * BT,), (BT,), (0,))
698
+ b_q = tl.load(p_q, boundary_check=(0, 1))
699
+ b_e = tl.load(p_e, boundary_check=(0,))
700
+
701
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
702
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
703
+ p_de = tl.make_block_ptr(de, (T,), (stride_e,), (i_t * BT,), (BT,), (0,))
704
+
705
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0)
706
+ b_ds = b_ds.to(b_k.dtype)
707
+ b_dq -= tl.dot(b_ds, b_k) * b_e[:, None]
708
+ b_dk -= tl.dot(tl.trans(b_ds), b_q * b_e[:, None]) * scale
709
+ b_de -= tl.sum(scale * tl.dot(b_ds, b_k) * b_q, axis=1)
710
+ b_de -= tl.sum(b_ds, axis=1)
711
+ b_dq *= scale
712
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
713
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
714
+ tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,))
715
+
716
+
717
+ def chunk_ttt_linear_fwd_h(
718
+ k: torch.Tensor,
719
+ v: torch.Tensor,
720
+ w: torch.Tensor,
721
+ b: torch.Tensor,
722
+ eta: torch.Tensor,
723
+ eps: float,
724
+ initial_state: Optional[torch.Tensor] = None,
725
+ initial_state_bias: Optional[torch.Tensor] = None,
726
+ output_final_state: bool = False,
727
+ offsets: Optional[torch.LongTensor] = None,
728
+ indices: Optional[torch.LongTensor] = None,
729
+ head_first: bool = True,
730
+ chunk_size: int = 16,
731
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
732
+ if head_first:
733
+ B, H, T, K, V = *k.shape, v.shape[-1]
734
+ else:
735
+ B, T, H, K, V = *k.shape, v.shape[-1]
736
+ BT = chunk_size
737
+ # N: the actual number of sequences in the batch with either equal or variable lengths
738
+ if offsets is None:
739
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
740
+ else:
741
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
742
+ BK = triton.next_power_of_2(K)
743
+ BV = triton.next_power_of_2(V)
744
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
745
+ NK = triton.cdiv(K, BK)
746
+ NV = triton.cdiv(V, BV)
747
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
748
+ assert NV == 1, 'NV > 1 is not supported by TTT update rule.'
749
+
750
+ if head_first:
751
+ h = k.new_empty(B, H, NT, K, V)
752
+ hb = k.new_empty(B, H, NT, 1, V)
753
+ else:
754
+ h = k.new_empty(B, NT, H, K, V)
755
+ hb = k.new_empty(B, NT, H, 1, V)
756
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
757
+ final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None
758
+
759
+ v_new = torch.empty_like(v)
760
+ grid = (NK, NV, N * H)
761
+
762
+ chunk_ttt_linear_fwd_kernel_h[grid](
763
+ k=k,
764
+ v=v,
765
+ v_new=v_new,
766
+ eta=eta,
767
+ w=w,
768
+ b=b,
769
+ eps=eps,
770
+ h=h,
771
+ hb=hb,
772
+ h0=initial_state,
773
+ hb0=initial_state_bias,
774
+ ht=final_state,
775
+ hbt=final_state_bias,
776
+ offsets=offsets,
777
+ chunk_offsets=chunk_offsets,
778
+ T=T,
779
+ H=H,
780
+ K=K,
781
+ V=V,
782
+ BT=BT,
783
+ BK=BK,
784
+ BV=BV,
785
+ NT=NT,
786
+ HEAD_FIRST=head_first
787
+ )
788
+ return h, hb, v_new, final_state, final_state_bias
789
+
790
+
791
+ def chunk_ttt_linear_fwd_o(
792
+ q: torch.Tensor,
793
+ k: torch.Tensor,
794
+ v: torch.Tensor,
795
+ eta: torch.Tensor,
796
+ h: torch.Tensor,
797
+ hb: torch.Tensor,
798
+ scale: Optional[float] = None,
799
+ offsets: Optional[torch.LongTensor] = None,
800
+ indices: Optional[torch.LongTensor] = None,
801
+ head_first: bool = True,
802
+ chunk_size: int = 64
803
+ ) -> torch.Tensor:
804
+ if head_first:
805
+ B, H, T, K, V = *q.shape, v.shape[-1]
806
+ else:
807
+ B, T, H, K, V = *q.shape, v.shape[-1]
808
+ if scale is None:
809
+ scale = k.shape[-1] ** -0.5
810
+ BT = chunk_size
811
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
812
+ BK = triton.next_power_of_2(K)
813
+ BV = triton.next_power_of_2(V)
814
+ NK = triton.cdiv(K, BK)
815
+ NV = triton.cdiv(V, BV)
816
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
817
+ assert NV == 1, 'NV > 1 is not supported by TTT update rule.'
818
+
819
+ o = torch.empty_like(v)
820
+
821
+ grid = (NV, NT, B * H)
822
+ chunk_ttt_linear_fwd_kernel_o[grid](
823
+ q,
824
+ k,
825
+ v,
826
+ eta,
827
+ h,
828
+ hb,
829
+ o,
830
+ offsets,
831
+ indices,
832
+ scale,
833
+ T=T,
834
+ H=H,
835
+ K=K,
836
+ V=V,
837
+ BT=BT,
838
+ BK=BK,
839
+ BV=BV,
840
+ HEAD_FIRST=head_first
841
+ )
842
+ return o
843
+
844
+
845
+ def chunk_ttt_linear_bwd_h(
846
+ k: torch.Tensor,
847
+ v: torch.Tensor,
848
+ w: torch.Tensor,
849
+ b: torch.Tensor,
850
+ eta: torch.Tensor,
851
+ eps: float,
852
+ initial_state: Optional[torch.Tensor] = None,
853
+ initial_state_bias: Optional[torch.Tensor] = None,
854
+ offsets: Optional[torch.LongTensor] = None,
855
+ indices: Optional[torch.LongTensor] = None,
856
+ head_first: bool = True,
857
+ chunk_size: int = 16,
858
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
859
+ if head_first:
860
+ B, H, T, K, V = *k.shape, v.shape[-1]
861
+ else:
862
+ B, T, H, K, V = *k.shape, v.shape[-1]
863
+ BT = chunk_size
864
+ # N: the actual number of sequences in the batch with either equal or variable lengths
865
+ if offsets is None:
866
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
867
+ else:
868
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
869
+ BK = triton.next_power_of_2(K)
870
+ BV = triton.next_power_of_2(V)
871
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
872
+ NK = triton.cdiv(K, BK)
873
+ NV = triton.cdiv(V, BV)
874
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
875
+ assert NV == 1, 'NV > 1 is not supported by TTT update rule.'
876
+
877
+ if head_first:
878
+ h = k.new_empty(B, H, NT, K, V)
879
+ rstd = v.new_empty(B, H, T, 1, dtype=torch.float32)
880
+ else:
881
+ h = k.new_empty(B, NT, H, K, V)
882
+ rstd = v.new_empty(B, T, H, 1, dtype=torch.float32)
883
+ x = torch.empty_like(v)
884
+ y = torch.empty_like(v)
885
+
886
+ v_new = torch.empty_like(v)
887
+ grid = (NK, NV, N * H)
888
+
889
+ chunk_ttt_linear_bwd_kernel_h[grid](
890
+ k=k,
891
+ v=v,
892
+ v_new=v_new,
893
+ eta=eta,
894
+ w=w,
895
+ b=b,
896
+ eps=eps,
897
+ h=h,
898
+ h0=initial_state,
899
+ hb0=initial_state_bias,
900
+ x=x,
901
+ y=y,
902
+ r=rstd,
903
+ offsets=offsets,
904
+ chunk_offsets=chunk_offsets,
905
+ T=T,
906
+ H=H,
907
+ K=K,
908
+ V=V,
909
+ BT=BT,
910
+ BK=BK,
911
+ BV=BV,
912
+ NT=NT,
913
+ HEAD_FIRST=head_first
914
+ )
915
+ return h, v_new, x, y, rstd
916
+
917
+
918
+ def chunk_ttt_linear_bwd_dv_local(
919
+ q: torch.Tensor,
920
+ k: torch.Tensor,
921
+ eta: torch.Tensor,
922
+ do: torch.Tensor,
923
+ scale: float,
924
+ offsets: Optional[torch.LongTensor] = None,
925
+ indices: Optional[torch.LongTensor] = None,
926
+ head_first: bool = True,
927
+ chunk_size: int = 16
928
+ ) -> torch.Tensor:
929
+ if head_first:
930
+ B, H, T, K, V = *k.shape, do.shape[-1]
931
+ else:
932
+ B, T, H, K, V = *k.shape, do.shape[-1]
933
+ BT = chunk_size
934
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
935
+ BK = min(triton.next_power_of_2(K), 128)
936
+ BV = min(triton.next_power_of_2(V), 128)
937
+
938
+ dv = torch.empty_like(do)
939
+ grid = (NT, B * H)
940
+ chunk_ttt_linear_bwd_kernel_dv_local[grid](
941
+ q,
942
+ k,
943
+ eta,
944
+ do,
945
+ dv,
946
+ offsets,
947
+ indices,
948
+ scale,
949
+ T=T,
950
+ H=H,
951
+ K=K,
952
+ V=V,
953
+ BT=BT,
954
+ BK=BK,
955
+ BV=BV,
956
+ HEAD_FIRST=head_first
957
+ )
958
+ return dv
959
+
960
+
961
+ def chunk_ttt_linear_bwd_norm(
962
+ q: torch.Tensor, # [B, H, L, D]
963
+ k: torch.Tensor, # [B, H, L, D]
964
+ v: torch.Tensor, # [B, H, L, D]
965
+ v_new: torch.Tensor, # [B, H, L, D]
966
+ x: torch.Tensor, # [B, H, L, D]
967
+ y: torch.Tensor, # [B, H, L, D]
968
+ rstd: torch.Tensor, # [B, H, L, 1]
969
+ w: torch.Tensor, # [H, D]
970
+ b: torch.Tensor, # [H, D]
971
+ eta: torch.Tensor, # [B, H, L, 1]
972
+ h0: torch.Tensor, # [B, H, D, D]
973
+ hb0: torch.Tensor, # [B, H, 1, D]
974
+ h: torch.Tensor, # [B, H, NT, D, D]
975
+ dht: Optional[torch.Tensor], # [B, H, D, D]
976
+ dhbt: Optional[torch.Tensor], # [B, H, 1, D]
977
+ dv_new: Optional[torch.Tensor], # [B, H, L, D]
978
+ do: torch.Tensor, # [B, H, L, D]
979
+ scale: float,
980
+ offsets: Optional[torch.LongTensor] = None,
981
+ indices: Optional[torch.LongTensor] = None,
982
+ head_first: bool = True,
983
+ chunk_size: int = 16
984
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
985
+ # torch implementation of `dkh, dw, db, dk, dv` for LN^2
986
+ assert offsets is None, "bwd of varlen is not implemented yet."
987
+ if head_first:
988
+ B, H, T, K, V = *q.shape, do.shape[-1]
989
+ else:
990
+ B, T, H, K, V = *q.shape, do.shape[-1]
991
+ BT = chunk_size
992
+ if offsets is None:
993
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
994
+ else:
995
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
996
+
997
+ BK = triton.next_power_of_2(K)
998
+ BV = triton.next_power_of_2(V)
999
+ NK = triton.cdiv(K, BK)
1000
+ NV = triton.cdiv(V, BV)
1001
+ assert NK == 1, 'NK > 1 is not supported by TTT.'
1002
+ assert NV == 1, 'NV > 1 is not supported by TTT.'
1003
+
1004
+ if head_first:
1005
+ dh = q.new_empty(B, H, NT, K, V)
1006
+ dhb = q.new_empty(B, H, NT, 1, V)
1007
+ else:
1008
+ dh = q.new_empty(B, NT, H, K, V)
1009
+ dhb = q.new_empty(B, NT, H, 1, V)
1010
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
1011
+ dhb0 = torch.empty_like(hb0, dtype=torch.float32) if hb0 is not None else None
1012
+ dv = torch.empty_like(v)
1013
+ dk = torch.empty_like(k)
1014
+ dw = w.new_empty(B, H, V)
1015
+ db = b.new_empty(B, H, V)
1016
+
1017
+ grid = (NK, NV, N * H)
1018
+ chunk_ttt_linear_bwd_kernel_norm[grid](
1019
+ q=q,
1020
+ k=k,
1021
+ v=v,
1022
+ v_new=v_new,
1023
+ x=x,
1024
+ y=y,
1025
+ r=rstd,
1026
+ w=w,
1027
+ b=b,
1028
+ eta=eta,
1029
+ h=h,
1030
+ dht=dht,
1031
+ dhbt=dhbt,
1032
+ dh0=dh0,
1033
+ dhb0=dhb0,
1034
+ do=do,
1035
+ dh=dh,
1036
+ dhb=dhb,
1037
+ dv=dv,
1038
+ dv_new=dv_new,
1039
+ dk=dk,
1040
+ dw=dw,
1041
+ db=db,
1042
+ offsets=offsets,
1043
+ chunk_offsets=chunk_offsets,
1044
+ scale=scale,
1045
+ T=T,
1046
+ H=H,
1047
+ K=K,
1048
+ V=V,
1049
+ BT=BT,
1050
+ BK=BK,
1051
+ BV=BV,
1052
+ HEAD_FIRST=head_first
1053
+ )
1054
+ dw = dw.sum(dim=0)
1055
+ db = db.sum(dim=0)
1056
+ return dh, dhb, dh0, dhb0, dv, dk, dw, db
1057
+
1058
+
1059
+ def chunk_ttt_linear_bwd_norm_ref(
1060
+ q: torch.Tensor, # [B, H, L, D]
1061
+ k: torch.Tensor, # [B, H, L, D]
1062
+ v: torch.Tensor, # [B, H, L, D]
1063
+ v_new: torch.Tensor, # [B, H, L, D]
1064
+ kh: torch.Tensor, # [B, H, L, D]
1065
+ y: torch.Tensor, # [B, H, L, D]
1066
+ w: torch.Tensor, # [H, D]
1067
+ b: torch.Tensor, # [H, D]
1068
+ eta: torch.Tensor, # [B, H, L, 1]
1069
+ h0: torch.Tensor, # [B, H, D, D]
1070
+ h: torch.Tensor, # [B, H, NT, D, D]
1071
+ dht: Optional[torch.Tensor], # [B, H, D, D]
1072
+ dv_new: Optional[torch.Tensor], # [B, H, L, D]
1073
+ do: torch.Tensor, # [B, H, L, D]
1074
+ scale: float,
1075
+ eps: float,
1076
+ offsets: Optional[torch.LongTensor] = None,
1077
+ indices: Optional[torch.LongTensor] = None,
1078
+ head_first: bool = True,
1079
+ chunk_size: int = 16
1080
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1081
+ # torch implementation of `dkh, dw, db, dk, dv` for LN^2
1082
+ assert offsets is None, "bwd of varlen is not implemented yet."
1083
+ if head_first:
1084
+ B, H, T, K, V = *q.shape, do.shape[-1]
1085
+ else:
1086
+ B, T, H, K, V = *q.shape, do.shape[-1]
1087
+ # [B, L, H, D] -> [B, H, L, D]
1088
+ q, k, v, v_new, kh, y, h, eta, dv_new, do = [
1089
+ x.transpose(1, 2) for x in
1090
+ [q, k, v, v_new, kh, y, h, eta, dv_new, do]
1091
+ ]
1092
+ BT = chunk_size
1093
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1094
+ pad_len = (BT - (T % BT)) % BT
1095
+ if pad_len > 0:
1096
+ q, k, v, v_new, kh, y, eta, dv_new, do = [
1097
+ F.pad(x, (0, 0, 0, pad_len)) for x in
1098
+ [q, k, v, v_new, kh, y, eta, dv_new, do]
1099
+ ]
1100
+ eta[:, :, -1, :] = eta[:, :, -(pad_len+1), :]
1101
+ # [NT, B, H, BT, D]
1102
+ q, k, v, v_new, kh, y, eta, dv_new, do = [
1103
+ x.reshape(B, H, NT, BT, -1).permute(2, 0, 1, 3, 4) for x in
1104
+ [q, k, v, v_new, kh, y, eta, dv_new, do]
1105
+ ]
1106
+ h = h.permute(2, 0, 1, 3, 4)
1107
+
1108
+ # allocate
1109
+ dh = q.new_zeros(NT, B, H, K, V)
1110
+ dv = torch.zeros_like(v)
1111
+ dk = torch.zeros_like(k)
1112
+ dw = torch.zeros_like(w)
1113
+ db = torch.zeros_like(b)
1114
+ # recurrent state
1115
+ b_dh = dht if dht is not None else torch.zeros_like(dh[0])
1116
+ b_dh = b_dh.to(torch.float32)
1117
+
1118
+ # [H, 1, D]
1119
+ _w = w.reshape(H, 1, V).to(torch.float32)
1120
+ _b = b.reshape(H, 1, V).to(torch.float32)
1121
+
1122
+ # d_state passing
1123
+ for i_t in range(NT - 1, -1, -1):
1124
+ dh[i_t] = b_dh.to(dh.dtype)
1125
+ # [B, H, BT, D]
1126
+ _q, _k, _v, _v_new, _kh, _y, _h, _eta, _dv_new, _do = [
1127
+ x[i_t].to(torch.float32) for x in
1128
+ (q, k, v, v_new, kh, y, h, eta, dv_new, do)
1129
+ ]
1130
+ _dv_new -= (_eta[:, :, -1, :, None] * _k) @ b_dh
1131
+
1132
+ mean = _kh.mean(dim=-1, keepdim=True)
1133
+ var = _kh.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
1134
+ rstd = 1 / torch.sqrt(var + eps).to(torch.float32)
1135
+ x = (_kh - mean) * rstd
1136
+ # [B, H, BT, D]
1137
+ dy = rstd * (_dv_new*V - _dv_new.sum(dim=-1, keepdim=True) - x*(x*_dv_new).sum(dim=-1, keepdim=True)) / V
1138
+ dx = -rstd * (_dv_new*(x*_y).sum(dim=-1, keepdim=True) + _y*(x*_dv_new).sum(dim=-1, keepdim=True)) / V
1139
+ d_rstd = (_dv_new * _v_new / rstd).sum(dim=-1, keepdim=True)
1140
+
1141
+ dv[i_t] = (-_w*dy).to(dv.dtype)
1142
+ dk[i_t] += (_w*dy).to(dk.dtype)
1143
+ dw += (2*_w*x*dy+(_b-_v+_k)*dy).sum(dim=(0, 2)).to(dw.dtype)
1144
+ db += (_w*dy).sum(dim=(0, 2)).to(db.dtype)
1145
+ dx += _w*_w*dy
1146
+
1147
+ # d_rstd, dx --> dkh --> dk, dh
1148
+ dkh = rstd * (V * dx - dx.sum(dim=-1, keepdim=True) - x * (x * dx).sum(dim=-1, keepdim=True)) / V
1149
+ dkh -= rstd**2 * d_rstd * x / V
1150
+ dk[i_t] += (dkh @ _h.transpose(-2, -1)).to(dk.dtype)
1151
+ b_dh += (_q.transpose(-2, -1) * scale) @ _do + _k.transpose(-2, -1) @ dkh
1152
+ dh0 = b_dh.to(torch.float32) if h0 is not None else None
1153
+
1154
+ # [NT, B, H, BT, D] -> [B, H, T, D]
1155
+ dv = dv.permute(1, 2, 0, 3, 4).reshape(B, H, -1, V)[:, :, :T, :]
1156
+ dk = dk.permute(1, 2, 0, 3, 4).reshape(B, H, -1, K)[:, :, :T, :]
1157
+ # [B, H, NT, D, D]
1158
+ dh = dh.permute(1, 2, 0, 3, 4)
1159
+ if not head_first:
1160
+ dv, dk, dh = [x.transpose(1, 2) for x in (dv, dk, dh)]
1161
+ dh, dv, dk, dw, db = [x.contiguous() for x in (dh, dv, dk, dw, db)]
1162
+ dh0 = dh0.contiguous() if h0 is not None else None
1163
+ return dh, dh0, dv, dk, dw, db
1164
+
1165
+
1166
+ def chunk_ttt_linear_bwd_dqke(
1167
+ q: torch.Tensor,
1168
+ k: torch.Tensor,
1169
+ v: torch.Tensor,
1170
+ eta: torch.Tensor,
1171
+ h: torch.Tensor,
1172
+ do: torch.Tensor,
1173
+ dh: torch.Tensor,
1174
+ dhb: torch.Tensor,
1175
+ scale: float,
1176
+ offsets: Optional[torch.LongTensor] = None,
1177
+ indices: Optional[torch.LongTensor] = None,
1178
+ head_first: bool = True,
1179
+ chunk_size: int = 16,
1180
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1181
+
1182
+ if head_first:
1183
+ B, H, T, K, V = *k.shape, v.shape[-1]
1184
+ else:
1185
+ B, T, H, K, V = *k.shape, v.shape[-1]
1186
+ BT = chunk_size
1187
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1188
+
1189
+ BK = triton.next_power_of_2(K)
1190
+ BV = min(triton.next_power_of_2(V), 64)
1191
+ NK = triton.cdiv(K, BK)
1192
+ assert NK == 1, "NK > 1 is not supported."
1193
+
1194
+ dq = torch.empty_like(q)
1195
+ dk = torch.empty_like(k)
1196
+ de = torch.empty_like(eta)
1197
+ grid = (NK, NT, B * H)
1198
+
1199
+ chunk_bwd_kernel_dqke[grid](
1200
+ q=q,
1201
+ k=k,
1202
+ v=v,
1203
+ e=eta,
1204
+ h=h,
1205
+ do=do,
1206
+ dh=dh,
1207
+ dhb=dhb,
1208
+ dq=dq,
1209
+ dk=dk,
1210
+ de=de,
1211
+ offsets=offsets,
1212
+ indices=indices,
1213
+ scale=scale,
1214
+ B=B,
1215
+ T=T,
1216
+ H=H,
1217
+ K=K,
1218
+ V=V,
1219
+ BT=BT,
1220
+ BK=BK,
1221
+ BV=BV,
1222
+ HEAD_FIRST=head_first
1223
+ )
1224
+ return dq, dk, de
1225
+
1226
+
1227
+ def chunk_ttt_linear_fwd(
1228
+ q: torch.Tensor,
1229
+ k: torch.Tensor,
1230
+ v: torch.Tensor,
1231
+ w: torch.Tensor,
1232
+ b: torch.Tensor,
1233
+ eta: torch.Tensor,
1234
+ scale: float,
1235
+ eps: float,
1236
+ initial_state: torch.Tensor,
1237
+ initial_state_bias: torch.Tensor,
1238
+ output_final_state: bool,
1239
+ offsets: Optional[torch.LongTensor] = None,
1240
+ indices: Optional[torch.LongTensor] = None,
1241
+ head_first: bool = True,
1242
+ BT: int = 16
1243
+ ):
1244
+ h, hb, v_new, final_state, final_state_bias = chunk_ttt_linear_fwd_h(
1245
+ k=k,
1246
+ v=v,
1247
+ w=w,
1248
+ b=b,
1249
+ eta=eta,
1250
+ eps=eps,
1251
+ initial_state=initial_state,
1252
+ initial_state_bias=initial_state_bias,
1253
+ output_final_state=output_final_state,
1254
+ offsets=offsets,
1255
+ indices=indices,
1256
+ head_first=head_first,
1257
+ chunk_size=BT
1258
+ )
1259
+ o = chunk_ttt_linear_fwd_o(
1260
+ q=q,
1261
+ k=k,
1262
+ v=v_new,
1263
+ eta=eta,
1264
+ h=h,
1265
+ hb=hb,
1266
+ scale=scale,
1267
+ offsets=offsets,
1268
+ indices=indices,
1269
+ head_first=head_first,
1270
+ chunk_size=BT
1271
+ )
1272
+ return o, final_state, final_state_bias
1273
+
1274
+
1275
+ def chunk_ttt_linear_bwd(
1276
+ q: torch.Tensor,
1277
+ k: torch.Tensor,
1278
+ v: torch.Tensor,
1279
+ w: torch.Tensor,
1280
+ b: torch.Tensor,
1281
+ eta: torch.Tensor,
1282
+ scale: float,
1283
+ eps: float,
1284
+ do: torch.Tensor,
1285
+ dht: torch.Tensor,
1286
+ dhbt: torch.Tensor,
1287
+ BT: int = 16,
1288
+ initial_state: torch.Tensor = None,
1289
+ initial_state_bias: torch.Tensor = None,
1290
+ offsets: Optional[torch.LongTensor] = None,
1291
+ indices: Optional[torch.LongTensor] = None,
1292
+ head_first: bool = True
1293
+ ):
1294
+ h, v_new, x, y, rstd = chunk_ttt_linear_bwd_h(
1295
+ k=k,
1296
+ v=v,
1297
+ w=w,
1298
+ b=b,
1299
+ eta=eta,
1300
+ eps=eps,
1301
+ initial_state=initial_state,
1302
+ initial_state_bias=initial_state_bias,
1303
+ offsets=offsets,
1304
+ indices=indices,
1305
+ head_first=head_first,
1306
+ chunk_size=BT
1307
+ )
1308
+ dv_new = chunk_ttt_linear_bwd_dv_local(
1309
+ q=q,
1310
+ k=k,
1311
+ eta=eta,
1312
+ do=do,
1313
+ scale=scale,
1314
+ offsets=offsets,
1315
+ indices=indices,
1316
+ head_first=head_first,
1317
+ chunk_size=BT
1318
+ )
1319
+ dh, dhb, dh0, dhb0, dv, dk, dw, db = chunk_ttt_linear_bwd_norm(
1320
+ q=q,
1321
+ k=k,
1322
+ v=v,
1323
+ v_new=v_new,
1324
+ x=x,
1325
+ y=y,
1326
+ rstd=rstd,
1327
+ w=w,
1328
+ b=b,
1329
+ eta=eta,
1330
+ h0=initial_state,
1331
+ hb0=initial_state_bias,
1332
+ h=h,
1333
+ dht=dht,
1334
+ dhbt=dhbt,
1335
+ dv_new=dv_new,
1336
+ do=do,
1337
+ scale=scale,
1338
+ offsets=offsets,
1339
+ indices=indices,
1340
+ head_first=head_first,
1341
+ chunk_size=BT
1342
+ )
1343
+ dq, dk2, de = chunk_ttt_linear_bwd_dqke(
1344
+ q=q,
1345
+ k=k,
1346
+ v=v_new,
1347
+ eta=eta,
1348
+ h=h,
1349
+ do=do,
1350
+ dh=dh,
1351
+ dhb=dhb,
1352
+ scale=scale,
1353
+ offsets=offsets,
1354
+ indices=indices,
1355
+ head_first=head_first,
1356
+ chunk_size=BT
1357
+ )
1358
+ dk.add_(dk2)
1359
+ return dq, dk, dv, de, dw, db, dh0, dhb0
1360
+
1361
+
1362
+ class ChunkTTTLinearFunction(torch.autograd.Function):
1363
+
1364
+ @staticmethod
1365
+ @input_guard
1366
+ @autocast_custom_fwd
1367
+ def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state,
1368
+ initial_state_bias, output_final_state, offsets, head_first):
1369
+ # 2-d indices denoting the offsets of chunks in each sequence
1370
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1371
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1372
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1373
+ indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None
1374
+ o, final_state, final_state_bias = chunk_ttt_linear_fwd(
1375
+ q=q,
1376
+ k=k,
1377
+ v=v,
1378
+ w=w,
1379
+ b=b,
1380
+ eta=eta,
1381
+ scale=scale,
1382
+ eps=eps,
1383
+ BT=BT,
1384
+ initial_state=initial_state,
1385
+ initial_state_bias=initial_state_bias,
1386
+ output_final_state=output_final_state,
1387
+ offsets=offsets,
1388
+ indices=indices,
1389
+ head_first=head_first,
1390
+ )
1391
+ ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias)
1392
+ ctx.BT = BT
1393
+ ctx.scale = scale
1394
+ ctx.eps = eps
1395
+ ctx.offsets = offsets
1396
+ ctx.indices = indices
1397
+ ctx.head_first = head_first
1398
+ return o.to(q.dtype), final_state, final_state_bias
1399
+
1400
+ @staticmethod
1401
+ @input_guard
1402
+ @autocast_custom_bwd
1403
+ def backward(ctx, do, dht, dhbt):
1404
+ q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors
1405
+ dq, dk, dv, de, dw, db, dh0, dhb0 = chunk_ttt_linear_bwd(
1406
+ q=q,
1407
+ k=k,
1408
+ v=v,
1409
+ w=w,
1410
+ b=b,
1411
+ eta=eta,
1412
+ scale=ctx.scale,
1413
+ eps=ctx.eps,
1414
+ do=do,
1415
+ dht=dht,
1416
+ dhbt=dhbt,
1417
+ BT=ctx.BT,
1418
+ initial_state=initial_state,
1419
+ initial_state_bias=initial_state_bias,
1420
+ offsets=ctx.offsets,
1421
+ indices=ctx.indices,
1422
+ head_first=ctx.head_first
1423
+ )
1424
+ return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None
1425
+
1426
+
1427
+ def norm_residual(x, weight, bias, eps, head_first):
1428
+ # GroupNorm and Residual
1429
+ if head_first:
1430
+ B, H, T, D = x.shape
1431
+ x = x.transpose(1, 2)
1432
+ x += group_norm(
1433
+ x.reshape(B, T, -1).clone(),
1434
+ weight=weight.reshape(-1).clone(),
1435
+ bias=bias.reshape(-1).clone(),
1436
+ eps=eps,
1437
+ num_groups=H,
1438
+ ).reshape(x.shape)
1439
+ x = x.transpose(1, 2)
1440
+ else:
1441
+ B, T, H, D = x.shape
1442
+ x += group_norm(
1443
+ x.reshape(B, T, -1).clone(),
1444
+ weight=weight.reshape(-1).clone(),
1445
+ bias=bias.reshape(-1).clone(),
1446
+ eps=eps,
1447
+ num_groups=H,
1448
+ ).reshape(x.shape)
1449
+ return x
1450
+
1451
+
1452
+ def chunk_ttt_linear(
1453
+ q: torch.Tensor,
1454
+ k: torch.Tensor,
1455
+ v: torch.Tensor,
1456
+ w: torch.Tensor,
1457
+ b: torch.Tensor,
1458
+ eta: torch.Tensor,
1459
+ scale: float = None,
1460
+ eps: float = 1e-6,
1461
+ chunk_size: int = 16,
1462
+ initial_state: torch.Tensor = None,
1463
+ initial_state_bias: torch.Tensor = None,
1464
+ output_final_state: bool = False,
1465
+ cu_seqlens: Optional[torch.LongTensor] = None,
1466
+ head_first: bool = True,
1467
+ ):
1468
+ r"""
1469
+ Args:
1470
+ q (torch.Tensor):
1471
+ queries of shape `(B, H, T, K)`
1472
+ k (torch.Tensor):
1473
+ keys of shape `(B, H, T, K)`
1474
+ v (torch.Tensor):
1475
+ values of shape `(B, H, T, V)`
1476
+ w (torch.Tensor):
1477
+ layer norm weight of shape `(H, V)`
1478
+ b (torch.Tensor):
1479
+ layer norm bias of shape `(H, V)`
1480
+ eta (torch.Tensor):
1481
+ Learning rate for hidden state, of shape `(B, H, T, 1)`.
1482
+ scale (Optional[int]):
1483
+ Scale factor for the RetNet attention scores.
1484
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1485
+ chunk_size (int):
1486
+ chunk size. Default: `16`.
1487
+ initial_state (Optional[torch.Tensor]):
1488
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
1489
+ initial_state_bias (Optional[torch.Tensor]):
1490
+ Initial state bias of shape `(B, H, 1, V)`. Default: `None`.
1491
+ output_final_state (Optional[bool]):
1492
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
1493
+ cu_seqlens (torch.LongTensor):
1494
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1495
+ consistent with the FlashAttention API.
1496
+ head_first (Optional[bool]):
1497
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1498
+ Default: `True`.
1499
+ Returns:
1500
+ o (torch.Tensor):
1501
+ Outputs of shape `[B, H, T, V]`
1502
+ final_state (torch.Tensor):
1503
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
1504
+ """
1505
+ assert q.dtype == k.dtype == v.dtype
1506
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
1507
+ if isinstance(eta, float):
1508
+ eta = torch.full_like(q[:, :, :, :1], eta)
1509
+ if cu_seqlens is not None:
1510
+ if q.shape[0] != 1:
1511
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1512
+ f"Please flatten variable-length inputs before processing.")
1513
+ if head_first:
1514
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1515
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
1516
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1517
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
1518
+ if scale is None:
1519
+ scale = k.shape[-1] ** -0.5
1520
+ else:
1521
+ assert scale > 0, "Scale must be positive."
1522
+ o, final_state, final_state_bias = ChunkTTTLinearFunction.apply(
1523
+ q,
1524
+ k,
1525
+ v,
1526
+ w,
1527
+ b,
1528
+ chunk_size,
1529
+ eta,
1530
+ scale,
1531
+ eps,
1532
+ initial_state,
1533
+ initial_state_bias,
1534
+ output_final_state,
1535
+ cu_seqlens,
1536
+ head_first,
1537
+ )
1538
+ o = norm_residual(o, w, b, eps, head_first)
1539
+ return o, final_state, final_state_bias
fla/ops/utils/asm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.utils import device_platform
4
+
5
+
6
+ def fp32_to_tf32_asm() -> str:
7
+ """
8
+ Get the assembly code for converting FP32 to TF32.
9
+ """
10
+ ASM_DICT = {
11
+ 'nvidia': 'cvt.rna.tf32.f32 $0, $1;'
12
+ }
13
+ if device_platform in ASM_DICT:
14
+ return ASM_DICT[device_platform]
15
+ else:
16
+ # return empty string if the device is not supported
17
+ return ""
fla/ops/utils/logcumsumexp.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from fla.ops.utils.op import exp, log
8
+
9
+
10
+ @triton.autotune(
11
+ configs=[
12
+ triton.Config({'BT': BT}, num_warps=num_warps)
13
+ for BT in [16, 32, 64]
14
+ for num_warps in [2, 4, 8]
15
+ ],
16
+ key=['S']
17
+ )
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def logcumsumexp_fwd_kernel(
20
+ s,
21
+ z,
22
+ T,
23
+ S: tl.constexpr,
24
+ BT: tl.constexpr
25
+ ):
26
+ i_bh = tl.program_id(0)
27
+ o_i = tl.arange(0, BT)
28
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
29
+
30
+ b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
31
+ b_zp = tl.zeros([S,], dtype=tl.float32)
32
+ for i_t in range(tl.cdiv(T, BT)):
33
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
34
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
35
+
36
+ # [BT, S]
37
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
38
+ # [S,]
39
+ b_mc = tl.max(b_s, 0)
40
+ b_mc = tl.maximum(b_mp, b_mc)
41
+ b_zp = b_zp * exp(b_mp - b_mc)
42
+ # [BT, S]
43
+ b_s = exp(b_s - b_mc)
44
+ b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
45
+ # [S,]
46
+ b_zc = tl.max(b_z, 0)
47
+ b_mp = b_mc
48
+ b_zp = b_zc
49
+ # [BT, BS]
50
+ # small eps to prevent underflows
51
+ b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
52
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
fla/ops/utils/testing.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ compiled_mode = os.getenv("COMPILER_MODE") == "1"
4
+ ci_env = os.getenv("CI_ENV") == "1"
5
+
6
+
7
+ def get_abs_err(x, y):
8
+ return (x.detach()-y.detach()).flatten().abs().max().item()
9
+
10
+
11
+ def get_err_ratio(x, y):
12
+ err = (x-y).flatten().square().mean().sqrt().item()
13
+ base = (x).flatten().square().mean().sqrt().item()
14
+ return err / (base + 1e-15)
15
+
16
+
17
+ def assert_close(prefix, ref, tri, ratio, warning=False):
18
+ msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}"
19
+ print(msg)
20
+ error_rate = get_err_ratio(ref, tri)
21
+ if warning or str(prefix).strip().lower() == "dh0" or (ci_env and error_rate < 0.01):
22
+ if error_rate > ratio:
23
+ import warnings
24
+ warnings.warn(msg)
25
+ else:
26
+ assert error_rate < ratio, msg
profile_trace/iteration_10240/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_10240/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_10240/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_10240/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_10240/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff