zaydzuhri commited on
Commit
f3ebcf2
·
verified ·
1 Parent(s): 8cbec64

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/ops/attn/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/ops/based/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  4. fla/ops/common/chunk_h.py +422 -0
  5. fla/ops/common/chunk_h_split.py +677 -0
  6. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  7. fla/ops/delta_rule/fused_recurrent.py +607 -0
  8. fla/ops/delta_rule/wy_fast.py +340 -0
  9. fla/ops/forgetting_attn/parallel.py +708 -0
  10. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  11. fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +324 -0
  12. fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +196 -0
  13. fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +138 -0
  14. fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +292 -0
  15. fla/ops/generalized_delta_rule/iplr/chunk.py +528 -0
  16. fla/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  17. fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  18. fla/ops/gla/fused_recurrent.py +113 -0
  19. fla/ops/gsa/chunk.py +1264 -0
  20. fla/ops/gsa/naive.py +68 -0
  21. fla/ops/hgrn/chunk.py +282 -0
  22. fla/ops/hgrn/naive.py +63 -0
  23. fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  24. fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  25. fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  26. fla/ops/nsa/__init__.py +9 -0
  27. fla/ops/rebased/__pycache__/__init__.cpython-312.pyc +0 -0
  28. fla/ops/rebased/__pycache__/parallel.cpython-312.pyc +0 -0
  29. fla/ops/retention/fused_chunk.py +365 -0
  30. fla/ops/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  31. fla/ops/rwkv7/__pycache__/channel_mixing.cpython-312.pyc +0 -0
  32. fla/ops/simple_gla/naive.py +54 -0
  33. fla/ops/titans/naive.py +375 -0
  34. fla/ops/ttt/__init__.py +9 -0
  35. fla/ops/ttt/fused_chunk.py +896 -0
  36. fla/ops/ttt/naive.py +126 -0
  37. fla/ops/utils/cumsum.py +400 -0
  38. fla/ops/utils/logcumsumexp.py +52 -0
  39. fla/ops/utils/solve_tril.py +321 -0
  40. profile_trace/iteration_1024/rank0_trace.json +0 -0
  41. profile_trace/iteration_1024/rank1_trace.json +0 -0
  42. profile_trace/iteration_1024/rank2_trace.json +0 -0
  43. profile_trace/iteration_1024/rank3_trace.json +0 -0
  44. profile_trace/iteration_1024/rank4_trace.json +0 -0
  45. profile_trace/iteration_1024/rank5_trace.json +0 -0
  46. profile_trace/iteration_1024/rank7_trace.json +0 -0
  47. profile_trace/iteration_11264/rank2_trace.json +0 -0
  48. profile_trace/iteration_15360/rank2_trace.json +0 -0
  49. profile_trace/iteration_15360/rank4_trace.json +0 -0
  50. profile_trace/iteration_20992/rank0_trace.json +0 -0
fla/ops/attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (220 Bytes). View file
 
fla/ops/based/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (286 Bytes). View file
 
fla/ops/based/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
fla/ops/common/chunk_h.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem
13
+
14
+ BKV_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in BKV_LIST
26
+ for BV in BKV_LIST
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ split_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BS: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ USE_G: tl.constexpr,
53
+ USE_GK: tl.constexpr,
54
+ USE_GV: tl.constexpr,
55
+ USE_INITIAL_STATE: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ NS = tl.cdiv(T, BS)
67
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
68
+ else:
69
+ bos, eos = i_n * T, i_n * T + T
70
+ NT = tl.cdiv(T, BT)
71
+ NS = tl.cdiv(T, BS)
72
+ boh = i_n * NS
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
79
+
80
+ for i_t in range(NT):
81
+ i_s = i_t // (BS // BT)
82
+ if HEAD_FIRST:
83
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
84
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+
86
+ o_h = (i_nh * NS + i_s).to(tl.int64) * K*V
87
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
88
+ else:
89
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
90
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+
92
+ o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
93
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
94
+
95
+ if i_t % (BS // BT) == 0:
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+ # [BK, BT]
98
+ b_k = tl.load(p_k, boundary_check=(0, 1))
99
+ # [BT, BV]
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ last_idx = min((i_t + 1) * BT, T) - 1
102
+
103
+ # scalar decay
104
+ if USE_G:
105
+ if HEAD_FIRST:
106
+ b_g_last = tl.load(g + i_nh * T + last_idx)
107
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
108
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
109
+ else:
110
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
111
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
112
+ b_h *= exp(b_g_last)
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+ b_h *= exp(b_gk_last)[:, None]
128
+
129
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
130
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
131
+
132
+ # vector decay, h = h @ Diag(gv)
133
+ if USE_GV:
134
+ if HEAD_FIRST:
135
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
136
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
137
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
138
+ else:
139
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
140
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
141
+
142
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
143
+ b_h *= exp(b_gv_last)[None, :]
144
+
145
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
146
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
147
+
148
+ b_h += tl.dot(b_k, b_v)
149
+
150
+ if STORE_FINAL_STATE:
151
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
157
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
158
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
159
+ })
160
+ @triton.autotune(
161
+ configs=[
162
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
163
+ for BK in BKV_LIST
164
+ for BV in BKV_LIST
165
+ for num_warps in [1, 2, 4, 8]
166
+ for num_stages in [2, 3, 4]
167
+ ],
168
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_bwd_kernel_dh(
172
+ q,
173
+ g,
174
+ gk,
175
+ gv,
176
+ do,
177
+ dh,
178
+ dht,
179
+ dh0,
180
+ offsets,
181
+ split_offsets,
182
+ scale,
183
+ T,
184
+ HQ: tl.constexpr,
185
+ H: tl.constexpr,
186
+ K: tl.constexpr,
187
+ V: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr,
192
+ NG: tl.constexpr,
193
+ USE_G: tl.constexpr,
194
+ USE_GK: tl.constexpr,
195
+ USE_GV: tl.constexpr,
196
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
197
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_bg = i_nh // NG
203
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
204
+ i_h = i_hq // NG
205
+ if USE_OFFSETS:
206
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
207
+ T = eos - bos
208
+ NT = tl.cdiv(T, BT)
209
+ NS = tl.cdiv(T, BS)
210
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
211
+ else:
212
+ bos, eos = i_n * T, i_n * T + T
213
+ NT = tl.cdiv(T, BT)
214
+ NS = tl.cdiv(T, BS)
215
+ boh = i_n * NS
216
+
217
+ # [BK, BV]
218
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
219
+ if USE_FINAL_STATE_GRADIENT:
220
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
221
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
222
+
223
+ for i_t in range(NT - 1, -1, -1):
224
+ i_s = i_t // (BS // BT)
225
+ if HEAD_FIRST:
226
+ o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V
227
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
228
+ else:
229
+ o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
230
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
231
+
232
+ if i_t % (BS // BT) == 0:
233
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
234
+ last_idx = min(i_t * BT + BT, T) - 1
235
+ # [BK, BT]
236
+ if HEAD_FIRST:
237
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
238
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ else:
240
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
241
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ b_q = tl.load(p_q, boundary_check=(0, 1))
243
+ b_q = (b_q * scale).to(b_q.dtype)
244
+ # [BT, BV]
245
+ b_do = tl.load(p_do, boundary_check=(0, 1))
246
+
247
+ if USE_G:
248
+ if HEAD_FIRST:
249
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
250
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
251
+ b_g_last = tl.load(g + i_bg * T + last_idx)
252
+ else:
253
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
254
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
255
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
256
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
257
+
258
+ b_dh *= exp(b_g_last)
259
+
260
+ if USE_GK:
261
+ if HEAD_FIRST:
262
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
263
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
264
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
265
+ else:
266
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
267
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+
269
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
270
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
271
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
272
+ b_dh *= exp(b_gk_last)[:, None]
273
+
274
+ if USE_GV:
275
+ if HEAD_FIRST:
276
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
277
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
278
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
279
+ else:
280
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
281
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
282
+
283
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
284
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
285
+
286
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
287
+ b_dh *= exp(b_gv_last)[None, :]
288
+
289
+ b_dh += tl.dot(b_q, b_do)
290
+
291
+ if STORE_INITIAL_STATE_GRADIENT:
292
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
293
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
294
+
295
+
296
+ def chunk_fwd_h(
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ g: torch.Tensor,
300
+ gk: torch.Tensor,
301
+ gv: torch.Tensor,
302
+ h0: torch.Tensor,
303
+ output_final_state: bool,
304
+ offsets: Optional[torch.Tensor] = None,
305
+ head_first: bool = True,
306
+ chunk_size: int = 64,
307
+ split_size: Optional[int] = None,
308
+ states_in_fp32: bool = False
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ if head_first:
311
+ B, H, T, K, V = *k.shape, v.shape[-1]
312
+ else:
313
+ B, T, H, K, V = *k.shape, v.shape[-1]
314
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
315
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
316
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
317
+ # N: the actual number of sequences in the batch with either equal or variable lengths
318
+ if offsets is None:
319
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
320
+ else:
321
+ split_offsets = prepare_chunk_offsets(offsets, BS)
322
+ N, NS = len(offsets) - 1, split_offsets[-1]
323
+
324
+ if head_first:
325
+ h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
326
+ else:
327
+ h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
328
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
329
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
330
+ chunk_fwd_kernel_h[grid](
331
+ k=k,
332
+ v=v,
333
+ h=h,
334
+ g=g,
335
+ gk=gk,
336
+ gv=gv,
337
+ h0=h0,
338
+ ht=ht,
339
+ offsets=offsets,
340
+ split_offsets=split_offsets,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ V=V,
345
+ BT=BT,
346
+ BS=BS,
347
+ USE_G=g is not None,
348
+ USE_GK=gk is not None,
349
+ USE_GV=gv is not None,
350
+ HEAD_FIRST=head_first
351
+ )
352
+ return h, ht
353
+
354
+
355
+ def chunk_bwd_dh(
356
+ q: torch.Tensor,
357
+ k: torch.Tensor,
358
+ v: torch.Tensor,
359
+ g: torch.Tensor,
360
+ gk: torch.Tensor,
361
+ gv: torch.Tensor,
362
+ do: torch.Tensor,
363
+ h0: torch.Tensor,
364
+ dht: torch.Tensor,
365
+ scale: float,
366
+ offsets: Optional[torch.Tensor] = None,
367
+ head_first: bool = True,
368
+ chunk_size: int = 64,
369
+ split_size: Optional[int] = None,
370
+ states_in_fp32: bool = False
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ if head_first:
373
+ B, H, T, K, V = *k.shape, v.shape[-1]
374
+ HQ = q.shape[1]
375
+ else:
376
+ B, T, H, K, V = *k.shape, v.shape[-1]
377
+ HQ = q.shape[2]
378
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
379
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
380
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
381
+ # N: the actual number of sequences in the batch with either equal or variable lengths
382
+ # NG: number of groups in GQA
383
+ if offsets is None:
384
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
385
+ else:
386
+ split_offsets = prepare_chunk_offsets(offsets, BS)
387
+ N, NS = len(offsets) - 1, split_offsets[-1]
388
+ NG = HQ // H
389
+
390
+ if head_first:
391
+ dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
392
+ else:
393
+ dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
394
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
395
+
396
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
397
+ chunk_bwd_kernel_dh[grid](
398
+ q=q,
399
+ g=g,
400
+ gk=gk,
401
+ gv=gv,
402
+ do=do,
403
+ dh=dh,
404
+ dht=dht,
405
+ dh0=dh0,
406
+ offsets=offsets,
407
+ split_offsets=split_offsets,
408
+ scale=scale,
409
+ T=T,
410
+ HQ=HQ,
411
+ H=H,
412
+ K=K,
413
+ V=V,
414
+ BT=BT,
415
+ BS=BS,
416
+ NG=NG,
417
+ USE_G=g is not None,
418
+ USE_GK=gk is not None,
419
+ USE_GV=gv is not None,
420
+ HEAD_FIRST=head_first
421
+ )
422
+ return dh, dh0
fla/ops/common/chunk_h_split.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in [32, 64]
22
+ for BV in [32, 64]
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3]
25
+ ],
26
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def chunk_fwd_kernel_h_split(
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ hs,
36
+ hr,
37
+ h0,
38
+ ht,
39
+ offsets,
40
+ split_indices,
41
+ T,
42
+ S: tl.constexpr,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_GK: tl.constexpr,
51
+ USE_GV: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # handle one split at a time
58
+ # i_h: head index
59
+ # i_n: sequence index
60
+ # i_s: local split index inside a sequence
61
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_ss, i_h = i_sh // H, i_sh % H
63
+ if USE_OFFSETS:
64
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ NS = tl.cdiv(T, S)
68
+ else:
69
+ NS = tl.cdiv(T, S)
70
+ i_n, i_s = i_ss // NS, i_ss % NS
71
+ bos, eos = i_n * T, i_n * T + T
72
+ i_nh = i_n * H + i_h
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # for the first split, we directly store the state as the final result
77
+ if i_s == 0:
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
81
+ p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
83
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
86
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ else:
88
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ # [BK, BT]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BT, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ last_idx = min(i_t * BT + BT, T) - 1
95
+
96
+ # scalar decay
97
+ if USE_G:
98
+ if HEAD_FIRST:
99
+ b_g_last = tl.load(g + i_nh * T + last_idx)
100
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
101
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
102
+ else:
103
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
104
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
105
+ b_h *= exp(b_g_last)
106
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
107
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
108
+
109
+ # vector decay, h = Diag(gk) @ h
110
+ if USE_GK:
111
+ if HEAD_FIRST:
112
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
113
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
114
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
115
+ else:
116
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
117
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
118
+
119
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
120
+ b_h *= exp(b_gk_last)[:, None]
121
+
122
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
123
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
124
+
125
+ # vector decay, h = h @ Diag(gv)
126
+ if USE_GV:
127
+ if HEAD_FIRST:
128
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
130
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
131
+ else:
132
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
133
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
134
+
135
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
136
+ b_h *= exp(b_gv_last)[None, :]
137
+
138
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
139
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
140
+
141
+ b_h += tl.dot(b_k, b_v)
142
+
143
+ # if there are more than one splits, we store the result to (unreduced) hs
144
+ # otherwise, we store the result to ht as the final state
145
+ if NS > 1:
146
+ p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1))
148
+ elif STORE_FINAL_STATE:
149
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
155
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
156
+ })
157
+ @triton.autotune(
158
+ configs=[
159
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
160
+ for BK in [32, 64]
161
+ for BV in [32, 64]
162
+ for num_warps in [2, 4, 8]
163
+ for num_stages in [2, 3, 4]
164
+ ],
165
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
166
+ )
167
+ @triton.jit(do_not_specialize=['T'])
168
+ def chunk_fwd_kernel_h_reduction(
169
+ g,
170
+ gk,
171
+ gv,
172
+ hs,
173
+ hr,
174
+ ht,
175
+ offsets,
176
+ split_offsets,
177
+ T,
178
+ S: tl.constexpr,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_G: tl.constexpr,
186
+ USE_GK: tl.constexpr,
187
+ USE_GV: tl.constexpr,
188
+ STORE_FINAL_STATE: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr
191
+ ):
192
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
193
+ i_n, i_h = i_nh // H, i_nh % H
194
+ if USE_OFFSETS:
195
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
196
+ T = eos - bos
197
+ NS = tl.cdiv(T, S)
198
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
199
+ else:
200
+ bos, eos = i_n * T, i_n * T + T
201
+ NS = tl.cdiv(T, S)
202
+ boh = i_n * NS
203
+
204
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
205
+ # skip the first split
206
+ for i_s in range(1, NS):
207
+ p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
208
+ p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
209
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
210
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
211
+
212
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
213
+ last_idx = min(i_t * BT + BT, T) - 1
214
+ # scalar decay
215
+ if USE_G:
216
+ if HEAD_FIRST:
217
+ b_g_last = tl.load(g + i_nh * T + last_idx)
218
+ else:
219
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
220
+ b_h *= exp(b_g_last)
221
+
222
+ # vector decay, h = Diag(gk) @ h
223
+ if USE_GK:
224
+ if HEAD_FIRST:
225
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
226
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
227
+ else:
228
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
229
+
230
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
231
+ b_h *= exp(b_gk_last)[:, None]
232
+
233
+ # vector decay, h = h @ Diag(gv)
234
+ if USE_GV:
235
+ if HEAD_FIRST:
236
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
237
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
238
+ else:
239
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
240
+
241
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
242
+ b_h *= exp(b_gv_last)[None, :]
243
+
244
+ if NS > 1:
245
+ if STORE_FINAL_STATE:
246
+ p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
247
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
248
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
249
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
250
+
251
+
252
+ @triton.heuristics({
253
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
254
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
255
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
256
+ })
257
+ @triton.autotune(
258
+ configs=[
259
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
260
+ for BK in [32, 64]
261
+ for BV in [32, 64]
262
+ for num_warps in [2, 4, 8]
263
+ for num_stages in [2, 3]
264
+ ],
265
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
266
+ )
267
+ @triton.jit(do_not_specialize=['T'])
268
+ def chunk_bwd_kernel_dh_split(
269
+ q,
270
+ g,
271
+ gk,
272
+ gv,
273
+ do,
274
+ dht,
275
+ dhs,
276
+ dhr,
277
+ dh0,
278
+ offsets,
279
+ split_indices,
280
+ scale,
281
+ T,
282
+ S: tl.constexpr,
283
+ HQ: tl.constexpr,
284
+ H: tl.constexpr,
285
+ K: tl.constexpr,
286
+ V: tl.constexpr,
287
+ BT: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr,
290
+ NG: tl.constexpr,
291
+ USE_G: tl.constexpr,
292
+ USE_GK: tl.constexpr,
293
+ USE_GV: tl.constexpr,
294
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
295
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
296
+ USE_OFFSETS: tl.constexpr,
297
+ HEAD_FIRST: tl.constexpr
298
+ ):
299
+ # handle one split at a time
300
+ # i_h: head index
301
+ # i_n: sequence index
302
+ # i_s: local split index inside a sequence
303
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
304
+ i_ss, i_hq = i_sh // HQ, i_sh % HQ
305
+ if USE_OFFSETS:
306
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
307
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
308
+ T = eos - bos
309
+ NS = tl.cdiv(T, S)
310
+ else:
311
+ NS = tl.cdiv(T, S)
312
+ i_n, i_s = i_ss // NS, i_ss % NS
313
+ bos, eos = i_n * T, i_n * T + T
314
+ i_nh = i_n * HQ + i_hq
315
+ i_ng, i_h = i_nh // NG, i_hq // NG
316
+
317
+ # [BK, BV]
318
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
319
+ if i_s == NS - 1:
320
+ if USE_FINAL_STATE_GRADIENT:
321
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
323
+ p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
325
+
326
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
327
+ if HEAD_FIRST:
328
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
329
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
330
+ else:
331
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
333
+
334
+ b_q = tl.load(p_q, boundary_check=(0, 1))
335
+ b_q = (b_q * scale).to(b_q.dtype)
336
+ # [BT, BV]
337
+ b_do = tl.load(p_do, boundary_check=(0, 1))
338
+
339
+ last_idx = min(i_t * BT + BT, T) - 1
340
+ if USE_G:
341
+ if HEAD_FIRST:
342
+ p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT)
343
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
344
+ b_g_last = tl.load(g + i_ng * T + last_idx)
345
+ else:
346
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
347
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
348
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
349
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
350
+ b_dh *= exp(b_g_last)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
356
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
357
+ else:
358
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
359
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
360
+
361
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
362
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
363
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
364
+ b_dh *= exp(b_gk_last)[:, None]
365
+
366
+ if USE_GV:
367
+ if HEAD_FIRST:
368
+ p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
370
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
371
+ else:
372
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
374
+
375
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
376
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
377
+
378
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
379
+ b_dh *= exp(b_gv_last)[None, :]
380
+
381
+ b_dh += tl.dot(b_q, b_do)
382
+
383
+ if NS > 1:
384
+ p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
385
+ tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1))
386
+ elif STORE_INITIAL_STATE_GRADIENT:
387
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
388
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
389
+
390
+
391
+ @triton.heuristics({
392
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
393
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
394
+ })
395
+ @triton.autotune(
396
+ configs=[
397
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
398
+ for BK in [32, 64]
399
+ for BV in [32, 64]
400
+ for num_warps in [2, 4, 8]
401
+ for num_stages in [2, 3, 4]
402
+ ],
403
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
404
+ )
405
+ @triton.jit(do_not_specialize=['T'])
406
+ def chunk_bwd_kernel_dh_reduction(
407
+ g,
408
+ gk,
409
+ gv,
410
+ dhs,
411
+ dhr,
412
+ dh0,
413
+ offsets,
414
+ split_offsets,
415
+ T,
416
+ S: tl.constexpr,
417
+ H: tl.constexpr,
418
+ HQ: tl.constexpr,
419
+ K: tl.constexpr,
420
+ V: tl.constexpr,
421
+ BT: tl.constexpr,
422
+ BK: tl.constexpr,
423
+ BV: tl.constexpr,
424
+ NG: tl.constexpr,
425
+ USE_G: tl.constexpr,
426
+ USE_GK: tl.constexpr,
427
+ USE_GV: tl.constexpr,
428
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
429
+ USE_OFFSETS: tl.constexpr,
430
+ HEAD_FIRST: tl.constexpr
431
+ ):
432
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
433
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
434
+ i_ng, i_h = i_nh // NG, i_hq // NG
435
+ if USE_OFFSETS:
436
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
437
+ T = eos - bos
438
+ NS = tl.cdiv(T, S)
439
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
440
+ else:
441
+ bos, eos = i_n * T, i_n * T + T
442
+ NS = tl.cdiv(T, S)
443
+ boh = i_n * NS
444
+
445
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
446
+ for i_s in range(NS - 2, -1, -1):
447
+ p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
448
+ p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
449
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
450
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
451
+
452
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
453
+ last_idx = min(i_t * BT + BT, T) - 1
454
+ # scalar decay
455
+ if USE_G:
456
+ if HEAD_FIRST:
457
+ b_g_last = tl.load(g + i_ng * T + last_idx)
458
+ else:
459
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
460
+ b_dh *= exp(b_g_last)
461
+
462
+ if USE_GK:
463
+ if HEAD_FIRST:
464
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
465
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
466
+ else:
467
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
468
+
469
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
470
+ b_dh *= exp(b_gk_last)[:, None]
471
+
472
+ if USE_GV:
473
+ if HEAD_FIRST:
474
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
475
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
476
+ else:
477
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
478
+
479
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
480
+ b_dh *= exp(b_gv_last)[None, :]
481
+
482
+ if NS > 1:
483
+ if STORE_INITIAL_STATE_GRADIENT:
484
+ p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
485
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
486
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
487
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
488
+
489
+
490
+ def chunk_fwd_h(
491
+ k: torch.Tensor,
492
+ v: torch.Tensor,
493
+ g: torch.Tensor,
494
+ gk: torch.Tensor,
495
+ gv: torch.Tensor,
496
+ h0: torch.Tensor,
497
+ output_final_state: bool,
498
+ offsets: Optional[torch.LongTensor] = None,
499
+ split_offsets: Optional[torch.LongTensor] = None,
500
+ split_indices: Optional[torch.LongTensor] = None,
501
+ head_first: bool = True,
502
+ chunk_size: int = 64,
503
+ split_size: int = 256,
504
+ states_in_fp32: bool = True
505
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
506
+ if head_first:
507
+ B, H, T, K, V = *k.shape, v.shape[-1]
508
+ else:
509
+ B, T, H, K, V = *k.shape, v.shape[-1]
510
+ # B: batch size
511
+ # N: the actual number of sequences in the batch
512
+ # H: number of heads
513
+ # T: sequence length, can be variable across sequences
514
+ # S: split size, a multiple of chunk size
515
+ # BT: chunk size
516
+ S, BT = split_size, chunk_size
517
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
518
+ if offsets is None:
519
+ N = B
520
+ NS = N * triton.cdiv(T, S)
521
+ else:
522
+ N = len(offsets) - 1
523
+ NS = split_offsets[-1]
524
+
525
+ # unreduced kv states per split
526
+ hs = k.new_empty(NS, H, K, V, dtype=torch.float)
527
+ # reduced states per split
528
+ hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
529
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
530
+ # parallelized over splits
531
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H)
532
+ chunk_fwd_kernel_h_split[grid](
533
+ k=k,
534
+ v=v,
535
+ g=g,
536
+ gk=gk,
537
+ gv=gv,
538
+ hs=hs,
539
+ hr=hr,
540
+ h0=h0,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ split_indices=split_indices,
544
+ T=T,
545
+ S=S,
546
+ H=H,
547
+ K=K,
548
+ V=V,
549
+ BT=BT,
550
+ USE_G=g is not None,
551
+ USE_GK=gk is not None,
552
+ USE_GV=gv is not None,
553
+ HEAD_FIRST=head_first
554
+ )
555
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
556
+ chunk_fwd_kernel_h_reduction[grid](
557
+ g=g,
558
+ gk=gk,
559
+ gv=gv,
560
+ hs=hs,
561
+ hr=hr,
562
+ ht=ht,
563
+ offsets=offsets,
564
+ split_offsets=split_offsets,
565
+ T=T,
566
+ S=S,
567
+ H=H,
568
+ K=K,
569
+ V=V,
570
+ BT=BT,
571
+ USE_G=g is not None,
572
+ USE_GK=gk is not None,
573
+ USE_GV=gv is not None,
574
+ HEAD_FIRST=head_first
575
+ )
576
+ return hr, ht
577
+
578
+
579
+ def chunk_bwd_dh(
580
+ q: torch.Tensor,
581
+ k: torch.Tensor,
582
+ v: torch.Tensor,
583
+ g: torch.Tensor,
584
+ gk: torch.Tensor,
585
+ gv: torch.Tensor,
586
+ do: torch.Tensor,
587
+ h0: torch.Tensor,
588
+ dht: torch.Tensor,
589
+ scale: float,
590
+ offsets: Optional[torch.Tensor] = None,
591
+ split_offsets: Optional[torch.Tensor] = None,
592
+ split_indices: Optional[torch.Tensor] = None,
593
+ head_first: bool = True,
594
+ chunk_size: int = 64,
595
+ split_size: int = 256,
596
+ states_in_fp32: bool = True
597
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
598
+ if head_first:
599
+ B, H, T, K, V = *k.shape, v.shape[-1]
600
+ HQ = q.shape[1]
601
+ else:
602
+ B, T, H, K, V = *k.shape, v.shape[-1]
603
+ HQ = q.shape[2]
604
+ # B: batch size
605
+ # N: the actual number of sequences in the batch
606
+ # H: number of heads
607
+ # T: sequence length, can be variable across sequences
608
+ # S: split size, a multiple of chunk size
609
+ # BT: chunk size
610
+ S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size
611
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
612
+ if offsets is None:
613
+ N = B
614
+ NS = N * triton.cdiv(T, S)
615
+ else:
616
+ N = len(offsets) - 1
617
+ NS = split_offsets[-1]
618
+ # number of groups in GQA
619
+ NG = HQ // H
620
+
621
+ dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float)
622
+ dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
623
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
624
+
625
+ # parallelized over splits
626
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ)
627
+ chunk_bwd_kernel_dh_split[grid](
628
+ q=q,
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ do=do,
633
+ dht=dht,
634
+ dhs=dhs,
635
+ dhr=dhr,
636
+ dh0=dh0,
637
+ offsets=offsets,
638
+ split_indices=split_indices,
639
+ scale=scale,
640
+ T=T,
641
+ S=S,
642
+ HQ=HQ,
643
+ H=H,
644
+ K=K,
645
+ V=V,
646
+ BT=BT,
647
+ NG=NG,
648
+ USE_G=g is not None,
649
+ USE_GK=gk is not None,
650
+ USE_GV=gv is not None,
651
+ HEAD_FIRST=head_first,
652
+ )
653
+
654
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
655
+ chunk_bwd_kernel_dh_reduction[grid](
656
+ g=g,
657
+ gk=gk,
658
+ gv=gv,
659
+ dhs=dhs,
660
+ dhr=dhr,
661
+ dh0=dh0,
662
+ offsets=offsets,
663
+ split_offsets=split_offsets,
664
+ T=T,
665
+ S=S,
666
+ HQ=HQ,
667
+ H=H,
668
+ K=K,
669
+ V=V,
670
+ BT=BT,
671
+ NG=NG,
672
+ USE_G=g is not None,
673
+ USE_GK=gk is not None,
674
+ USE_GV=gv is not None,
675
+ HEAD_FIRST=head_first
676
+ )
677
+ return dhr, dh0
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ u,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr,
40
+ STORE_FINAL_STATE: tl.constexpr,
41
+ IS_BETA_HEADWISE: tl.constexpr,
42
+ USE_OFFSETS: tl.constexpr,
43
+ HEAD_FIRST: tl.constexpr
44
+ ):
45
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+
55
+ if HEAD_FIRST:
56
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
57
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
58
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
59
+ p_u = u + i_nh * T*V + i_v * BV + tl.arange(0, BV)
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
62
+ else:
63
+ p_beta = beta + i_nh * T
64
+ p_o = o + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
65
+ else:
66
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
67
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
68
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
69
+ p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
70
+ if IS_BETA_HEADWISE:
71
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
72
+ else:
73
+ p_beta = beta + bos * H + i_h
74
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+
76
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
77
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
78
+ mask_h = mask_k[None, :] & mask_v[:, None]
79
+
80
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
81
+ if USE_INITIAL_STATE:
82
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
83
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
84
+
85
+ for _ in range(0, T):
86
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
87
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
88
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
90
+ b_v -= b_v_minus
91
+ if IS_BETA_HEADWISE:
92
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
93
+ else:
94
+ b_beta = tl.load(p_beta).to(tl.float32)
95
+ tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
96
+ b_v *= b_beta
97
+ b_h += b_k[None, :] * b_v[:, None]
98
+ b_o = b_h * b_q[None, :]
99
+ b_o = tl.sum(b_o, axis=1)
100
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
101
+
102
+ p_q += K if HEAD_FIRST else H*K
103
+ p_k += K if HEAD_FIRST else H*K
104
+ p_o += V if HEAD_FIRST else H*V
105
+ p_v += V if HEAD_FIRST else H*V
106
+ p_u += V if HEAD_FIRST else H*V
107
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
117
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
118
+ })
119
+ @triton.jit(do_not_specialize=['T'])
120
+ def fused_recurrent_delta_rule_bwd_kernel(
121
+ q,
122
+ k,
123
+ v,
124
+ beta,
125
+ h0,
126
+ dh0,
127
+ dht,
128
+ do,
129
+ dq,
130
+ dk,
131
+ dv,
132
+ db,
133
+ offsets,
134
+ scale,
135
+ B: tl.constexpr,
136
+ T,
137
+ H: tl.constexpr,
138
+ K: tl.constexpr,
139
+ V: tl.constexpr,
140
+ BK: tl.constexpr,
141
+ BV: tl.constexpr,
142
+ NK: tl.constexpr,
143
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar
144
+ USE_INITIAL_STATE: tl.constexpr, # whether to use dh0
145
+ USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht
146
+ USE_OFFSETS: tl.constexpr,
147
+ HEAD_FIRST: tl.constexpr
148
+ ):
149
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
150
+ i_n, i_h = i_nh // H, i_nh % H
151
+ if USE_OFFSETS:
152
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
153
+ all = T
154
+ T = eos - bos
155
+ else:
156
+ bos, eos = i_n * T, i_n * T + T
157
+ all = B * T
158
+
159
+ mask_k = i_k * BK + tl.arange(0, BK) < K
160
+ mask_v = i_v * BV + tl.arange(0, BV) < V
161
+
162
+ if HEAD_FIRST:
163
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
164
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
165
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
166
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
167
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
168
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
169
+ if IS_BETA_HEADWISE:
170
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
171
+ p_dbeta = db + (i_v * NK*B*H + i_k * B*H + i_nh) * T*V + tl.arange(0, BV) + (T - 1) * V
172
+ else:
173
+ p_beta = beta + i_nh * T + T - 1
174
+ p_dbeta = db + (i_v * B*H + i_nh) * T + T - 1
175
+ else:
176
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
177
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
178
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
179
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
180
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
182
+ if IS_BETA_HEADWISE:
183
+ p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
184
+ p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV)
185
+ else:
186
+ p_beta = beta + (bos + T - 1) * H + i_h
187
+ p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h
188
+
189
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
190
+ if USE_FINAL_STATE_GRADIENT:
191
+ p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
192
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
193
+
194
+ for _ in range(T):
195
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
196
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
197
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
198
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
199
+ if IS_BETA_HEADWISE:
200
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
201
+ else:
202
+ b_beta = tl.load(p_beta).to(tl.float32)
203
+ b_dh += b_q[:, None] * b_do[None, :]
204
+ b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1)
205
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
206
+
207
+ b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v)
208
+ b_dv = b_dv * b_beta
209
+
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
212
+ if IS_BETA_HEADWISE:
213
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v)
214
+ else:
215
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty))
216
+
217
+ b_dh -= b_k[:, None] * b_dv[None, :]
218
+
219
+ p_q -= K if HEAD_FIRST else H*K
220
+ p_k -= K if HEAD_FIRST else H*K
221
+ p_v -= V if HEAD_FIRST else H*V
222
+ p_do -= V if HEAD_FIRST else H*V
223
+ p_dk -= K if HEAD_FIRST else H*K
224
+ p_dv -= V if HEAD_FIRST else H*V
225
+ p_dbeta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
226
+ p_beta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
227
+
228
+ if USE_INITIAL_STATE:
229
+ p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
230
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
231
+
232
+ tl.debug_barrier()
233
+
234
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
235
+
236
+ if HEAD_FIRST:
237
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
238
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
239
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
240
+ if IS_BETA_HEADWISE:
241
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
242
+ else:
243
+ p_beta = beta + i_nh * T
244
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV)
245
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
246
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
247
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
248
+ else:
249
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
250
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
251
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
252
+ if IS_BETA_HEADWISE:
253
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
254
+ else:
255
+ p_beta = beta + bos * H + i_h
256
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
257
+ p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
258
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
260
+
261
+ if USE_INITIAL_STATE:
262
+ mask_h = mask_k[:, None] & mask_v[None, :]
263
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
264
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
265
+
266
+ for _ in range(0, T):
267
+ b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32)
268
+ b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32)
269
+ b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1)
270
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
271
+
272
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
273
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
274
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
275
+ if IS_BETA_HEADWISE:
276
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
277
+ else:
278
+ b_beta = tl.load(p_beta).to(tl.float32)
279
+ b_v *= b_beta
280
+
281
+ b_h += b_k[:, None] * b_v[None, :]
282
+ b_dq = b_h * b_do[None, :]
283
+ d_q = tl.sum(b_dq, axis=1) * scale
284
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
285
+
286
+ p_k += K if HEAD_FIRST else H*K
287
+ p_v += V if HEAD_FIRST else H*V
288
+ p_do += V if HEAD_FIRST else H*V
289
+ p_dq += K if HEAD_FIRST else H*K
290
+ p_dk += K if HEAD_FIRST else H*K
291
+ p_dv += V if HEAD_FIRST else H*V
292
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
293
+
294
+
295
+ def fused_recurrent_delta_rule_fwd(
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ beta: torch.Tensor,
300
+ scale: float,
301
+ initial_state: torch.Tensor,
302
+ output_final_state: bool,
303
+ offsets: Optional[torch.LongTensor] = None,
304
+ head_first: bool = True
305
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ N = B if offsets is None else len(offsets) - 1
311
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
312
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
313
+ assert NK == 1, "NK > 1 is not supported yet"
314
+ num_stages = 1
315
+ num_warps = 1
316
+
317
+ o = q.new_empty(NK, *v.shape)
318
+ if output_final_state:
319
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
320
+ else:
321
+ final_state = None
322
+
323
+ grid = (NV, NK, N * H)
324
+ u = torch.empty_like(v)
325
+ fused_recurrent_delta_rule_fwd_kernel[grid](
326
+ q,
327
+ k,
328
+ v,
329
+ u,
330
+ beta,
331
+ o,
332
+ initial_state,
333
+ final_state,
334
+ offsets,
335
+ scale,
336
+ T=T,
337
+ B=B,
338
+ H=H,
339
+ K=K,
340
+ V=V,
341
+ BK=BK,
342
+ BV=BV,
343
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
344
+ HEAD_FIRST=head_first,
345
+ num_warps=num_warps,
346
+ num_stages=num_stages,
347
+ )
348
+ o = o.squeeze(0)
349
+ return o, u, final_state
350
+
351
+
352
+ def fused_recurrent_delta_rule_bwd(
353
+ q: torch.Tensor,
354
+ k: torch.Tensor,
355
+ v: torch.Tensor,
356
+ beta: torch.Tensor,
357
+ dht: torch.Tensor,
358
+ do: torch.Tensor,
359
+ scale: float,
360
+ initial_state: torch.Tensor,
361
+ offsets: Optional[torch.LongTensor] = None,
362
+ head_first: bool = True
363
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
364
+ if head_first:
365
+ B, H, T, K, V = *k.shape, v.shape[-1]
366
+ else:
367
+ B, T, H, K, V = *k.shape, v.shape[-1]
368
+ N = B if offsets is None else len(offsets) - 1
369
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
370
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
371
+ assert NK == 1, "NK > 1 is not supported yet"
372
+ num_stages = 1
373
+ num_warps = 2
374
+
375
+ beta_vector = beta.ndim == v.ndim
376
+
377
+ dq = q.new_empty(NV, *q.shape)
378
+ dk = q.new_empty(NV, *k.shape)
379
+ dv = q.new_empty(NK, *v.shape)
380
+ if beta_vector:
381
+ db = q.new_empty(NV, NK, B, H, T, V) if head_first else q.new_empty(NV, NK, B, T, H, V)
382
+ else:
383
+ db = q.new_empty(NV, B, H, T) if head_first else q.new_empty(NV, B, T, H)
384
+ grid = (NV, NK, N * H)
385
+
386
+ if initial_state is not None and initial_state.requires_grad:
387
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
388
+ else:
389
+ dh0 = None
390
+
391
+ fused_recurrent_delta_rule_bwd_kernel[grid](
392
+ q,
393
+ k,
394
+ v,
395
+ beta,
396
+ initial_state,
397
+ dh0,
398
+ dht,
399
+ do,
400
+ dq,
401
+ dk,
402
+ dv,
403
+ db,
404
+ offsets,
405
+ scale,
406
+ T=T,
407
+ B=B,
408
+ H=H,
409
+ K=K,
410
+ V=V,
411
+ BK=BK,
412
+ BV=BV,
413
+ NK=NK,
414
+ IS_BETA_HEADWISE=beta_vector,
415
+ HEAD_FIRST=head_first,
416
+ num_warps=num_warps,
417
+ num_stages=num_stages
418
+ )
419
+ dq = dq.sum(0)
420
+ dk = dk.sum(0)
421
+ dv = dv.sum(0)
422
+ db = db.sum((0, 1)) if beta_vector else db.sum(0)
423
+
424
+ return dq, dk, dv, db, dh0
425
+
426
+
427
+ class FusedRecurrentFunction(torch.autograd.Function):
428
+
429
+ @staticmethod
430
+ @input_guard
431
+ def forward(
432
+ ctx,
433
+ q: torch.Tensor,
434
+ k: torch.Tensor,
435
+ v: torch.Tensor,
436
+ beta: torch.Tensor,
437
+ scale: float,
438
+ initial_state: torch.Tensor,
439
+ output_final_state: bool,
440
+ offsets: Optional[torch.LongTensor] = None,
441
+ head_first: bool = True,
442
+ use_qk_l2norm_in_kernel: bool = False
443
+ ):
444
+ q_orig = q
445
+ k_orig = k
446
+
447
+ if use_qk_l2norm_in_kernel:
448
+ q = l2norm_fwd(q)
449
+ k = l2norm_fwd(k)
450
+
451
+ o, u, final_state = fused_recurrent_delta_rule_fwd(
452
+ q=q,
453
+ k=k,
454
+ v=v,
455
+ beta=beta,
456
+ scale=scale,
457
+ initial_state=initial_state,
458
+ output_final_state=output_final_state,
459
+ offsets=offsets,
460
+ head_first=head_first
461
+ )
462
+
463
+ ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state)
464
+ ctx.scale = scale
465
+ ctx.offsets = offsets
466
+ ctx.head_first = head_first
467
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
468
+ return o, final_state
469
+
470
+ @staticmethod
471
+ @input_guard
472
+ def backward(ctx, do, dht):
473
+ q, k, v, beta, initial_state = ctx.saved_tensors
474
+ if ctx.use_qk_l2norm_in_kernel:
475
+ q, q_orig = l2norm_fwd(q), q
476
+ k, k_orig = l2norm_fwd(k), k
477
+ dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd(
478
+ q=q,
479
+ k=k,
480
+ v=v,
481
+ beta=beta,
482
+ dht=dht,
483
+ do=do,
484
+ scale=ctx.scale,
485
+ initial_state=initial_state,
486
+ offsets=ctx.offsets,
487
+ head_first=ctx.head_first
488
+ )
489
+ if ctx.use_qk_l2norm_in_kernel:
490
+ dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk)
491
+ return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None, None
492
+
493
+
494
+ @torch.compiler.disable
495
+ def fused_recurrent_delta_rule(
496
+ q: torch.Tensor,
497
+ k: torch.Tensor,
498
+ v: torch.Tensor,
499
+ beta: torch.Tensor = None,
500
+ scale: float = None,
501
+ initial_state: torch.Tensor = None,
502
+ output_final_state: bool = False,
503
+ cu_seqlens: Optional[torch.LongTensor] = None,
504
+ head_first: bool = True,
505
+ use_qk_l2norm_in_kernel: bool = False
506
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
507
+ r"""
508
+ Args:
509
+ q (torch.Tensor):
510
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
511
+ k (torch.Tensor):
512
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
513
+ v (torch.Tensor):
514
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
515
+ beta (torch.Tensor):
516
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
517
+ scale (Optional[int]):
518
+ Scale factor for the RetNet attention scores.
519
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
520
+ initial_state (Optional[torch.Tensor]):
521
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
522
+ For equal-length input sequences, `N` equals the batch size `B`.
523
+ Default: `None`.
524
+ output_final_state (Optional[bool]):
525
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
526
+ cu_seqlens (torch.LongTensor):
527
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
528
+ consistent with the FlashAttention API.
529
+ head_first (Optional[bool]):
530
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
531
+ Default: `False`.
532
+
533
+ Returns:
534
+ o (torch.Tensor):
535
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
536
+ final_state (torch.Tensor):
537
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
538
+
539
+ Examples::
540
+ >>> import torch
541
+ >>> import torch.nn.functional as F
542
+ >>> from einops import rearrange
543
+ >>> from fla.ops.delta_rule import fused_recurrent_delta_rule
544
+ # inputs with equal lengths
545
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
546
+ >>> q = torch.randn(B, T, H, K, device='cuda')
547
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
548
+ >>> v = torch.randn(B, T, H, V, device='cuda')
549
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
550
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
551
+ >>> o, ht = fused_recurrent_delta_rule(
552
+ q, k, v, beta,
553
+ initial_state=h0,
554
+ output_final_state=True
555
+ )
556
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
557
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
558
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
559
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
560
+ >>> o_var, ht_var = fused_recurrent_delta_rule(
561
+ q, k, v, beta,
562
+ initial_state=h0,
563
+ output_final_state=True,
564
+ cu_seqlens=cu_seqlens
565
+ )
566
+ >>> assert o.allclose(o_var.view(o.shape))
567
+ >>> assert ht.allclose(ht_var)
568
+ """
569
+ if cu_seqlens is not None:
570
+ if q.shape[0] != 1:
571
+ raise ValueError(
572
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
573
+ f"Please flatten variable-length inputs before processing."
574
+ )
575
+ if head_first:
576
+ raise RuntimeError(
577
+ "Sequences with variable lengths are not supported for head-first mode"
578
+ )
579
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
580
+ raise ValueError(
581
+ f"The number of initial states is expected to be equal to the number of input sequences, "
582
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
583
+ )
584
+ if scale is None:
585
+ scale = k.shape[-1] ** -0.5
586
+ else:
587
+ assert scale > 0, "scale must be positive"
588
+ if beta is None:
589
+ beta = torch.ones_like(q[..., 0])
590
+ if head_first:
591
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
592
+ beta = rearrange(beta, 'b h t -> b t h')
593
+ o, final_state = FusedRecurrentFunction.apply(
594
+ q,
595
+ k,
596
+ v,
597
+ beta,
598
+ scale,
599
+ initial_state,
600
+ output_final_state,
601
+ cu_seqlens,
602
+ False,
603
+ use_qk_l2norm_in_kernel
604
+ )
605
+ if head_first:
606
+ o = rearrange(o, 'b t h v -> b h t v')
607
+ return o, final_state
fla/ops/delta_rule/wy_fast.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
11
+ from fla.ops.utils.solve_tril import solve_tril
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fwd_recompute_w_u_kernel(
30
+ k,
31
+ v,
32
+ beta,
33
+ w,
34
+ u,
35
+ A,
36
+ offsets,
37
+ indices,
38
+ T,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ HEAD_FIRST: tl.constexpr,
46
+ USE_OFFSETS: tl.constexpr
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if USE_OFFSETS:
51
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
53
+ T = eos - bos
54
+ else:
55
+ bos, eos = i_b * T, i_b * T + T
56
+
57
+ if HEAD_FIRST:
58
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
59
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
60
+ else:
61
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
63
+ b_beta = tl.load(p_beta, boundary_check=(0,))
64
+ b_A = tl.load(p_A, boundary_check=(0, 1))
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ if HEAD_FIRST:
68
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
69
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
70
+ else:
71
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
72
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ b_v = tl.load(p_v, boundary_check=(0, 1))
74
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
75
+ b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False)
76
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ for i_k in range(tl.cdiv(K, BK)):
79
+ if HEAD_FIRST:
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ else:
83
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
84
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
85
+ b_k = tl.load(p_k, boundary_check=(0, 1))
86
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
87
+ b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False)
88
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
89
+
90
+
91
+ @triton.heuristics({
92
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
93
+ })
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
97
+ for num_warps in NUM_WARPS
98
+ for num_stages in [2, 3, 4]
99
+ ],
100
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
101
+ )
102
+ @triton.jit(do_not_specialize=['T'])
103
+ def bwd_prepare_wy_repr_kernel(
104
+ k,
105
+ v,
106
+ beta,
107
+ A,
108
+ dw,
109
+ du,
110
+ dk,
111
+ dv,
112
+ dbeta,
113
+ offsets,
114
+ indices,
115
+ T,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BT: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ HEAD_FIRST: tl.constexpr,
123
+ USE_OFFSETS: tl.constexpr
124
+ ):
125
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
126
+ i_b, i_h = i_bh // H, i_bh % H
127
+ if USE_OFFSETS:
128
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
129
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
130
+ T = eos - bos
131
+ else:
132
+ bos, eos = i_b * T, i_b * T + T
133
+
134
+ if HEAD_FIRST:
135
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
136
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
137
+ else:
138
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
139
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
140
+
141
+ b_beta = tl.load(p_beta, boundary_check=(0,))
142
+ b_A = tl.load(p_A, boundary_check=(0, 1))
143
+
144
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
145
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
146
+ for i_v in range(tl.cdiv(V, BV)):
147
+ if HEAD_FIRST:
148
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
149
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
150
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
151
+ else:
152
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
154
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
155
+
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
158
+ b_du = tl.load(p_du, boundary_check=(0, 1))
159
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
160
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
161
+ b_dv = b_dv_beta * b_beta[:, None]
162
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
163
+
164
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ for i_k in range(tl.cdiv(K, BK)):
167
+ if HEAD_FIRST:
168
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
170
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
171
+ else:
172
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
173
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
174
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
177
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
178
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
179
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
180
+ b_dk = b_dk_beta * b_beta[:, None]
181
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
182
+
183
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
184
+
185
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
186
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
187
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
188
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
189
+
190
+ for i_k in range(tl.cdiv(K, BK)):
191
+ if HEAD_FIRST:
192
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
193
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ else:
195
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ b_k = tl.load(p_k, boundary_check=(0, 1))
198
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
199
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
200
+
201
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
202
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
203
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
204
+ b_dk += b_dk_beta * b_beta[:, None]
205
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
206
+
207
+ if HEAD_FIRST:
208
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
209
+ else:
210
+ p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
211
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
212
+
213
+
214
+ def fwd_prepare_wy_repr(
215
+ k: torch.Tensor,
216
+ v: torch.Tensor,
217
+ beta: torch.Tensor,
218
+ offsets: Optional[torch.LongTensor],
219
+ indices: Optional[torch.LongTensor],
220
+ head_first: bool = False,
221
+ chunk_size: int = 64
222
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
+ A = chunk_scaled_dot_kkt_fwd(
224
+ k=k,
225
+ beta=beta,
226
+ cu_seqlens=offsets,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ output_dtype=torch.float32
230
+ )
231
+ A = solve_tril(
232
+ A=A,
233
+ cu_seqlens=offsets,
234
+ head_first=head_first,
235
+ output_dtype=k.dtype
236
+ )
237
+
238
+ w, u = fwd_recompute_w_u(
239
+ k=k,
240
+ v=v,
241
+ beta=beta,
242
+ A=A,
243
+ offsets=offsets,
244
+ indices=indices,
245
+ head_first=head_first,
246
+ chunk_size=chunk_size
247
+ )
248
+ return w, u, A
249
+
250
+
251
+ def fwd_recompute_w_u(
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ beta: torch.Tensor,
255
+ A: torch.Tensor,
256
+ offsets: Optional[torch.LongTensor],
257
+ indices: Optional[torch.LongTensor],
258
+ head_first: bool,
259
+ chunk_size: int
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ if head_first:
262
+ B, H, T, K, V = *k.shape, v.shape[-1]
263
+ else:
264
+ B, T, H, K, V = *k.shape, v.shape[-1]
265
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
266
+ CONST_TILING = 64 if check_shared_mem() else 32
267
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
268
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
269
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
270
+
271
+ u = torch.empty_like(v)
272
+ w = torch.empty_like(k)
273
+ fwd_recompute_w_u_kernel[(NT, B*H)](
274
+ k,
275
+ v,
276
+ beta,
277
+ w,
278
+ u,
279
+ A,
280
+ offsets=offsets,
281
+ indices=indices,
282
+ T=T,
283
+ H=H,
284
+ K=K,
285
+ V=V,
286
+ BT=BT,
287
+ BK=BK,
288
+ BV=BV,
289
+ HEAD_FIRST=head_first
290
+ )
291
+ return w, u
292
+
293
+
294
+ def bwd_prepare_wy_repr(
295
+ k: torch.Tensor,
296
+ v: torch.Tensor,
297
+ beta: torch.Tensor,
298
+ A: torch.Tensor,
299
+ dw: torch.Tensor,
300
+ du: torch.Tensor,
301
+ offsets: Optional[torch.LongTensor],
302
+ indices: Optional[torch.LongTensor],
303
+ head_first: bool,
304
+ chunk_size: int
305
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
311
+ CONST_TILING = 64 if check_shared_mem() else 32
312
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
313
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
314
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
315
+
316
+ dk = torch.empty_like(k)
317
+ dv = torch.empty_like(v)
318
+ dbeta = torch.empty_like(beta)
319
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
320
+ k,
321
+ v,
322
+ beta,
323
+ A,
324
+ dw,
325
+ du,
326
+ dk,
327
+ dv,
328
+ dbeta,
329
+ offsets=offsets,
330
+ indices=indices,
331
+ T=T,
332
+ H=H,
333
+ K=K,
334
+ V=V,
335
+ BT=BT,
336
+ BK=BK,
337
+ BV=BV,
338
+ HEAD_FIRST=head_first
339
+ )
340
+ return dk, dv, dbeta
fla/ops/forgetting_attn/parallel.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
13
+ from fla.ops.utils.op import div, exp, log
14
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
24
+ for num_stages in [2, 3, 4, 5]
25
+ ],
26
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
27
+ )
28
+ @triton.jit
29
+ def parallel_forgetting_attn_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ g,
34
+ o,
35
+ lse,
36
+ scale,
37
+ offsets,
38
+ indices,
39
+ T,
40
+ B: tl.constexpr,
41
+ H: tl.constexpr,
42
+ HQ: tl.constexpr,
43
+ G: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BS: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr
51
+ ):
52
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
54
+ i_h = i_hq // G
55
+
56
+ if USE_OFFSETS:
57
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
58
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ else:
61
+ i_n = i_b
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
65
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
66
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
67
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
68
+
69
+ # the Q block is kept in the shared memory throughout the whole kernel
70
+ # [BT, BK]
71
+ b_q = tl.load(p_q, boundary_check=(0, 1))
72
+ b_q = (b_q * scale).to(b_q.dtype)
73
+ # [BT,]
74
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
75
+ # [BT, BV]
76
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
77
+
78
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
79
+ b_acc = tl.zeros([BT], dtype=tl.float32)
80
+
81
+ # [BT]
82
+ o_q = i_t * BT + tl.arange(0, BT)
83
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
84
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
86
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
87
+
88
+ # [BS]
89
+ o_k = i_s + tl.arange(0, BS)
90
+ # [BK, BS]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BS, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ # [BS,]
95
+ b_gk = tl.load(p_gk, boundary_check=(0,))
96
+ # [BT, BS]
97
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :]
98
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
99
+
100
+ # [BT]
101
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
102
+ b_r = exp(b_mp - b_m)
103
+ # [BT, BS]
104
+ b_p = exp(b_s - b_m[:, None])
105
+ # [BT]
106
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
107
+ # [BT, BV]
108
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
109
+
110
+ b_mp = b_m
111
+
112
+ for i_s in range(i_t * BT - BS, -BS, -BS):
113
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
114
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
115
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
116
+
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BS,]
122
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
123
+
124
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
125
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
126
+ # [BT, BS]
127
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :]
128
+
129
+ b_gq += b_gn - b_gp
130
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
131
+ b_r = exp(b_mp - b_m)
132
+ # [BT, BS]
133
+ b_p = exp(b_s - b_m[:, None])
134
+ # [BT]
135
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
136
+ # [BT, BV]
137
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
138
+
139
+ b_mp = b_m
140
+
141
+ b_o = div(b_o, b_acc[:, None])
142
+ b_m += log(b_acc)
143
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
144
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
145
+
146
+
147
+ @triton.jit
148
+ def parallel_forgetting_attn_bwd_kernel_preprocess(
149
+ o,
150
+ do,
151
+ delta,
152
+ B: tl.constexpr,
153
+ V: tl.constexpr
154
+ ):
155
+ i_n = tl.program_id(0)
156
+ o_d = tl.arange(0, B)
157
+ m_d = o_d < V
158
+
159
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
160
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
161
+ b_delta = tl.sum(b_o * b_do)
162
+
163
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
164
+
165
+
166
+ @triton.heuristics({
167
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
168
+ })
169
+ @triton.autotune(
170
+ configs=[
171
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
172
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
173
+ for num_stages in [2, 3, 4]
174
+ ],
175
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
176
+ )
177
+ @triton.jit(do_not_specialize=['T'])
178
+ def parallel_forgetting_attn_bwd_kernel_dq(
179
+ q,
180
+ k,
181
+ v,
182
+ g,
183
+ lse,
184
+ delta,
185
+ do,
186
+ dq,
187
+ dg,
188
+ scale,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ B: tl.constexpr,
193
+ H: tl.constexpr,
194
+ HQ: tl.constexpr,
195
+ G: tl.constexpr,
196
+ K: tl.constexpr,
197
+ V: tl.constexpr,
198
+ BT: tl.constexpr,
199
+ BS: tl.constexpr,
200
+ BK: tl.constexpr,
201
+ BV: tl.constexpr,
202
+ USE_OFFSETS: tl.constexpr
203
+ ):
204
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
206
+ i_h = i_hq // G
207
+
208
+ if USE_OFFSETS:
209
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
210
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
211
+ T = eos - bos
212
+ else:
213
+ i_n = i_b
214
+ bos, eos = i_n * T, i_n * T + T
215
+
216
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
217
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
218
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
219
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
220
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
221
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
222
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
223
+
224
+ # [BT, BK]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale).to(b_q.dtype)
227
+ # [BT, BV]
228
+ b_do = tl.load(p_do, boundary_check=(0, 1))
229
+ # [BT]
230
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
231
+ b_lse = tl.load(p_lse, boundary_check=(0,))
232
+ b_delta = tl.load(p_delta, boundary_check=(0,))
233
+
234
+ # [BT]
235
+ o_q = i_t * BT + tl.arange(0, BT)
236
+ # [BT, BK]
237
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
238
+ # [BT]
239
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
240
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
241
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
242
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
243
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
244
+
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+ # [BS,]
252
+ b_gk = tl.load(p_gk, boundary_check=(0,))
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :]
255
+ b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')))
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
260
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
261
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
262
+ # [BT]
263
+ b_dg += tl.sum(b_ds, 1)
264
+
265
+ for i_s in range(i_t * BT - BS, -BS, -BS):
266
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
267
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
268
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
269
+
270
+ # [BK, BS]
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ # [BV, BS]
273
+ b_v = tl.load(p_v, boundary_check=(0, 1))
274
+ # [BS,]
275
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
276
+
277
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
278
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
279
+ # [BT, BS]
280
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :]
281
+ b_p = exp(b_s)
282
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
283
+ b_dp = tl.dot(b_do, b_v)
284
+ b_ds = b_p * (b_dp - b_delta[:, None])
285
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
286
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
287
+ # [BT]
288
+ b_dg += tl.sum(b_ds, 1)
289
+
290
+ b_gq += b_gn - b_gp
291
+
292
+ b_dq *= scale
293
+
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
296
+
297
+
298
+ @triton.heuristics({
299
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
300
+ })
301
+ @triton.autotune(
302
+ configs=[
303
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
304
+ for num_warps in [1, 2, 4, 8]
305
+ for num_stages in [2, 3, 4]
306
+ ],
307
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
308
+ )
309
+ @triton.jit(do_not_specialize=['T'])
310
+ def parallel_forgetting_attn_bwd_kernel_dkv(
311
+ q,
312
+ k,
313
+ v,
314
+ g,
315
+ lse,
316
+ delta,
317
+ do,
318
+ dk,
319
+ dv,
320
+ dg,
321
+ offsets,
322
+ indices,
323
+ scale,
324
+ T,
325
+ B: tl.constexpr,
326
+ H: tl.constexpr,
327
+ HQ: tl.constexpr,
328
+ G: tl.constexpr,
329
+ K: tl.constexpr,
330
+ V: tl.constexpr,
331
+ BT: tl.constexpr,
332
+ BS: tl.constexpr,
333
+ BK: tl.constexpr,
334
+ BV: tl.constexpr,
335
+ USE_OFFSETS: tl.constexpr
336
+ ):
337
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
338
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
339
+ i_h = i_hq // G
340
+
341
+ if USE_OFFSETS:
342
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
343
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
344
+ T = eos - bos
345
+ else:
346
+ i_n = i_b
347
+ bos, eos = i_n * T, i_n * T + T
348
+
349
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
350
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
351
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
352
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
353
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
355
+
356
+ # [BT, BK]
357
+ b_k = tl.load(p_k, boundary_check=(0, 1))
358
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
359
+ # [BT, BV]
360
+ b_v = tl.load(p_v, boundary_check=(0, 1))
361
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
362
+ # [BT]
363
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
364
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
365
+
366
+ o_k = i_t * BT + tl.arange(0, BT)
367
+ m_k = o_k < T
368
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
374
+
375
+ # [BS]
376
+ o_q = i_s + tl.arange(0, BS)
377
+ # [BS, BK]
378
+ b_q = tl.load(p_q, boundary_check=(0, 1))
379
+ b_q = (b_q * scale).to(b_q.dtype)
380
+ # [BS, BV]
381
+ b_do = tl.load(p_do, boundary_check=(0, 1))
382
+ # [BS]
383
+ b_lse = tl.load(p_lse, boundary_check=(0,))
384
+ b_delta = tl.load(p_delta, boundary_check=(0,))
385
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
386
+
387
+ m_q = o_q < T
388
+ m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :]
389
+ # [BT, BS]
390
+ b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :]
391
+ b_p = tl.where(m_s, exp(b_s), 0)
392
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
393
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
394
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
395
+ b_dp = tl.dot(b_v, tl.trans(b_do))
396
+ # [BT, BS]
397
+ b_ds = b_p * (b_dp - b_delta[None, :])
398
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
399
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
400
+ # [BT]
401
+ b_dg -= tl.sum(b_ds, 1)
402
+
403
+ b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32)
404
+ for i_s in range((i_t + 1) * BT, T, BS):
405
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
406
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
407
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
408
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
409
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
410
+
411
+ # [BS]
412
+ o_q = i_s + tl.arange(0, BS)
413
+ # [BS, BK]
414
+ b_q = tl.load(p_q, boundary_check=(0, 1))
415
+ b_q = (b_q * scale).to(b_q.dtype)
416
+ # [BS, BV]
417
+ b_do = tl.load(p_do, boundary_check=(0, 1))
418
+ # [BS]
419
+ b_lse = tl.load(p_lse, boundary_check=(0,))
420
+ b_delta = tl.load(p_delta, boundary_check=(0,))
421
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
422
+
423
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
424
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
425
+ # [BT, BS]
426
+ b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :]
427
+ b_p = exp(b_s)
428
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
429
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
430
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
431
+ b_dp = tl.dot(b_v, tl.trans(b_do))
432
+ # [BT, BS]
433
+ b_ds = b_p * (b_dp - b_delta[None, :])
434
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
435
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
436
+ # [BT]
437
+ b_dg -= tl.sum(b_ds, 1)
438
+
439
+ b_gk -= b_gn - b_gp
440
+
441
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
442
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
443
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
444
+
445
+
446
+ def parallel_forgetting_attn_fwd(
447
+ q: torch.Tensor,
448
+ k: torch.Tensor,
449
+ v: torch.Tensor,
450
+ g: torch.Tensor,
451
+ scale: float,
452
+ chunk_size: int = 128,
453
+ offsets: Optional[torch.LongTensor] = None,
454
+ indices: Optional[torch.LongTensor] = None,
455
+ ):
456
+ B, T, H, K, V = *k.shape, v.shape[-1]
457
+ HQ = q.shape[2]
458
+ G = HQ // H
459
+ BT = chunk_size
460
+ BK = max(16, triton.next_power_of_2(K))
461
+ assert V <= 256, "V must be less than or equal to 256"
462
+ if check_shared_mem('hopper'):
463
+ BS = min(64, max(16, triton.next_power_of_2(T)))
464
+ else:
465
+ BS = min(32, max(16, triton.next_power_of_2(T)))
466
+ BV = min(256, max(16, triton.next_power_of_2(V)))
467
+ NV = triton.cdiv(V, BV)
468
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
469
+
470
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
471
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
472
+
473
+ grid = (NV, NT, B * HQ)
474
+ parallel_forgetting_attn_fwd_kernel[grid](
475
+ q=q,
476
+ k=k,
477
+ v=v,
478
+ g=g,
479
+ o=o,
480
+ lse=lse,
481
+ scale=scale,
482
+ offsets=offsets,
483
+ indices=indices,
484
+ B=B,
485
+ T=T,
486
+ H=H,
487
+ HQ=HQ,
488
+ G=G,
489
+ K=K,
490
+ V=V,
491
+ BT=BT,
492
+ BS=BS,
493
+ BK=BK,
494
+ BV=BV,
495
+ )
496
+ return o, lse
497
+
498
+
499
+ def parallel_forgetting_attn_bwd_preprocess(
500
+ o: torch.Tensor,
501
+ do: torch.Tensor
502
+ ):
503
+ V = o.shape[-1]
504
+ delta = torch.empty_like(o[..., 0], dtype=torch.float)
505
+ parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)](
506
+ o=o,
507
+ do=do,
508
+ delta=delta,
509
+ B=triton.next_power_of_2(V),
510
+ V=V,
511
+ )
512
+ return delta
513
+
514
+
515
+ def parallel_forgetting_attn_bwd(
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ g: torch.Tensor,
520
+ o: torch.Tensor,
521
+ lse: torch.Tensor,
522
+ do: torch.Tensor,
523
+ scale: float = None,
524
+ chunk_size: int = 128,
525
+ offsets: Optional[torch.LongTensor] = None,
526
+ indices: Optional[torch.LongTensor] = None,
527
+ ):
528
+ B, T, H, K, V = *k.shape, v.shape[-1]
529
+ HQ = q.shape[2]
530
+ G = HQ // H
531
+ BT = chunk_size
532
+ BS = min(32, max(16, triton.next_power_of_2(T)))
533
+ BK = max(16, triton.next_power_of_2(K))
534
+ BV = max(16, triton.next_power_of_2(V))
535
+ NV = triton.cdiv(V, BV)
536
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
537
+
538
+ delta = parallel_forgetting_attn_bwd_preprocess(o, do)
539
+ dq = q.new_empty(B, T, HQ, K, dtype=q.dtype)
540
+ dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float)
541
+ dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float)
542
+ dg = q.new_empty(g.shape, dtype=torch.float)
543
+ # NOTE: the original `dg` can be destroyed during autotuning
544
+ # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?)
545
+ # so we need to make a copy of `dg`
546
+ dg2 = q.new_empty(g.shape, dtype=torch.float)
547
+ grid = (NV, NT, B * HQ)
548
+ parallel_forgetting_attn_bwd_kernel_dq[grid](
549
+ q=q,
550
+ k=k,
551
+ v=v,
552
+ g=g,
553
+ lse=lse,
554
+ delta=delta,
555
+ do=do,
556
+ dq=dq,
557
+ dg=dg,
558
+ offsets=offsets,
559
+ indices=indices,
560
+ scale=scale,
561
+ T=T,
562
+ B=B,
563
+ H=H,
564
+ HQ=HQ,
565
+ G=G,
566
+ K=K,
567
+ V=V,
568
+ BT=BT,
569
+ BS=BS,
570
+ BK=BK,
571
+ BV=BV
572
+ )
573
+ parallel_forgetting_attn_bwd_kernel_dkv[grid](
574
+ q=q,
575
+ k=k,
576
+ v=v,
577
+ g=g,
578
+ lse=lse,
579
+ delta=delta,
580
+ do=do,
581
+ dk=dk,
582
+ dv=dv,
583
+ dg=dg2,
584
+ offsets=offsets,
585
+ indices=indices,
586
+ scale=scale,
587
+ T=T,
588
+ B=B,
589
+ H=H,
590
+ HQ=HQ,
591
+ G=G,
592
+ K=K,
593
+ V=V,
594
+ BT=BT,
595
+ BS=BS,
596
+ BK=BK,
597
+ BV=BV
598
+ )
599
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
600
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
601
+ dg = dg.add_(dg2)
602
+ return dq, dk, dv, dg
603
+
604
+
605
+ @torch.compile
606
+ class ParallelForgettingAttentionFunction(torch.autograd.Function):
607
+
608
+ @staticmethod
609
+ @input_guard
610
+ @autocast_custom_fwd
611
+ def forward(ctx, q, k, v, g, scale, offsets):
612
+ ctx.dtype = q.dtype
613
+ if check_shared_mem('hopper'):
614
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
615
+ else:
616
+ chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1])))
617
+ # 2-d indices denoting the offsets of chunks in each sequence
618
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
619
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
620
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
621
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
622
+
623
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False)
624
+ o, lse = parallel_forgetting_attn_fwd(
625
+ q=q,
626
+ k=k,
627
+ v=v,
628
+ g=g,
629
+ scale=scale,
630
+ chunk_size=chunk_size,
631
+ offsets=offsets,
632
+ indices=indices
633
+ )
634
+ ctx.save_for_backward(q, k, v, g, o, lse)
635
+ ctx.chunk_size = chunk_size
636
+ ctx.offsets = offsets
637
+ ctx.indices = indices
638
+ ctx.scale = scale
639
+ return o.to(q.dtype)
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do):
645
+ q, k, v, g, o, lse = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_forgetting_attn_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ o=o,
652
+ lse=lse,
653
+ do=do,
654
+ scale=ctx.scale,
655
+ chunk_size=ctx.chunk_size,
656
+ offsets=ctx.offsets,
657
+ indices=ctx.indices
658
+ )
659
+ dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets)
660
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None
661
+
662
+
663
+ def parallel_forgetting_attn(
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ g: torch.Tensor,
668
+ scale: Optional[float] = None,
669
+ cu_seqlens: Optional[torch.LongTensor] = None,
670
+ head_first: bool = False
671
+ ) -> torch.Tensor:
672
+ r"""
673
+ Args:
674
+ q (torch.Tensor):
675
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
676
+ k (torch.Tensor):
677
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
678
+ GQA will be applied if HQ is divisible by H.
679
+ v (torch.Tensor):
680
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
681
+ g (torch.Tensor):
682
+ Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
683
+ scale (Optional[int]):
684
+ Scale factor for attention scores.
685
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
686
+ cu_seqlens (torch.LongTensor):
687
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
688
+ consistent with the FlashAttention API.
689
+ head_first (Optional[bool]):
690
+ Whether the inputs are in the head-first format. Default: `False`.
691
+
692
+ Returns:
693
+ o (torch.Tensor):
694
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
695
+ """
696
+ if scale is None:
697
+ scale = k.shape[-1] ** -0.5
698
+ if cu_seqlens is not None:
699
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
700
+ if g is not None:
701
+ g = g.float()
702
+ if head_first:
703
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
704
+ g = rearrange(g, 'b h t -> b t h')
705
+ o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens)
706
+ if head_first:
707
+ o = rearrange(o, 'b t h d -> b h t d')
708
+ return o
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (21.3 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
20
+ for BK in [32, 64]
21
+ for num_warps in [2, 4, 8, 16]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BC', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_inter(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi, # cumsum
34
+ ge, # before cumsum
35
+ Aqk,
36
+ Aqb,
37
+ Aab,
38
+ Aak,
39
+ offsets,
40
+ indices,
41
+ scale: tl.constexpr,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BC: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ NC: tl.constexpr,
49
+ USE_OFFSETS: tl.constexpr,
50
+ HEAD_FIRST: tl.constexpr,
51
+ ):
52
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_h = i_bh // H, i_bh % H
54
+ i_i, i_j = i_c // NC, i_c % NC
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if i_t * BT + i_i * BC >= T:
63
+ return
64
+ if i_i <= i_j:
65
+ return
66
+
67
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
68
+ b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
69
+ b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
70
+ b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
71
+ for i_k in range(tl.cdiv(K, BK)):
72
+ o_k = i_k * BK + tl.arange(0, BK)
73
+ m_k = o_k < K
74
+
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
77
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
78
+ p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
83
+ p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
84
+ else:
85
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
87
+ p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
90
+ p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
91
+ p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
92
+ p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
93
+ # [BK,]
94
+ b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
95
+ # [BC, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_a = tl.load(p_a, boundary_check=(0, 1))
98
+ b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
99
+ b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
100
+ b_ag = b_a * exp(b_gq_e - b_gn[None, :])
101
+ b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
102
+ # [BK, BC]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ b_b = tl.load(p_b, boundary_check=(0, 1))
105
+ b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
106
+ tmp = exp(b_gn[:, None] - b_gk)
107
+ b_kg = b_k * tmp
108
+ b_bg = b_b * tmp
109
+ # [BC, BC] using tf32 to improve precision here.
110
+ b_Aab += tl.dot(b_ag, b_bg)
111
+ b_Aak += tl.dot(b_ag, b_kg)
112
+ b_Aqk += tl.dot(b_qg, b_kg)
113
+ b_Aqb += tl.dot(b_qg, b_bg)
114
+
115
+ if HEAD_FIRST:
116
+ p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
117
+ p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
118
+ p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
119
+ p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ else:
121
+ p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
122
+ p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
123
+ p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
124
+ p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
125
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
126
+ tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
127
+ tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
128
+ tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
129
+
130
+
131
+ @triton.heuristics({
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in [2, 4, 8, 16, 32]
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BK', 'BT'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
145
+ q,
146
+ k,
147
+ a,
148
+ b,
149
+ gi,
150
+ ge,
151
+ qg,
152
+ kg,
153
+ ag,
154
+ bg,
155
+ Aqk,
156
+ Aqb,
157
+ Aab,
158
+ Aak,
159
+ offsets,
160
+ indices,
161
+ scale: tl.constexpr,
162
+ T,
163
+ H: tl.constexpr,
164
+ K: tl.constexpr,
165
+ BT: tl.constexpr,
166
+ BC: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ NC: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr,
171
+ GATHER_SUPPORTED: tl.constexpr
172
+ ):
173
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
174
+ i_b, i_h = i_bh // H, i_bh % H
175
+ i_j = i_i
176
+ if USE_OFFSETS:
177
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
178
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
179
+ T = eos - bos
180
+ else:
181
+ bos, eos = i_b * T, i_b * T + T
182
+
183
+ if i_t * BT + i_i * BC >= T:
184
+ return
185
+
186
+ o_i = tl.arange(0, BC)
187
+ o_k = tl.arange(0, BK)
188
+ m_k = o_k < K
189
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
190
+ last_idx = min((i_t+1) * BT, T) - 1
191
+ if HEAD_FIRST:
192
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
193
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
195
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
196
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
197
+ p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
198
+ p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
199
+ p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
200
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
201
+
202
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
203
+ p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
204
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
205
+ p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
206
+ else:
207
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
208
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
210
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
211
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
212
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
213
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
214
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
215
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
216
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
217
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
218
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
219
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
220
+
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ b_q = b_q * scale
223
+ b_k = tl.load(p_k, boundary_check=(0, 1))
224
+ b_a = tl.load(p_a, boundary_check=(0, 1))
225
+ b_b = tl.load(p_b, boundary_check=(0, 1))
226
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
227
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
228
+
229
+ # deal with decay term.
230
+ g_exp = exp(b_gi)
231
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
232
+ b_qg = b_q * g_exp
233
+ b_kg = b_k * g_exp_inv
234
+ b_bg = b_b * g_exp_inv
235
+ b_ag = b_a * exp(b_ge)
236
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
237
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
238
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ # tl.debug_barrier()
241
+
242
+ b_q = b_q.to(b_k.dtype)
243
+ # inner attn
244
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
245
+ # a trick to index the j-th row of b_k, b_g, b_b
246
+ if GATHER_SUPPORTED:
247
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
248
+ # [1, BK]
249
+ b_k_j = gather(b_k, row_idx, axis=0)
250
+ b_gk_j = gather(b_gi, row_idx, axis=0)
251
+ b_b_j = gather(b_b, row_idx, axis=0)
252
+ else:
253
+ mask = tl.arange(0, BC) == j
254
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
255
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
256
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
257
+ mask = tl.arange(0, BC) == j
258
+ tmp = exp(b_gi - b_gk_j)
259
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
260
+ b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
261
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
262
+ b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
263
+ tmp2 = exp(b_ge - b_gk_j)
264
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
265
+ b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
266
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
267
+ b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
268
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
269
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
270
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
271
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
272
+
273
+
274
+ def chunk_fwd_intra_dplr_fn(
275
+ q: torch.Tensor,
276
+ k: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gi: torch.Tensor,
280
+ ge: torch.Tensor,
281
+ scale: float,
282
+ chunk_size: int,
283
+ offsets: Optional[torch.LongTensor] = None,
284
+ indices: Optional[torch.LongTensor] = None,
285
+ head_first: bool = True,
286
+ ):
287
+ if head_first:
288
+ B, H, T, K = k.shape
289
+ else:
290
+ B, T, H, K = k.shape
291
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
292
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
293
+ BC = min(16, BT)
294
+ NC = triton.cdiv(BT, BC)
295
+
296
+ Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
297
+ Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
298
+ # involving matrix inverse and it'd be better to use float here.
299
+ Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
300
+ Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
301
+ grid = (NT, NC * NC, B * H)
302
+
303
+ chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
304
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
305
+ offsets=offsets, indices=indices,
306
+ scale=scale,
307
+ T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
308
+ HEAD_FIRST=head_first
309
+ )
310
+ grid = (NT, NC, B * H)
311
+ BK = triton.next_power_of_2(K)
312
+ qg = torch.empty_like(q)
313
+ kg = torch.empty_like(k, dtype=q.dtype)
314
+ ag = torch.empty_like(a, dtype=q.dtype)
315
+ bg = torch.empty_like(b, dtype=q.dtype)
316
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
317
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
318
+ qg=qg, kg=kg, ag=ag, bg=bg,
319
+ offsets=offsets, indices=indices,
320
+ scale=scale,
321
+ T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
322
+ GATHER_SUPPORTED=is_gather_supported
323
+ )
324
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_n, i_h = i_nh // H, i_nh % H
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ NT = tl.cdiv(T, BT)
66
+ boh = i_n * NT
67
+
68
+ # [BK, BV]
69
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_FINAL_STATE_GRADIENT:
71
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
73
+
74
+ mask_k = tl.arange(0, BK) < K
75
+ for i_t in range(NT - 1, -1, -1):
76
+ if HEAD_FIRST:
77
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
81
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
82
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
83
+ if HEAD_FIRST:
84
+ p_qg = tl.make_block_ptr(qg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
85
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
88
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ else:
91
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
92
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
93
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
95
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ # [BK, BT]
98
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
99
+ # [BT, BK]
100
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
101
+ b_w = tl.load(p_w, boundary_check=(0, 1))
102
+ # [BT, V]
103
+ b_do = tl.load(p_do, boundary_check=(0, 1))
104
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
105
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
106
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
107
+ # [BK, BV]
108
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
109
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
110
+ last_idx = min((i_t + 1) * BT, T) - 1
111
+ if HEAD_FIRST:
112
+ bg_last = tl.load(gk + (i_nh * T + last_idx) * K + tl.arange(0, BK), mask=mask_k)
113
+ else:
114
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
115
+ b_dh *= exp(bg_last)[:, None]
116
+ b_dh += b_dh_tmp
117
+
118
+ if USE_INITIAL_STATE:
119
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_bwd_dhu(
124
+ qg: torch.Tensor,
125
+ bg: torch.Tensor,
126
+ w: torch.Tensor,
127
+ gk: torch.Tensor,
128
+ h0: torch.Tensor,
129
+ dht: Optional[torch.Tensor],
130
+ do: torch.Tensor,
131
+ dv: torch.Tensor,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *qg.shape, do.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *qg.shape, do.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ BK = triton.next_power_of_2(K)
143
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
144
+ # H100
145
+ if check_shared_mem('hopper', qg.device.index):
146
+ BV = 64
147
+ BC = 64 if K <= 128 else 32
148
+ elif check_shared_mem('ampere', qg.device.index): # A100
149
+ BV = 32
150
+ BC = 32
151
+ else: # Etc: 4090
152
+ BV = 16
153
+ BC = 16
154
+
155
+ # N: the actual number of sequences in the batch with either equal or variable lengths
156
+ if offsets is None:
157
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
158
+ else:
159
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
160
+
161
+ BC = min(BT, BC)
162
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
163
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
164
+
165
+ if head_first:
166
+ dh = qg.new_empty(B, H, NT, K, V)
167
+ else:
168
+ dh = qg.new_empty(B, NT, H, K, V)
169
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
170
+ dv2 = torch.zeros_like(dv)
171
+
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_bwd_kernel_dhu[grid](
174
+ qg=qg,
175
+ bg=bg,
176
+ w=w,
177
+ gk=gk,
178
+ dht=dht,
179
+ dh0=dh0,
180
+ do=do,
181
+ dh=dh,
182
+ dv=dv,
183
+ dv2=dv2,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ HEAD_FIRST=head_first
195
+ )
196
+ return dh, dh0, dv2
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, use_cuda_graph
11
+
12
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in BK_LIST
22
+ for BV in BK_LIST
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_o(
31
+ qg,
32
+ v,
33
+ v_new,
34
+ A_qk,
35
+ A_qb,
36
+ h,
37
+ o,
38
+ offsets,
39
+ indices,
40
+ T,
41
+ H: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ USE_OFFSETS: tl.constexpr,
48
+ HEAD_FIRST: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if USE_OFFSETS:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ if HEAD_FIRST:
67
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
68
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
69
+ else:
70
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
71
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
73
+ b_h = tl.load(p_h, boundary_check=(0, 1))
74
+ b_o += tl.dot(b_qg, b_h)
75
+
76
+ if HEAD_FIRST:
77
+ p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
78
+ p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
79
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
82
+ else:
83
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
84
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+
89
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
90
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
91
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
92
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
93
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
96
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+
100
+ def chunk_dplr_fwd_o(
101
+ qg: torch.Tensor,
102
+ v: torch.Tensor,
103
+ v_new: torch.Tensor,
104
+ A_qk: torch.Tensor,
105
+ A_qb: torch.Tensor,
106
+ h: torch.Tensor,
107
+ offsets: Optional[torch.LongTensor] = None,
108
+ indices: Optional[torch.LongTensor] = None,
109
+ head_first: bool = True,
110
+ chunk_size: int = 64
111
+ ) -> torch.Tensor:
112
+ if head_first:
113
+ B, H, T, K, V = *qg.shape, v.shape[-1]
114
+ else:
115
+ B, T, H, K, V = *qg.shape, v.shape[-1]
116
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
117
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
118
+
119
+ o = torch.empty_like(v)
120
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
121
+ chunk_dplr_fwd_kernel_o[grid](
122
+ qg=qg,
123
+ v=v,
124
+ v_new=v_new,
125
+ A_qk=A_qk,
126
+ A_qb=A_qb,
127
+ h=h,
128
+ o=o,
129
+ offsets=offsets,
130
+ indices=indices,
131
+ T=T,
132
+ H=H,
133
+ K=K,
134
+ V=V,
135
+ BT=BT,
136
+ HEAD_FIRST=head_first
137
+ )
138
+ return o
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ offsets,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ USE_OFFSETS: tl.constexpr,
53
+ HEAD_FIRST: tl.constexpr
54
+ ):
55
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ o_k = tl.arange(0, BK)
65
+ o_v = i_v * BV + tl.arange(0, BV)
66
+ if HEAD_FIRST:
67
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
68
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
69
+ p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
70
+ p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
71
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
72
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
73
+ p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
74
+
75
+ else:
76
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
77
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
78
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
79
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
80
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
82
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
83
+
84
+ mask_k = o_k < K
85
+ mask_v = o_v < V
86
+ mask_h = mask_k[None, :] & mask_v[:, None]
87
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
88
+
89
+ if USE_INITIAL_STATE:
90
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
91
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
92
+
93
+ for _ in range(0, T):
94
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
95
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
97
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
98
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
99
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
100
+
101
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
102
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
103
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
104
+
105
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
106
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
107
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
108
+ p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
109
+ p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
110
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
111
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
112
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
113
+
114
+ if STORE_FINAL_STATE:
115
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
116
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
117
+
118
+
119
+ def fused_recurrent_dplr_delta_rule_fwd(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ a: torch.Tensor,
124
+ b: torch.Tensor,
125
+ gk: torch.Tensor,
126
+ scale: Optional[float] = 1.0,
127
+ initial_state: Optional[torch.Tensor] = None,
128
+ output_final_state: bool = False,
129
+ reverse: bool = False,
130
+ offsets: Optional[torch.LongTensor] = None,
131
+ head_first: bool = True
132
+ ):
133
+ if head_first:
134
+ B, H, T, K, V = *k.shape, v.shape[-1]
135
+ else:
136
+ B, T, H, K, V = *k.shape, v.shape[-1]
137
+ N = B if offsets is None else len(offsets) - 1
138
+ BK = triton.next_power_of_2(K)
139
+
140
+ h0 = initial_state
141
+ if output_final_state:
142
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
143
+ else:
144
+ ht = None
145
+ o = torch.empty_like(v)
146
+
147
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
148
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
149
+ q,
150
+ k,
151
+ v,
152
+ a,
153
+ b,
154
+ gk,
155
+ o,
156
+ h0,
157
+ ht,
158
+ offsets,
159
+ scale,
160
+ T=T,
161
+ B=B,
162
+ H=H,
163
+ K=K,
164
+ V=V,
165
+ BK=BK,
166
+ REVERSE=reverse,
167
+ HEAD_FIRST=head_first
168
+ )
169
+ return o, ht
170
+
171
+
172
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ @autocast_custom_fwd
177
+ def forward(
178
+ ctx,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ v: torch.Tensor,
182
+ a: torch.Tensor,
183
+ b: torch.Tensor,
184
+ gk: torch.Tensor,
185
+ scale: Optional[float] = 1.0,
186
+ initial_state: Optional[torch.Tensor] = None,
187
+ output_final_state: bool = False,
188
+ reverse: bool = False,
189
+ offsets: Optional[torch.LongTensor] = None,
190
+ head_first: bool = False
191
+ ):
192
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
193
+ q=q,
194
+ k=k,
195
+ v=v,
196
+ a=a,
197
+ b=b,
198
+ gk=gk,
199
+ scale=scale,
200
+ initial_state=initial_state,
201
+ output_final_state=output_final_state,
202
+ reverse=reverse,
203
+ offsets=offsets,
204
+ head_first=head_first
205
+ )
206
+ return o, ht
207
+
208
+ @staticmethod
209
+ @input_guard
210
+ @autocast_custom_bwd
211
+ def backward(ctx, do, dht):
212
+ raise NotImplementedError(
213
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
214
+ "This kernel is only for inference. "
215
+ "For training, please use `chunk_dplr_delta_rule`."
216
+ )
217
+
218
+
219
+ def fused_recurrent_dplr_delta_rule(
220
+ q: torch.Tensor,
221
+ k: torch.Tensor,
222
+ v: torch.Tensor,
223
+ a: torch.Tensor,
224
+ b: torch.Tensor,
225
+ gk: torch.Tensor,
226
+ scale: Optional[float] = 1.0,
227
+ initial_state: Optional[torch.Tensor] = None,
228
+ output_final_state: bool = False,
229
+ reverse: bool = False,
230
+ cu_seqlens: Optional[torch.Tensor] = None,
231
+ head_first: bool = False
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ r"""
234
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
235
+
236
+ Args:
237
+ q (torch.Tensor):
238
+ queries of shape `[B, H, T, K]`
239
+ k (torch.Tensor):
240
+ keys of shape `[B, H, T, K]`
241
+ v (torch.Tensor):
242
+ values of shape `[B, H, T, V]`
243
+ a (torch.Tensor):
244
+ as of shape `[B, H, T, K]`
245
+ b (torch.Tensor):
246
+ bs of shape `[B, H, T, K]`
247
+ gk (torch.Tensor):
248
+ gk of shape `[B, H, T, K]`
249
+ scale (Optional[int]):
250
+ Scale factor for the RetNet attention scores.
251
+ If None, it will default to `1 / sqrt(K)`. Default: `1.0`.
252
+ initial_state (Optional[torch.Tensor]):
253
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
254
+ output_final_state (Optional[bool]):
255
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
256
+ reverse (Optional[bool]):
257
+ If `True`, process the state passing in reverse order. Default: `False`.
258
+ cu_seqlens (Optional[torch.Tensor]):
259
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
260
+ consistent with the FlashAttention API.
261
+ head_first (Optional[bool]):
262
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
263
+ Default: `False`.
264
+ """
265
+ if cu_seqlens is not None:
266
+ if q.shape[0] != 1:
267
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
268
+ f"Please flatten variable-length inputs before processing.")
269
+ if head_first:
270
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
271
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
272
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
273
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
274
+ if scale is None:
275
+ scale = q.shape[-1] ** -0.5
276
+ else:
277
+ assert scale > 0, "scale must be positive"
278
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
279
+ q,
280
+ k,
281
+ v,
282
+ a,
283
+ b,
284
+ gk,
285
+ scale,
286
+ initial_state,
287
+ output_final_state,
288
+ reverse,
289
+ cu_seqlens,
290
+ head_first
291
+ )
292
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_delta_h import prepare_chunk_offsets
11
+ from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
13
+
14
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [2, 4, 8, 16]
26
+ ],
27
+ key=['BT', 'BK', 'BV'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
32
+ k,
33
+ v,
34
+ d,
35
+ b,
36
+ u,
37
+ v_new,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
82
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
83
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
86
+ p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ else:
92
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
93
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
95
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
98
+ # [BK, BC]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ b_d = tl.load(p_d, boundary_check=(0, 1))
102
+ b_b = tl.load(p_b, boundary_check=(0, 1))
103
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
104
+ b_hc += tl.dot(b_k, b_v)
105
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
106
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
107
+ b_h += b_hc
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
116
+ })
117
+ @triton.autotune(
118
+ configs=[
119
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
120
+ for BK in BKV_LIST
121
+ for BV in BKV_LIST
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3]
124
+ ],
125
+ key=['BT'],
126
+ use_cuda_graph=use_cuda_graph,
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
130
+ q,
131
+ k,
132
+ v,
133
+ u,
134
+ b,
135
+ h,
136
+ o,
137
+ offsets,
138
+ indices,
139
+ scale,
140
+ T,
141
+ H: tl.constexpr,
142
+ K: tl.constexpr,
143
+ V: tl.constexpr,
144
+ BT: tl.constexpr,
145
+ BK: tl.constexpr,
146
+ BV: tl.constexpr,
147
+ USE_OFFSETS: tl.constexpr,
148
+ HEAD_FIRST: tl.constexpr,
149
+ ):
150
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
151
+ i_b, i_h = i_bh // H, i_bh % H
152
+
153
+ if USE_OFFSETS:
154
+ i_tg = i_t
155
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
156
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
157
+ T = eos - bos
158
+ NT = tl.cdiv(T, BT)
159
+ else:
160
+ NT = tl.cdiv(T, BT)
161
+ i_tg = i_b * NT + i_t
162
+ bos, eos = i_b * T, i_b * T + T
163
+
164
+ # offset calculation
165
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
166
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
167
+ b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
168
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
169
+ u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
170
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
171
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
172
+ stride_qk = K if HEAD_FIRST else H*K
173
+ stride_vo = V if HEAD_FIRST else H*V
174
+
175
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
176
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
177
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
178
+
179
+ for i_k in range(tl.cdiv(K, BK)):
180
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
181
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
182
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
183
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
184
+ # [BT, BK]
185
+ b_q = tl.load(p_q, boundary_check=(0, 1))
186
+ # [BK, BT]
187
+ b_k = tl.load(p_k, boundary_check=(0, 1))
188
+ b_b = tl.load(p_b, boundary_check=(0, 1))
189
+ # [BK, BV]
190
+ b_h = tl.load(p_h, boundary_check=(0, 1))
191
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
192
+ b_o += tl.dot(b_q, b_h)
193
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
194
+ b_Aqk += tl.dot(b_q, b_k)
195
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
196
+ b_Aqb += tl.dot(b_q, b_b)
197
+
198
+ o_i = tl.arange(0, BT)
199
+ m_A = o_i[:, None] >= o_i[None, :]
200
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
201
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
202
+
203
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ b_v = tl.load(p_v, boundary_check=(0, 1))
207
+ b_u = tl.load(p_u, boundary_check=(0, 1))
208
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
209
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
210
+
211
+
212
+ def chunk_generalized_iplr_delta_rule_fwd_o(
213
+ q: torch.Tensor,
214
+ k: torch.Tensor,
215
+ v: torch.Tensor,
216
+ v_new: torch.Tensor,
217
+ b: torch.Tensor,
218
+ h: torch.Tensor,
219
+ scale: Optional[float] = None,
220
+ offsets: Optional[torch.LongTensor] = None,
221
+ indices: Optional[torch.LongTensor] = None,
222
+ head_first: bool = True,
223
+ chunk_size: int = 64
224
+ ) -> torch.Tensor:
225
+ if head_first:
226
+ B, H, T, K, V = *q.shape, v.shape[-1]
227
+ else:
228
+ B, T, H, K, V = *q.shape, v.shape[-1]
229
+ if scale is None:
230
+ scale = k.shape[-1] ** -0.5
231
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
232
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
233
+
234
+ o = torch.empty_like(v)
235
+
236
+ def grid(meta): return (
237
+ triton.cdiv(V, meta['BV']),
238
+ NT,
239
+ B * H
240
+ )
241
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
242
+ q=q,
243
+ k=k,
244
+ v=v,
245
+ u=v_new,
246
+ b=b,
247
+ h=h,
248
+ o=o,
249
+ offsets=offsets,
250
+ indices=indices,
251
+ scale=scale,
252
+ T=T,
253
+ H=H,
254
+ K=K,
255
+ V=V,
256
+ BT=BT,
257
+ HEAD_FIRST=head_first
258
+ )
259
+ return o
260
+
261
+
262
+ def chunk_generalized_iplr_delta_rule_fwd_h(
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ w: torch.Tensor,
266
+ u: torch.Tensor,
267
+ b: torch.Tensor,
268
+ initial_state: Optional[torch.Tensor] = None,
269
+ output_final_state: bool = False,
270
+ offsets: Optional[torch.LongTensor] = None,
271
+ indices: Optional[torch.LongTensor] = None,
272
+ head_first: bool = True,
273
+ chunk_size: int = 64
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ if head_first:
276
+ B, H, T, K, V = *k.shape, u.shape[-1]
277
+ else:
278
+ B, T, H, K, V = *k.shape, u.shape[-1]
279
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
280
+ # N: the actual number of sequences in the batch with either equal or variable lengths
281
+ if offsets is None:
282
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
283
+ else:
284
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
285
+
286
+ BK = triton.next_power_of_2(K)
287
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
288
+ # H100 can have larger block size
289
+
290
+ if check_shared_mem('hopper', k.device.index):
291
+ BV = 64
292
+ BC = 64 if K <= 128 else 32
293
+ elif check_shared_mem('ampere', k.device.index): # A100
294
+ BV = 32
295
+ BC = 32
296
+ else:
297
+ BV = 16
298
+ BC = 16
299
+
300
+ BC = min(BT, BC)
301
+ NK = triton.cdiv(K, BK)
302
+ NV = triton.cdiv(V, BV)
303
+
304
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
305
+
306
+ if head_first:
307
+ h = k.new_empty(B, H, NT, K, V)
308
+ else:
309
+ h = k.new_empty(B, NT, H, K, V)
310
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
311
+
312
+ v_new = torch.empty_like(u)
313
+ grid = (NK, NV, N * H)
314
+
315
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
316
+ k=k,
317
+ v=v,
318
+ d=w,
319
+ b=b,
320
+ u=u,
321
+ v_new=v_new,
322
+ h=h,
323
+ h0=initial_state,
324
+ ht=final_state,
325
+ offsets=offsets,
326
+ chunk_offsets=chunk_offsets,
327
+ T=T,
328
+ H=H,
329
+ K=K,
330
+ V=V,
331
+ BT=BT,
332
+ BC=BC,
333
+ BK=BK,
334
+ BV=BV,
335
+ NT=NT,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return h, v_new, final_state
339
+
340
+
341
+ def chunk_generalized_iplr_delta_rule_fwd(
342
+ q: torch.Tensor,
343
+ k: torch.Tensor,
344
+ v: torch.Tensor,
345
+ a: torch.Tensor,
346
+ b: torch.Tensor,
347
+ scale: float,
348
+ initial_state: torch.Tensor,
349
+ output_final_state: bool,
350
+ offsets: Optional[torch.LongTensor] = None,
351
+ indices: Optional[torch.LongTensor] = None,
352
+ head_first: bool = True,
353
+ chunk_size: int = 64
354
+ ):
355
+ T = q.shape[2] if head_first else q.shape[1]
356
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
357
+ w, u, _ = fwd_prepare_wy_repr(
358
+ a=a,
359
+ b=b,
360
+ k=k,
361
+ v=v,
362
+ offsets=offsets,
363
+ indices=indices,
364
+ head_first=head_first,
365
+ chunk_size=BT
366
+ )
367
+
368
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
369
+ k=k,
370
+ v=v,
371
+ b=b,
372
+ w=w,
373
+ u=u,
374
+ initial_state=initial_state,
375
+ output_final_state=output_final_state,
376
+ offsets=offsets,
377
+ indices=indices,
378
+ head_first=head_first,
379
+ chunk_size=BT
380
+ )
381
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
382
+ q=q,
383
+ k=k,
384
+ v=v,
385
+ v_new=v_new,
386
+ b=b,
387
+ h=h,
388
+ scale=scale,
389
+ offsets=offsets,
390
+ indices=indices,
391
+ head_first=head_first,
392
+ chunk_size=BT
393
+ )
394
+ return o, final_state
395
+
396
+
397
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
398
+
399
+ @staticmethod
400
+ @input_guard
401
+ @autocast_custom_fwd
402
+ def forward(
403
+ ctx,
404
+ q: torch.Tensor,
405
+ k: torch.Tensor,
406
+ v: torch.Tensor,
407
+ a: torch.Tensor,
408
+ b: torch.Tensor,
409
+ scale: float,
410
+ initial_state: torch.Tensor,
411
+ output_final_state: bool,
412
+ offsets: Optional[torch.LongTensor] = None,
413
+ head_first: bool = True
414
+ ):
415
+ chunk_size = 64
416
+
417
+ # 2-d indices denoting the offsets of chunks in each sequence
418
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
419
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
420
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
421
+ indices = None
422
+ if offsets is not None:
423
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
424
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
425
+
426
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
427
+ q=q,
428
+ k=k,
429
+ v=v,
430
+ a=a,
431
+ b=b,
432
+ scale=scale,
433
+ initial_state=initial_state,
434
+ output_final_state=output_final_state,
435
+ offsets=offsets,
436
+ indices=indices,
437
+ head_first=head_first,
438
+ chunk_size=chunk_size
439
+ )
440
+ return o.to(q.dtype), final_state
441
+
442
+ @staticmethod
443
+ @input_guard
444
+ @autocast_custom_bwd
445
+ def backward(
446
+ ctx,
447
+ do: torch.Tensor,
448
+ dht: torch.Tensor
449
+ ):
450
+ raise NotImplementedError(
451
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
452
+ "Stay tuned!"
453
+ )
454
+
455
+
456
+ @torch.compiler.disable
457
+ def chunk_iplr_delta_rule(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ a: torch.Tensor,
462
+ b: torch.Tensor,
463
+ scale: float = None,
464
+ initial_state: torch.Tensor = None,
465
+ output_final_state: bool = False,
466
+ cu_seqlens: Optional[torch.LongTensor] = None,
467
+ head_first: bool = True
468
+ ):
469
+ r"""
470
+ Args:
471
+ q (torch.Tensor):
472
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
473
+ k (torch.Tensor):
474
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
475
+ v (torch.Tensor):
476
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
477
+ a (torch.Tensor):
478
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
479
+ b (torch.Tensor):
480
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
481
+ scale (Optional[int]):
482
+ Scale factor for the RetNet attention scores.
483
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
484
+ initial_state (Optional[torch.Tensor]):
485
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
486
+ For equal-length input sequences, `N` equals the batch size `B`.
487
+ Default: `None`.
488
+ output_final_state (Optional[bool]):
489
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
490
+ cu_seqlens (torch.LongTensor):
491
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
492
+ consistent with the FlashAttention API.
493
+ head_first (Optional[bool]):
494
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
495
+ Default: `True`.
496
+
497
+ Returns:
498
+ o (torch.Tensor):
499
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
500
+ final_state (torch.Tensor):
501
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
502
+ """
503
+ assert q.dtype == k.dtype == v.dtype
504
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
505
+
506
+ if cu_seqlens is not None:
507
+ if q.shape[0] != 1:
508
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
509
+ f"Please flatten variable-length inputs before processing.")
510
+ if head_first:
511
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
512
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
513
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
514
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
515
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
516
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
517
+ q,
518
+ k,
519
+ v,
520
+ a,
521
+ b,
522
+ scale,
523
+ initial_state,
524
+ output_final_state,
525
+ cu_seqlens,
526
+ head_first
527
+ )
528
+ return o, final_state
fla/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (81.8 kB). View file
 
fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (5.69 kB). View file
 
fla/ops/gla/fused_recurrent.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.common.fused_recurrent import fused_recurrent
9
+
10
+
11
+ def fused_recurrent_gla(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ gk: Optional[torch.Tensor] = None,
16
+ gv: Optional[torch.Tensor] = None,
17
+ scale: Optional[int] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ reverse: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ gk (torch.Tensor):
33
+ Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys.
34
+ gv (torch.Tensor):
35
+ Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ reverse (Optional[bool]):
46
+ If `True`, process the state passing in reverse order. Default: `False`.
47
+ cu_seqlens (torch.LongTensor):
48
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
49
+ consistent with the FlashAttention API.
50
+ head_first (Optional[bool]):
51
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
52
+ Default: `True`.
53
+
54
+ Returns:
55
+ o (torch.Tensor):
56
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
57
+ final_state (torch.Tensor):
58
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
59
+
60
+ Examples::
61
+ >>> import torch
62
+ >>> import torch.nn.functional as F
63
+ >>> from einops import rearrange
64
+ >>> from fla.ops.gla import fused_recurrent_gla
65
+ # inputs with equal lengths
66
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
67
+ >>> q = torch.randn(B, T, H, K, device='cuda')
68
+ >>> k = torch.randn(B, T, H, K, device='cuda')
69
+ >>> v = torch.randn(B, T, H, V, device='cuda')
70
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
71
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
72
+ >>> o, ht = fused_recurrent_gla(q, k, v, g,
73
+ initial_state=h0,
74
+ output_final_state=True,
75
+ head_first=False)
76
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
77
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
78
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
79
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
80
+ >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g,
81
+ initial_state=h0,
82
+ output_final_state=True,
83
+ cu_seqlens=cu_seqlens,
84
+ head_first=False)
85
+ >>> assert o.allclose(o_var.view(o.shape))
86
+ >>> assert ht.allclose(ht_var)
87
+ """
88
+ if cu_seqlens is not None:
89
+ if q.shape[0] != 1:
90
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
91
+ f"Please flatten variable-length inputs before processing.")
92
+ if head_first:
93
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
94
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
95
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
96
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
97
+ if scale is None:
98
+ scale = k.shape[-1] ** -0.5
99
+ o, final_state = fused_recurrent(
100
+ q=q,
101
+ k=k,
102
+ v=v,
103
+ g=None,
104
+ gk=gk,
105
+ gv=gv,
106
+ scale=scale,
107
+ initial_state=initial_state,
108
+ output_final_state=output_final_state,
109
+ reverse=reverse,
110
+ cu_seqlens=cu_seqlens,
111
+ head_first=head_first
112
+ )
113
+ return o, final_state
fla/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import reduce
10
+
11
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
12
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
13
+ from fla.ops.utils import chunk_local_cumsum, softmax_bwd, softmax_fwd
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import input_guard
16
+
17
+
18
+ @triton.heuristics({
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in [32, 64]
25
+ for BV in [32, 64]
26
+ for num_warps in [2, 4, 8]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BT']
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_gsa_fwd_k_kernel_inter(
33
+ q,
34
+ k,
35
+ h,
36
+ g,
37
+ o,
38
+ A,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ HQ: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ NG: tl.constexpr,
51
+ USE_OFFSETS: tl.constexpr,
52
+ HEAD_FIRST: tl.constexpr
53
+ ):
54
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
55
+ i_bg = i_bh // NG
56
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
57
+ i_h = i_hq // NG
58
+ if USE_OFFSETS:
59
+ i_tg = i_t
60
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ NT = tl.cdiv(T, BT)
66
+ i_tg = i_b * NT + i_t
67
+ bos, eos = i_b * T, i_b * T + T
68
+
69
+ o_i = tl.arange(0, BT)
70
+ m_s = o_i[:, None] >= o_i[None, :]
71
+
72
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
73
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
74
+ for i_k in range(tl.cdiv(K, BK)):
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_bg * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+ else:
80
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
82
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
83
+
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ b_q = (b_q * scale).to(b_q.dtype)
87
+ # [BK, BT]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ # [BK, BV]
90
+ b_h = tl.load(p_h, boundary_check=(0, 1))
91
+ # [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+ if HEAD_FIRST:
96
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
98
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
99
+ else:
100
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
101
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
102
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
103
+ # [BT, BV]
104
+ b_g = tl.load(p_g, boundary_check=(0, 1))
105
+ b_o = b_o * exp(b_g)
106
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
107
+
108
+ # [BT, BT]
109
+ b_A = tl.where(m_s, b_A, 0.)
110
+ if i_v == 0:
111
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
116
+ })
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_gsa_fwd_k_kernel_intra(
119
+ v,
120
+ g,
121
+ o,
122
+ A,
123
+ offsets,
124
+ indices,
125
+ T,
126
+ HQ: tl.constexpr,
127
+ H: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BC: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ NC: tl.constexpr,
133
+ NG: tl.constexpr,
134
+ USE_OFFSETS: tl.constexpr,
135
+ HEAD_FIRST: tl.constexpr
136
+ ):
137
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
138
+ i_bg = i_bh // NG
139
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
140
+ i_h = i_hq // NG
141
+ i_t, i_i = i_c // NC, i_c % NC
142
+ if USE_OFFSETS:
143
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ else:
147
+ bos, eos = i_b * T, i_b * T + T
148
+
149
+ o_v = i_v * BV + tl.arange(0, BV)
150
+ m_v = o_v < V
151
+
152
+ if i_t * BT + i_i * BC > T:
153
+ return
154
+
155
+ if HEAD_FIRST:
156
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
157
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + min(i_t * BT + i_i * BC, T) * V + o_v, BV), BV)
158
+ else:
159
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
160
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
161
+ # [BV,]
162
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
163
+ # [BC, BV]
164
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
165
+ for i_j in range(0, i_i):
166
+ if HEAD_FIRST:
167
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
168
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
169
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
170
+ else:
171
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
172
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
173
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
174
+ # [BC, BV]
175
+ b_v = tl.load(p_v, boundary_check=(0, 1))
176
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
177
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
178
+ # [BC, BC]
179
+ b_A = tl.load(p_A, boundary_check=(0, 1))
180
+ b_o += tl.dot(b_A, b_vg)
181
+ # [BC, BV]
182
+ b_g = tl.load(p_g, boundary_check=(0, 1))
183
+ b_o *= exp(b_g - b_gn[None, :])
184
+
185
+ o_i = tl.arange(0, BC)
186
+ if HEAD_FIRST:
187
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
188
+ else:
189
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
190
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
191
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
192
+ if HEAD_FIRST:
193
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
194
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
195
+ else:
196
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
197
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
198
+ # [BC,]
199
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
200
+ # [BV,]
201
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
202
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
203
+ # [BC, BV]
204
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
205
+ # avoid 0 * inf = inf
206
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
207
+ if HEAD_FIRST:
208
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
209
+ else:
210
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
211
+ b_o += tl.load(p_o, boundary_check=(0, 1))
212
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
213
+
214
+
215
+ @triton.heuristics({
216
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
217
+ })
218
+ @triton.autotune(
219
+ configs=[
220
+ triton.Config({}, num_warps=num_warps)
221
+ for num_warps in [2, 4, 8]
222
+ ],
223
+ key=["BT"]
224
+ )
225
+ @triton.jit(do_not_specialize=['T'])
226
+ def chunk_gsa_bwd_k_kernel_dA(
227
+ v,
228
+ g,
229
+ do,
230
+ dA,
231
+ indices,
232
+ offsets,
233
+ scale,
234
+ T,
235
+ B: tl.constexpr,
236
+ HQ: tl.constexpr,
237
+ H: tl.constexpr,
238
+ V: tl.constexpr,
239
+ BT: tl.constexpr,
240
+ BC: tl.constexpr,
241
+ BV: tl.constexpr,
242
+ NC: tl.constexpr,
243
+ NG: tl.constexpr,
244
+ USE_OFFSETS: tl.constexpr,
245
+ HEAD_FIRST: tl.constexpr
246
+ ):
247
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_bg = i_bh // NG
249
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
250
+ i_h = i_hq // NG
251
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
252
+ if USE_OFFSETS:
253
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
254
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
255
+ all = T
256
+ T = eos - bos
257
+ else:
258
+ bos, eos = i_b * T, i_b * T + T
259
+ all = B * T
260
+
261
+ o_v = i_v * BV + tl.arange(0, BV)
262
+ m_v = o_v < V
263
+
264
+ if i_t * BT + i_i * BC > T:
265
+ return
266
+
267
+ if HEAD_FIRST:
268
+ p_dA = tl.make_block_ptr(dA+(i_v*B*H+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
269
+ else:
270
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
271
+
272
+ # [BC, BC]
273
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
274
+ if i_i > i_j:
275
+ if HEAD_FIRST:
276
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
277
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
278
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
279
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
280
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
281
+ else:
282
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
283
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
284
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
285
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
286
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
287
+ # [BV,]
288
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
289
+ # [BC, BV]
290
+ b_g = tl.load(p_g, boundary_check=(0, 1))
291
+ b_do = tl.load(p_do, boundary_check=(0, 1))
292
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
293
+ # [BV, BC]
294
+ b_v = tl.load(p_v, boundary_check=(0, 1))
295
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
296
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
297
+ # [BC, BC]
298
+ b_dA = tl.dot(b_do, b_vg)
299
+ elif i_i == i_j:
300
+ if HEAD_FIRST:
301
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
302
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
303
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
304
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
305
+ else:
306
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
307
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
308
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
309
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
310
+ # [BC, BV]
311
+ b_g = tl.load(p_g, boundary_check=(0, 1))
312
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
313
+ m_v = o_v < V
314
+
315
+ o_i = tl.arange(0, BC)
316
+ # [BC, BC]
317
+ m_dA = o_i[:, None] >= o_i[None, :]
318
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
319
+ # [BV,]
320
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
321
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
322
+ # [BC,]
323
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
324
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
325
+
326
+ p_v += (1 if HEAD_FIRST else H) * V
327
+ p_gv += (1 if HEAD_FIRST else H) * V
328
+ b_dA = tl.where(m_dA, b_dA, 0.)
329
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
330
+
331
+
332
+ @triton.heuristics({
333
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
334
+ })
335
+ @triton.autotune(
336
+ configs=[
337
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338
+ for num_warps in [2, 4]
339
+ for num_stages in [2, 3, 4]
340
+ ],
341
+ key=['BT']
342
+ )
343
+ @triton.jit(do_not_specialize=['T'])
344
+ def chunk_gsa_bwd_k_kernel_dqkvg(
345
+ q,
346
+ k,
347
+ v,
348
+ h,
349
+ g,
350
+ A,
351
+ do,
352
+ dh,
353
+ dq,
354
+ dk,
355
+ dv,
356
+ dg,
357
+ dgv,
358
+ dA,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ B: tl.constexpr,
364
+ HQ: tl.constexpr,
365
+ H: tl.constexpr,
366
+ K: tl.constexpr,
367
+ V: tl.constexpr,
368
+ BT: tl.constexpr,
369
+ BK: tl.constexpr,
370
+ BV: tl.constexpr,
371
+ NG: tl.constexpr,
372
+ USE_OFFSETS: tl.constexpr,
373
+ HEAD_FIRST: tl.constexpr
374
+ ):
375
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
376
+ i_bg = i_bh // NG
377
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
378
+ i_h = i_hq // NG
379
+ if USE_OFFSETS:
380
+ i_tg = i_t
381
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
382
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
383
+ all = T
384
+ T = eos - bos
385
+ NT = tl.cdiv(T, BT)
386
+ else:
387
+ NT = tl.cdiv(T, BT)
388
+ i_tg = i_b * NT + i_t
389
+ bos, eos = i_b * T, i_b * T + T
390
+ all = B * T
391
+
392
+ o_i = tl.arange(0, BT)
393
+ o_t = min(i_t * BT + BT, T)
394
+ m_s = o_i[:, None] >= o_i[None, :]
395
+
396
+ if HEAD_FIRST:
397
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
398
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
399
+ p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh) * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
400
+ else:
401
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
402
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
403
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
404
+
405
+ # [BT, BK]
406
+ b_q = tl.load(p_q, boundary_check=(0, 1))
407
+ b_k = tl.load(p_k, boundary_check=(0, 1))
408
+ # [BT, BT]
409
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
410
+ b_A = tl.where(m_s, b_A, 0.)
411
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
412
+
413
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
414
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
415
+ for i_v in range(tl.cdiv(V, BV)):
416
+ o_v = i_v * BV + tl.arange(0, BV)
417
+ if HEAD_FIRST:
418
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
419
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
420
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (o_t - 1) * V + o_v, BV), BV)
421
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
422
+ p_dv = tl.make_block_ptr(dv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
423
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
424
+ p_dgv = tl.make_block_ptr(dgv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
425
+ p_h = tl.make_block_ptr(h + i_bg * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
426
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
427
+ else:
428
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
429
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
430
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
431
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
432
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
433
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
434
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
436
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
437
+ m_v = o_v < V
438
+
439
+ # [BV,]
440
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
441
+ # [BT, BV]
442
+ b_v = tl.load(p_v, boundary_check=(0, 1))
443
+ b_g = tl.load(p_g, boundary_check=(0, 1))
444
+ b_gv = exp(b_gn[None, :] - b_g)
445
+ # [BV, BK]
446
+ b_h = tl.load(p_h, boundary_check=(0, 1))
447
+ # [BT, BV]
448
+ b_do = tl.load(p_do, boundary_check=(0, 1))
449
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
450
+ # [BK, BV]
451
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
452
+ # [BV]
453
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
454
+
455
+ b_dh = b_dh.to(b_k.dtype)
456
+ # [BT, BK]
457
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
458
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
459
+ # [BT, BV]
460
+ b_dv = tl.dot(b_k, b_dh) * b_gv
461
+ # [BV]
462
+ b_dg += tl.sum(b_dv * b_v, 0)
463
+
464
+ if i_k == 0:
465
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
466
+ else:
467
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
468
+
469
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
470
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
471
+ if HEAD_FIRST:
472
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
473
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
474
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
475
+ else:
476
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
477
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
478
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
479
+ # [BT, BT]
480
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
481
+ # [BT, BK]
482
+ b_dq += tl.dot(b_dA, b_k)
483
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
484
+
485
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
486
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
487
+
488
+
489
+ @triton.heuristics({
490
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
491
+ })
492
+ @triton.jit(do_not_specialize=['T'])
493
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
494
+ v,
495
+ g,
496
+ o,
497
+ A,
498
+ do,
499
+ dv,
500
+ dg,
501
+ offsets,
502
+ indices,
503
+ T,
504
+ HQ: tl.constexpr,
505
+ H: tl.constexpr,
506
+ V: tl.constexpr,
507
+ BT: tl.constexpr,
508
+ BC: tl.constexpr,
509
+ BV: tl.constexpr,
510
+ NC: tl.constexpr,
511
+ NG: tl.constexpr,
512
+ USE_OFFSETS: tl.constexpr,
513
+ HEAD_FIRST: tl.constexpr
514
+ ):
515
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
516
+ i_bg = i_bh // NG
517
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
518
+ i_h = i_hq // NG
519
+ i_t, i_i = i_c // NC, i_c % NC
520
+ if USE_OFFSETS:
521
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
523
+ T = eos - bos
524
+ else:
525
+ bos, eos = i_b * T, i_b * T + T
526
+
527
+ o_v = i_v * BV + tl.arange(0, BV)
528
+ m_v = o_v < V
529
+
530
+ if i_t * BT + i_i * BC > T:
531
+ return
532
+
533
+ if HEAD_FIRST:
534
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
535
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (min(i_t * BT + i_i * BC + BC, T) - 1) * V + o_v, BV), BV)
536
+ else:
537
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
538
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
539
+ # [BV,]
540
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
541
+ # [BC, BV]
542
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
543
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
544
+ for i_j in range(i_i + 1, NC):
545
+ if HEAD_FIRST:
546
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
547
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
548
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
549
+ else:
550
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
551
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
552
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
553
+ # [BC, BV]
554
+ b_g = tl.load(p_g, boundary_check=(0, 1))
555
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
556
+ # [BC, BC]
557
+ b_A = tl.load(p_A, boundary_check=(0, 1))
558
+ # [BC, BV]
559
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
560
+ b_dv *= exp(b_gn[None, :] - b_gv)
561
+
562
+ o_i = tl.arange(0, BC)
563
+ o_c = i_i * BC + tl.arange(0, BC)
564
+
565
+ if HEAD_FIRST:
566
+ p_g = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
567
+ p_A = tl.max_contiguous(tl.multiple_of(A + i_bh * T*BT + (i_t * BT + i_i * BC) * BT + o_c, BC), BC)
568
+ p_do = tl.max_contiguous(tl.multiple_of(do + i_bh * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
569
+ else:
570
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
571
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
572
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
573
+
574
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
575
+ # [BC,]
576
+ b_A = tl.load(p_A)
577
+ # [BV,]
578
+ b_g = tl.load(p_g, mask=m_v, other=0)
579
+ b_do = tl.load(p_do, mask=m_v, other=0)
580
+ # [BC, BV]
581
+ m_i = o_i[:, None] <= j
582
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
583
+
584
+ p_g += (1 if HEAD_FIRST else H) * V
585
+ p_A += (1 if HEAD_FIRST else HQ) * BT
586
+ p_do += (1 if HEAD_FIRST else HQ) * V
587
+ if HEAD_FIRST:
588
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
589
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
590
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
591
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
592
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
593
+ else:
594
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
595
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
596
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
597
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
598
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
599
+
600
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
601
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
602
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
603
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
604
+ b_dg = b_o * b_do - b_v * b_dv
605
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
606
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
607
+
608
+
609
+ def chunk_gsa_fwd_v(
610
+ q: torch.Tensor,
611
+ k: torch.Tensor,
612
+ v: torch.Tensor,
613
+ g: torch.Tensor,
614
+ scale: float = 1.,
615
+ initial_state: Optional[torch.Tensor] = None,
616
+ output_final_state: bool = False,
617
+ offsets: Optional[torch.LongTensor] = None,
618
+ indices: Optional[torch.LongTensor] = None,
619
+ head_first: bool = True,
620
+ chunk_size: int = 64
621
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
622
+ _, A, h, ht, o = chunk_gla_fwd(
623
+ q=q,
624
+ k=k,
625
+ v=v,
626
+ g=None,
627
+ g_cumsum=g,
628
+ scale=scale,
629
+ initial_state=initial_state,
630
+ output_final_state=output_final_state,
631
+ offsets=offsets,
632
+ indices=indices,
633
+ head_first=head_first,
634
+ chunk_size=chunk_size
635
+ )
636
+ return A, h, ht, o
637
+
638
+
639
+ def chunk_gsa_fwd_k(
640
+ q: torch.Tensor,
641
+ k: torch.Tensor,
642
+ v: torch.Tensor,
643
+ g: torch.Tensor,
644
+ h0: Optional[torch.Tensor] = None,
645
+ output_final_state: bool = False,
646
+ scale: float = 1.,
647
+ offsets: Optional[torch.LongTensor] = None,
648
+ indices: Optional[torch.LongTensor] = None,
649
+ head_first: bool = True,
650
+ chunk_size: int = 64
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ if head_first:
653
+ B, H, T, K, V = *k.shape, v.shape[-1]
654
+ else:
655
+ B, T, H, K, V = *k.shape, v.shape[-1]
656
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
657
+ BC = min(16, BT)
658
+ BV = min(64, triton.next_power_of_2(V))
659
+ HQ = q.shape[1] if head_first else q.shape[2]
660
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
661
+ NC = triton.cdiv(BT, BC)
662
+ NG = HQ // H
663
+
664
+ h, ht = chunk_fwd_h(
665
+ k=k,
666
+ v=v,
667
+ g=None,
668
+ gk=None,
669
+ gv=g,
670
+ h0=h0,
671
+ output_final_state=output_final_state,
672
+ offsets=offsets,
673
+ head_first=head_first,
674
+ chunk_size=BT,
675
+ states_in_fp32=False
676
+ )
677
+ o = v.new_empty(B, *((HQ, T) if head_first else (T, HQ)), V)
678
+ A = q.new_empty(B, *((HQ, T) if head_first else (T, HQ)), BT)
679
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
680
+ chunk_gsa_fwd_k_kernel_inter[grid](
681
+ q,
682
+ k,
683
+ h,
684
+ g,
685
+ o,
686
+ A,
687
+ offsets=offsets,
688
+ indices=indices,
689
+ scale=scale,
690
+ T=T,
691
+ HQ=HQ,
692
+ H=H,
693
+ K=K,
694
+ V=V,
695
+ BT=BT,
696
+ NG=NG,
697
+ HEAD_FIRST=head_first
698
+ )
699
+
700
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
701
+ chunk_gsa_fwd_k_kernel_intra[grid](
702
+ v,
703
+ g,
704
+ o,
705
+ A,
706
+ offsets=offsets,
707
+ indices=indices,
708
+ T=T,
709
+ HQ=HQ,
710
+ H=H,
711
+ V=V,
712
+ BT=BT,
713
+ BC=BC,
714
+ BV=BV,
715
+ NC=NC,
716
+ NG=NG,
717
+ HEAD_FIRST=head_first,
718
+ num_warps=4,
719
+ num_stages=2
720
+ )
721
+ return A, h, ht, o
722
+
723
+
724
+ def chunk_gsa_bwd_v(
725
+ q: torch.Tensor,
726
+ k: torch.Tensor,
727
+ v: torch.Tensor,
728
+ g: torch.Tensor,
729
+ h0: torch.Tensor,
730
+ h: torch.Tensor,
731
+ A: torch.Tensor,
732
+ do: torch.Tensor,
733
+ dht: torch.Tensor,
734
+ dg: torch.Tensor,
735
+ scale: float = 1.,
736
+ offsets: Optional[torch.LongTensor] = None,
737
+ indices: Optional[torch.LongTensor] = None,
738
+ head_first: bool = True,
739
+ chunk_size: int = 64
740
+ ):
741
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
742
+ q=q,
743
+ k=k,
744
+ v=v,
745
+ g=None,
746
+ g_cumsum=g,
747
+ scale=scale,
748
+ initial_state=h0,
749
+ h=h,
750
+ A=A,
751
+ do=do,
752
+ dht=dht,
753
+ offsets=offsets,
754
+ indices=indices,
755
+ head_first=head_first,
756
+ chunk_size=chunk_size
757
+ )
758
+ return dq, dk, dv, dg, dh0
759
+
760
+
761
+ def chunk_gsa_bwd_k(
762
+ q: torch.Tensor,
763
+ k: torch.Tensor,
764
+ v: torch.Tensor,
765
+ g: torch.Tensor,
766
+ h: torch.Tensor,
767
+ h0: torch.Tensor,
768
+ o: torch.Tensor,
769
+ do: torch.Tensor,
770
+ dht: torch.Tensor,
771
+ dg: torch.Tensor,
772
+ scale: float = 1.,
773
+ offsets: Optional[torch.LongTensor] = None,
774
+ indices: Optional[torch.LongTensor] = None,
775
+ head_first: bool = True,
776
+ chunk_size: int = 64
777
+ ):
778
+ if head_first:
779
+ B, H, T, K, V = *k.shape, v.shape[-1]
780
+ else:
781
+ B, T, H, K, V = *k.shape, v.shape[-1]
782
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
783
+ BC = min(16, BT)
784
+ BK = min(64, triton.next_power_of_2(K))
785
+ BV = min(64, triton.next_power_of_2(V))
786
+ HQ = q.shape[1] if head_first else q.shape[2]
787
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
788
+ NC = triton.cdiv(BT, BC)
789
+ NK = triton.cdiv(K, BK)
790
+ NV = triton.cdiv(V, BV)
791
+ NG = HQ // H
792
+
793
+ if h is None:
794
+ h, _ = chunk_fwd_h(
795
+ k=k,
796
+ v=v,
797
+ g=None,
798
+ gk=None,
799
+ gv=g,
800
+ h0=h0,
801
+ output_final_state=False,
802
+ offsets=offsets,
803
+ head_first=head_first,
804
+ chunk_size=BT,
805
+ states_in_fp32=False
806
+ )
807
+ dh, dh0 = chunk_bwd_dh(
808
+ q=q,
809
+ k=k,
810
+ v=v,
811
+ g=None,
812
+ gk=None,
813
+ gv=g,
814
+ do=do,
815
+ h0=h0,
816
+ dht=dht,
817
+ scale=scale,
818
+ offsets=offsets,
819
+ head_first=head_first,
820
+ chunk_size=BT,
821
+ states_in_fp32=True
822
+ )
823
+ dA = q.new_empty(NV, B, *((HQ, T) if head_first else (T, HQ)), BT)
824
+ grid = (NV, NT * NC * NC, B * HQ)
825
+ chunk_gsa_bwd_k_kernel_dA[grid](
826
+ v,
827
+ g,
828
+ do,
829
+ dA,
830
+ offsets=offsets,
831
+ indices=indices,
832
+ scale=scale,
833
+ T=T,
834
+ B=B,
835
+ HQ=HQ,
836
+ H=H,
837
+ V=V,
838
+ BT=BT,
839
+ BC=BC,
840
+ BV=BV,
841
+ NC=NC,
842
+ NG=NG,
843
+ HEAD_FIRST=head_first
844
+ )
845
+ dA = dA.sum(0, dtype=dA.dtype)
846
+
847
+ A = do.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), BT)
848
+ dq = torch.empty_like(q)
849
+ dk = k.new_empty(B, *((HQ, T) if head_first else (T, HQ)), K)
850
+ dv = v.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V)
851
+ dgv = g.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V, dtype=torch.float)
852
+ grid = (NK, NT, B * HQ)
853
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
854
+ q,
855
+ k,
856
+ v,
857
+ h,
858
+ g,
859
+ A,
860
+ do,
861
+ dh,
862
+ dq,
863
+ dk,
864
+ dv,
865
+ dg,
866
+ dgv,
867
+ dA,
868
+ offsets=offsets,
869
+ indices=indices,
870
+ scale=scale,
871
+ T=T,
872
+ B=B,
873
+ HQ=HQ,
874
+ H=H,
875
+ K=K,
876
+ V=V,
877
+ BT=BT,
878
+ BK=BK,
879
+ BV=BV,
880
+ NG=NG,
881
+ HEAD_FIRST=head_first
882
+ )
883
+ A = A.sum(0, dtype=A.dtype)
884
+ dv = dv.sum(0, dtype=dv.dtype)
885
+ dgv = dgv.sum(0, dtype=dgv.dtype)
886
+
887
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
888
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
889
+ v,
890
+ g,
891
+ o,
892
+ A,
893
+ do,
894
+ dv,
895
+ dg,
896
+ offsets=offsets,
897
+ indices=indices,
898
+ T=T,
899
+ HQ=HQ,
900
+ H=H,
901
+ V=V,
902
+ BT=BT,
903
+ BC=BC,
904
+ BV=BV,
905
+ NC=NC,
906
+ NG=NG,
907
+ HEAD_FIRST=head_first,
908
+ num_warps=4,
909
+ num_stages=2
910
+ )
911
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets, indices=indices, head_first=head_first))
912
+
913
+ return dq, dk, dv, dg, dh0
914
+
915
+
916
+ def chunk_gsa_fwd(
917
+ q: torch.Tensor,
918
+ k: torch.Tensor,
919
+ v: torch.Tensor,
920
+ s: torch.Tensor,
921
+ g: torch.Tensor,
922
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
923
+ output_final_state: bool = False,
924
+ scale: float = 1.,
925
+ offsets: Optional[torch.LongTensor] = None,
926
+ indices: Optional[torch.LongTensor] = None,
927
+ head_first: bool = True,
928
+ chunk_size: int = 64
929
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
930
+ hk0, hv0 = None, None
931
+ if initial_state is not None:
932
+ hk0, hv0 = initial_state
933
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
934
+ q=q,
935
+ k=k,
936
+ v=s,
937
+ g=g,
938
+ h0=hk0,
939
+ output_final_state=output_final_state,
940
+ scale=scale,
941
+ offsets=offsets,
942
+ indices=indices,
943
+ head_first=head_first,
944
+ chunk_size=chunk_size
945
+ )
946
+
947
+ # p is kept in fp32 for safe softmax backward
948
+ p = softmax_fwd(ok, dtype=torch.float)
949
+
950
+ qv = p.to(q.dtype)
951
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
952
+ q=qv,
953
+ k=s,
954
+ v=v,
955
+ g=g,
956
+ scale=1.,
957
+ initial_state=hv0,
958
+ output_final_state=output_final_state,
959
+ offsets=offsets,
960
+ indices=indices,
961
+ head_first=head_first,
962
+ chunk_size=chunk_size
963
+ )
964
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
965
+
966
+
967
+ def chunk_gsa_bwd(
968
+ q: torch.Tensor,
969
+ k: torch.Tensor,
970
+ v: torch.Tensor,
971
+ s: torch.Tensor,
972
+ g: torch.Tensor,
973
+ ok: torch.Tensor,
974
+ p: torch.Tensor,
975
+ A: Tuple[torch.Tensor, torch.Tensor],
976
+ h: Tuple[torch.Tensor, torch.Tensor],
977
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
978
+ scale: float,
979
+ do: torch.Tensor,
980
+ dht: Tuple[torch.Tensor, torch.Tensor],
981
+ offsets: Optional[torch.LongTensor] = None,
982
+ indices: Optional[torch.LongTensor] = None,
983
+ head_first: bool = True,
984
+ chunk_size: int = 64
985
+ ):
986
+ hk0, hv0 = None, None
987
+ if initial_state is not None:
988
+ hk0, hv0 = initial_state
989
+
990
+ _, Av = A
991
+ hk, hv = h
992
+ dhkt, dhvt = dht
993
+
994
+ qv = p.to(q.dtype)
995
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
996
+ q=qv,
997
+ k=s,
998
+ v=v,
999
+ g=g,
1000
+ h0=hv0,
1001
+ h=hv,
1002
+ A=Av,
1003
+ do=do,
1004
+ dht=dhvt,
1005
+ dg=None,
1006
+ scale=1.,
1007
+ offsets=offsets,
1008
+ indices=indices,
1009
+ head_first=head_first,
1010
+ chunk_size=chunk_size
1011
+ )
1012
+
1013
+ # softmax gradient, equivalent to:
1014
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
1015
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
1016
+
1017
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
1018
+ q=q,
1019
+ k=k,
1020
+ v=s,
1021
+ g=g,
1022
+ h0=hk0,
1023
+ h=hk,
1024
+ o=ok,
1025
+ do=dok,
1026
+ dht=dhkt,
1027
+ dg=dg,
1028
+ scale=scale,
1029
+ offsets=offsets,
1030
+ indices=indices,
1031
+ head_first=head_first,
1032
+ chunk_size=chunk_size
1033
+ )
1034
+
1035
+ ds = dsv.add_(dsk)
1036
+ if q.shape[1] != k.shape[1]:
1037
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
1038
+ dg = dg.to(s.dtype)
1039
+ return dq, dk, dv, ds, dg, dhk0, dhv0
1040
+
1041
+
1042
+ class ChunkGSAFunction(torch.autograd.Function):
1043
+
1044
+ @staticmethod
1045
+ @input_guard
1046
+ def forward(
1047
+ ctx,
1048
+ q: torch.Tensor,
1049
+ k: torch.Tensor,
1050
+ v: torch.Tensor,
1051
+ s: torch.Tensor,
1052
+ g: torch.Tensor,
1053
+ scale: float,
1054
+ hk0: Optional[torch.Tensor],
1055
+ hv0: Optional[torch.Tensor],
1056
+ output_final_state: bool,
1057
+ checkpoint_level: int,
1058
+ offsets: Optional[torch.LongTensor],
1059
+ head_first: bool = True
1060
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1061
+ T = q.shape[2] if head_first else q.shape[1]
1062
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1063
+
1064
+ # 2-d indices denoting the offsets of chunks in each sequence
1065
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1066
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1067
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1068
+ indices = None
1069
+ if offsets is not None:
1070
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
1071
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
1072
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1073
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
1074
+ q=q,
1075
+ k=k,
1076
+ v=v,
1077
+ s=s,
1078
+ g=g,
1079
+ initial_state=(hk0, hv0),
1080
+ output_final_state=output_final_state,
1081
+ scale=scale,
1082
+ offsets=offsets,
1083
+ indices=indices,
1084
+ head_first=head_first,
1085
+ chunk_size=chunk_size
1086
+ )
1087
+
1088
+ if checkpoint_level >= 1:
1089
+ del g
1090
+ g = g_org
1091
+ if checkpoint_level > 1:
1092
+ del hk
1093
+ del hv
1094
+ hk, hv = None, None
1095
+ else:
1096
+ hk0, hv0 = None, None
1097
+
1098
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
1099
+ ctx.checkpoint_level = checkpoint_level
1100
+ ctx.scale = scale
1101
+ ctx.offsets = offsets
1102
+ ctx.indices = indices
1103
+ ctx.head_first = head_first
1104
+ ctx.chunk_size = chunk_size
1105
+ return ov, hkt, hvt
1106
+
1107
+ @staticmethod
1108
+ @input_guard
1109
+ def backward(ctx, dov, dhkt=None, dhvt=None):
1110
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
1111
+ scale = ctx.scale
1112
+ offsets = ctx.offsets
1113
+ indices = ctx.indices
1114
+ head_first = ctx.head_first
1115
+ chunk_size = ctx.chunk_size
1116
+
1117
+ if ctx.checkpoint_level >= 1:
1118
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1119
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
1120
+ q=q,
1121
+ k=k,
1122
+ v=v,
1123
+ s=s,
1124
+ g=g,
1125
+ ok=ok,
1126
+ p=p,
1127
+ A=(None, Av),
1128
+ h=(hk, hv),
1129
+ initial_state=(hk0, hv0),
1130
+ scale=scale,
1131
+ do=dov,
1132
+ dht=(dhkt, dhvt),
1133
+ offsets=offsets,
1134
+ indices=indices,
1135
+ head_first=head_first,
1136
+ chunk_size=chunk_size
1137
+ )
1138
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
1139
+
1140
+
1141
+ @torch.compiler.disable
1142
+ def chunk_gsa(
1143
+ q: torch.Tensor,
1144
+ k: torch.Tensor,
1145
+ v: torch.Tensor,
1146
+ s: torch.Tensor,
1147
+ g: Optional[torch.Tensor] = None,
1148
+ scale: Optional[int] = None,
1149
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1150
+ output_final_state: Optional[bool] = False,
1151
+ checkpoint_level: Optional[int] = 2,
1152
+ cu_seqlens: Optional[torch.LongTensor] = None,
1153
+ head_first: Optional[bool] = True
1154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1155
+ r"""
1156
+ Args:
1157
+ q (torch.Tensor):
1158
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
1159
+ k (torch.Tensor):
1160
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1161
+ GQA is performed if `H` is not equal to `HQ`.
1162
+ v (torch.Tensor):
1163
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1164
+ s (torch.Tensor):
1165
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`.
1166
+ g (torch.Tensor):
1167
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1168
+ If not provided, this function is equivalent to vanilla ABC.
1169
+ scale (Optional[int]):
1170
+ Scale factor for attention scores.
1171
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1172
+ initial_state (Optional[Tuple[torch.Tensor]]):
1173
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1174
+ For equal-length input sequences, `N` equals the batch size `B`.
1175
+ Default: `None`.
1176
+ output_final_state (Optional[bool]):
1177
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1178
+ Default: `False`.
1179
+ checkpoint_level (Optional[int]):
1180
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1181
+ Default: `2`:
1182
+ - Level `0`: no memory saved, no recomputation.
1183
+ - Level `1`: recompute the fp32 cumulative values during backward.
1184
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1185
+ cu_seqlens (torch.LongTensor):
1186
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1187
+ consistent with the FlashAttention API.
1188
+ head_first (Optional[bool]):
1189
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1190
+ Default: `True`.
1191
+
1192
+ Returns:
1193
+ o (torch.Tensor):
1194
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1195
+ final_state (Tuple[torch.Tensor]):
1196
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1197
+ `None` otherwise.
1198
+
1199
+ Examples::
1200
+ >>> import torch
1201
+ >>> import torch.nn.functional as F
1202
+ >>> from einops import rearrange
1203
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1204
+ # inputs with equal lengths
1205
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1206
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1207
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1208
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1209
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1210
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1211
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1212
+ >>> o, (hk, hv) = chunk_gsa(q, k, v, s, g,
1213
+ initial_state=h0,
1214
+ output_final_state=True,
1215
+ head_first=False)
1216
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1217
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1218
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1219
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1220
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g,
1221
+ initial_state=h0,
1222
+ output_final_state=True,
1223
+ cu_seqlens=cu_seqlens,
1224
+ head_first=False)
1225
+ >>> assert o.allclose(o_var.view(o.shape))
1226
+ >>> assert hk.allclose(hk_var)
1227
+ >>> assert hv.allclose(hv_var)
1228
+ """
1229
+ if cu_seqlens is not None:
1230
+ if q.shape[0] != 1:
1231
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1232
+ f"Please flatten variable-length inputs before processing.")
1233
+ if head_first:
1234
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1235
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1236
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1237
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.")
1238
+ assert checkpoint_level in [0, 1, 2]
1239
+ if g is None:
1240
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1241
+ z = s.float().logcumsumexp(2)
1242
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
1243
+ s = torch.exp(s - z).to(k.dtype)
1244
+ if scale is None:
1245
+ scale = q.shape[-1] ** -0.5
1246
+
1247
+ hk0, hv0 = None, None
1248
+ if initial_state is not None:
1249
+ hk0, hv0 = initial_state
1250
+ o, *final_state = ChunkGSAFunction.apply(
1251
+ q,
1252
+ k,
1253
+ v,
1254
+ s,
1255
+ g,
1256
+ scale,
1257
+ hk0,
1258
+ hv0,
1259
+ output_final_state,
1260
+ checkpoint_level,
1261
+ cu_seqlens,
1262
+ head_first
1263
+ )
1264
+ return o, final_state
fla/ops/gsa/naive.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
fla/ops/hgrn/chunk.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # this function implements the chunkwise form of HGRN, inspired by
5
+ # [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
6
+ # also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
7
+
8
+ # from tests on H800, with B, D = 16, 128, we see that the chunk can be greatly faster than the recurrent:
9
+ #
10
+ # Performance:
11
+ # seq_len chunk recurrent chunk_bwd recurrent_bwd
12
+ # 0 128.0 0.039360 0.061056 0.312160 0.205008
13
+ # 1 256.0 0.045824 0.123712 0.308784 0.297696
14
+ # 2 512.0 0.058688 0.241952 0.310720 0.626528
15
+ # 3 1024.0 0.088288 0.476992 0.313184 1.333152
16
+ # 4 2048.0 0.169472 0.943264 0.452464 2.724864
17
+ # 5 4096.0 0.329920 1.886144 0.881600 5.551520
18
+ # 6 8192.0 0.647872 3.755040 1.740496 11.117184
19
+ # 7 16384.0 1.272064 7.520576 3.446608 22.362528
20
+
21
+ from typing import Tuple
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+ from fla.ops.utils.op import exp
28
+ from fla.utils import input_guard
29
+
30
+
31
+ @triton.autotune(
32
+ configs=[
33
+ triton.Config({'BD': 32}, num_warps=1),
34
+ triton.Config({'BD': 32}, num_warps=2),
35
+ triton.Config({'BD': 32}, num_warps=4),
36
+ triton.Config({'BD': 32}, num_warps=8),
37
+ triton.Config({'BD': 64}, num_warps=1),
38
+ triton.Config({'BD': 64}, num_warps=2),
39
+ triton.Config({'BD': 64}, num_warps=4),
40
+ triton.Config({'BD': 64}, num_warps=8),
41
+ triton.Config({'BD': 128}, num_warps=1),
42
+ triton.Config({'BD': 128}, num_warps=2),
43
+ triton.Config({'BD': 128}, num_warps=4),
44
+ triton.Config({'BD': 128}, num_warps=8),
45
+ ],
46
+ key=['D']
47
+ )
48
+ @triton.jit(do_not_specialize=['T'])
49
+ def chunk_hgrn_fwd_kernel_h(
50
+ x,
51
+ g,
52
+ gc,
53
+ o,
54
+ h0,
55
+ T,
56
+ D: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BD: tl.constexpr,
59
+ USE_INITIAL_STATE: tl.constexpr
60
+ ):
61
+ i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ o_d = i_d * BD + tl.arange(0, BD)
63
+ mask = o_d < D
64
+
65
+ p_x = x + i_b * T * D + i_t * BT * D + o_d
66
+ p_g = g + i_b * T * D + i_t * BT * D + o_d
67
+ p_gc = gc + i_b * T * D + i_t * BT * D + o_d
68
+ p_o = o + i_b * T * D + i_t * BT * D + o_d
69
+
70
+ b_h = tl.zeros([BD], dtype=tl.float32)
71
+ b_gc = tl.zeros([BD], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ if i_t == 0:
74
+ b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32)
75
+ for i in range(0, BT):
76
+ mask_t = mask & ((i_t * BT + i) < T)
77
+ b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
78
+ b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
79
+ b_h = exp(b_g) * b_h + b_x
80
+ b_gc = b_gc + b_g
81
+ tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
82
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
83
+
84
+ p_x += D
85
+ p_g += D
86
+ p_gc += D
87
+ p_o += D
88
+
89
+
90
+ @triton.jit(do_not_specialize=['T'])
91
+ def chunk_hgrn_fwd_kernel_o(
92
+ gc,
93
+ o,
94
+ s_b,
95
+ s_t,
96
+ s_d,
97
+ T,
98
+ D: tl.constexpr,
99
+ BT: tl.constexpr,
100
+ BD: tl.constexpr
101
+ ):
102
+ i_d, i_b = tl.program_id(0), tl.program_id(1)
103
+ o_d = i_d * BD + tl.arange(0, BD)
104
+ mask = o_d < D
105
+
106
+ for i_t in range(1, tl.cdiv(T, BT)):
107
+ p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
108
+ p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
109
+
110
+ # [BD,]
111
+ b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
112
+ # [BT, BD]
113
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
114
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
115
+ b_o = b_o + exp(b_gc) * b_h0[None, :]
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.autotune(
120
+ configs=[
121
+ triton.Config({'BD': BD}, num_warps=num_warps)
122
+ for BD in [32, 64, 128]
123
+ for num_warps in [1, 2, 4, 8]
124
+ ],
125
+ key=['D']
126
+ )
127
+ @triton.jit(do_not_specialize=['T'])
128
+ def chunk_hgrn_bwd_kernel_h(
129
+ g,
130
+ gc,
131
+ dx,
132
+ do,
133
+ T,
134
+ D: tl.constexpr,
135
+ BT: tl.constexpr,
136
+ BD: tl.constexpr
137
+ ):
138
+ i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
139
+ o_d = i_d * BD + tl.arange(0, BD)
140
+ mask = o_d < D
141
+ BC = min(BT, T - i_t * BT)
142
+ NT = tl.num_programs(1)
143
+
144
+ p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d
145
+ p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d
146
+ p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d
147
+ p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d
148
+
149
+ if i_t == NT - 1:
150
+ b_gc = tl.zeros([BD], dtype=tl.float32)
151
+ else:
152
+ b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
153
+ b_dh = tl.zeros([BD], dtype=tl.float32)
154
+ for _ in range(BC - 1, -1, -1):
155
+ tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
156
+
157
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
158
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
159
+
160
+ b_gc = b_gc + b_g
161
+ b_dh = b_dh + b_do
162
+ b_dx = b_dh
163
+ b_dh = b_dh * exp(b_g)
164
+
165
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
166
+
167
+ p_g -= D
168
+ p_gc -= D
169
+ p_dx -= D
170
+ p_do -= D
171
+
172
+
173
+ @triton.jit(do_not_specialize=['T'])
174
+ def chunk_hgrn_bwd_kernel_o(
175
+ g,
176
+ gc,
177
+ o,
178
+ dx,
179
+ dg,
180
+ s_b,
181
+ s_t,
182
+ s_d,
183
+ T,
184
+ D: tl.constexpr,
185
+ BT: tl.constexpr,
186
+ BD: tl.constexpr
187
+ ):
188
+ i_d, i_b = tl.program_id(0), tl.program_id(1)
189
+ o_d = i_d * BD + tl.arange(0, BD)
190
+ mask = o_d < D
191
+
192
+ for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
193
+ p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
194
+ p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
195
+ p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
196
+ p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
197
+ p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
198
+
199
+ # [BD,]
200
+ mask_t = mask & ((i_t + 1) * BT < T)
201
+ b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
202
+ # [BT, BD]
203
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
204
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
205
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
206
+ b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
207
+
208
+ b_dx = b_dx + exp(b_gc) * b_ht[None, :]
209
+ b_dg = b_o * b_dx * exp(b_g)
210
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
212
+
213
+
214
+ class ChunkHGRNFunction(torch.autograd.Function):
215
+
216
+ @staticmethod
217
+ @input_guard
218
+ def forward(ctx, x, g, initial_state=None, output_final_state=False):
219
+ B, T, D = x.shape
220
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
221
+ num_warps = 8 if BD == 64 else 4
222
+
223
+ gc = torch.empty_like(g, dtype=torch.float)
224
+ o = torch.empty_like(x, dtype=torch.float)
225
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B)
226
+ chunk_hgrn_fwd_kernel_h[grid](
227
+ x, g, gc, o, initial_state,
228
+ T=T, D=D, BT=BT,
229
+ USE_INITIAL_STATE=initial_state is not None
230
+ )
231
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B)
232
+ chunk_hgrn_fwd_kernel_o[grid](
233
+ gc, o,
234
+ o.stride(-3), o.stride(-2), o.stride(-1),
235
+ T=T, D=D, BT=BT, BD=BD,
236
+ num_warps=num_warps
237
+ )
238
+ final_state = None
239
+ if output_final_state:
240
+ final_state = o[:, -1].clone()
241
+ o = o.to(x.dtype)
242
+ ctx.save_for_backward(g, o, initial_state)
243
+ return o, final_state
244
+
245
+ @staticmethod
246
+ @input_guard
247
+ def backward(ctx, do, dht=None):
248
+ g, o, initial_state = ctx.saved_tensors
249
+ B, T, D = do.shape
250
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
251
+ num_warps = 8 if BD == 64 else 4
252
+
253
+ gc = torch.empty_like(g, dtype=torch.float)
254
+ dx = torch.empty_like(o, dtype=torch.float)
255
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B)
256
+ chunk_hgrn_bwd_kernel_h[grid](
257
+ g, gc, dx, do,
258
+ T=T, D=D, BT=BT
259
+ )
260
+
261
+ dg = torch.empty_like(g, dtype=torch.float)
262
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B)
263
+ chunk_hgrn_bwd_kernel_o[grid](
264
+ g, gc, o, dx, dg,
265
+ o.stride(-3), o.stride(-2), o.stride(-1),
266
+ T=T, D=D, BT=BT, BD=BD,
267
+ num_warps=num_warps
268
+ )
269
+ if initial_state is not None:
270
+ dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype)
271
+
272
+ return dx.to(o.dtype), dg, None, None
273
+
274
+
275
+ @torch.compiler.disable
276
+ def chunk_hgrn(
277
+ x: torch.Tensor,
278
+ g: torch.Tensor,
279
+ initial_state: torch.Tensor = None,
280
+ output_final_state: bool = False
281
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
282
+ return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
fla/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, T, D = x.shape
17
+
18
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, i].exp() * h + x[:, i]
27
+ o[:, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, T, D = x.shape
44
+
45
+ gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, j].exp() * h + x[:, j]
58
+ o[:, j] = hp * gc[:, j].exp() + h
59
+ h = o[:, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (315 Bytes). View file
 
fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/ops/nsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .naive import naive_nsa
4
+ from .parallel import parallel_nsa
5
+
6
+ __all__ = [
7
+ 'naive_nsa',
8
+ 'parallel_nsa'
9
+ ]
fla/ops/rebased/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (226 Bytes). View file
 
fla/ops/rebased/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
fla/ops/retention/fused_chunk.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from packaging import version
10
+
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.jit(do_not_specialize=['T'])
15
+ def fused_chunk_retention_fwd_kernel(
16
+ q,
17
+ k,
18
+ v,
19
+ o,
20
+ h0,
21
+ ht,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BK: tl.constexpr,
30
+ BV: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr,
33
+ CHECK: tl.constexpr
34
+ ):
35
+ # indices
36
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
37
+ i_h = i_bh % H
38
+
39
+ o_i = tl.arange(0, BT)
40
+ # decay rate given the head index
41
+ b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))
42
+
43
+ # d_b: overall decay for the entire chunk
44
+ # d_o: cumulative decay from the start of the chunk
45
+ # d_h: cumulative decay from the end of the chunk
46
+ d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
47
+
48
+ # [BT, BT]
49
+ m_s = o_i[:, None] >= o_i[None, :]
50
+ d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
51
+ # [BK, BV]
52
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
53
+
54
+ # make block pointers
55
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
56
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
57
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
58
+ p_o = tl.make_block_ptr(o + (i_k*B*H+i_bh).to(tl.int64) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
59
+
60
+ if USE_INITIAL_STATE:
61
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
62
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
63
+
64
+ NT = tl.cdiv(T, BT)
65
+ for i in range(0, NT):
66
+ # [BT, BK]
67
+ b_q = tl.load(p_q, boundary_check=(0, 1))
68
+ b_q = (b_q * scale).to(b_q.dtype)
69
+ # [BK, BT]
70
+ b_k = tl.load(p_k, boundary_check=(0, 1))
71
+ # [BT, BV]
72
+ b_v = tl.load(p_v, boundary_check=(0, 1))
73
+
74
+ # [BT, BT]
75
+ b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
76
+ # [BT, BV]
77
+ b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
78
+ if CHECK and i == 0:
79
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
80
+ b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
81
+ else:
82
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
83
+ if i == NT - 1 and (T % BT) != 0:
84
+ d_b = tl.math.exp2((T % BT) * b_b)
85
+ d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
86
+ b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
87
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
88
+
89
+ p_q = tl.advance(p_q, (BT, 0))
90
+ p_k = tl.advance(p_k, (0, BT))
91
+ p_v = tl.advance(p_v, (BT, 0))
92
+ p_o = tl.advance(p_o, (BT, 0))
93
+
94
+ if STORE_FINAL_STATE:
95
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
96
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+
99
+ @triton.jit(do_not_specialize=['T'])
100
+ def fused_chunk_retention_bwd_kernel(
101
+ q,
102
+ k,
103
+ v,
104
+ do,
105
+ dq,
106
+ dk,
107
+ dv,
108
+ h0,
109
+ scale,
110
+ T,
111
+ B: tl.constexpr,
112
+ H: tl.constexpr,
113
+ K: tl.constexpr,
114
+ V: tl.constexpr,
115
+ BT: tl.constexpr,
116
+ BK: tl.constexpr,
117
+ BV: tl.constexpr,
118
+ USE_INITIAL_STATE: tl.constexpr,
119
+ CHECK: tl.constexpr
120
+ ):
121
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
122
+ i_h = i_bh % H
123
+
124
+ o_i = tl.arange(0, BT)
125
+ b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))
126
+ d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)
127
+ d_b = tl.math.exp2(BT * b_b)
128
+
129
+ m_s = o_i[:, None] >= o_i[None, :]
130
+ d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
131
+ # [BV, BK]
132
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
133
+ if USE_INITIAL_STATE:
134
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
135
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
136
+
137
+ for i in range(0, tl.cdiv(T, BT)):
138
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
139
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
140
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
141
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
142
+
143
+ # [BT, K]
144
+ b_k = tl.load(p_k, boundary_check=(0, 1))
145
+ # [V, BT]
146
+ b_v = tl.load(p_v, boundary_check=(0, 1))
147
+ # [BT, V]
148
+ b_do = tl.load(p_do, boundary_check=(0, 1))
149
+ b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
150
+
151
+ # [BT, BT]
152
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
153
+ b_ds = (b_ds * d_s).to(b_k.dtype)
154
+ # [BT, K]
155
+ b_dq = tl.dot(b_ds, b_k, allow_tf32=False)
156
+ # [V, K]
157
+ if CHECK and i == 0:
158
+ b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
159
+ b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
160
+ else:
161
+ b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
162
+ b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
163
+
164
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ # sync threads
167
+ b_h = None
168
+ tl.debug_barrier()
169
+ d_s = tl.trans(d_s)
170
+ # [BK, BV]
171
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
172
+ for i in range(1, tl.cdiv(T, BT) + 1):
173
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
174
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
175
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
176
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
177
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
178
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H).to(tl.int64) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
179
+ # [K, BT]
180
+ b_q = tl.load(p_q, boundary_check=(0, 1))
181
+ # [BT, BK]
182
+ b_k = tl.load(p_k, boundary_check=(0, 1))
183
+ # [BT, BV]
184
+ b_v = tl.load(p_v, boundary_check=(0, 1))
185
+ b_do = tl.load(p_do, boundary_check=(0, 1))
186
+ b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
187
+
188
+ # [BT, BT]
189
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
190
+ b_ds = (b_ds * d_s).to(b_k.dtype)
191
+
192
+ # [BT, BT]
193
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
194
+ # [BT, BK]
195
+ b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
196
+ # [BT, BV]
197
+ b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
198
+ if CHECK and i == 1:
199
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
200
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
201
+ b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
202
+ else:
203
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
204
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
205
+ b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
206
+
207
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
208
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
209
+
210
+
211
+ class FusedChunkRetentionFunction(torch.autograd.Function):
212
+
213
+ @staticmethod
214
+ @input_guard
215
+ @autocast_custom_fwd
216
+ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
217
+ B, H, T, K, V = *k.shape, v.shape[-1]
218
+
219
+ BT = 64
220
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
221
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
222
+ num_stages = 1
223
+ num_warps = 4
224
+
225
+ o = q.new_empty(NK, B, H, T, V)
226
+
227
+ if output_final_state:
228
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
229
+ else:
230
+ final_state = None
231
+ # the bug still exists even for Triton 2.2 on H100 GPUs
232
+ # so we always enable initial checks
233
+ CHECK = True
234
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
235
+ import warnings
236
+ warnings.warn(
237
+ "Triton<2.2.0 detected for running this kernel, "
238
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
239
+ "that lead to significant precision loss. "
240
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
241
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
242
+ )
243
+ CHECK = True
244
+
245
+ grid = (NV, NK, B * H)
246
+ fused_chunk_retention_fwd_kernel[grid](
247
+ q,
248
+ k,
249
+ v,
250
+ o,
251
+ initial_state,
252
+ final_state,
253
+ scale,
254
+ T=T,
255
+ B=B,
256
+ H=H,
257
+ K=K,
258
+ V=V,
259
+ BT=BT,
260
+ BK=BK,
261
+ BV=BV,
262
+ USE_INITIAL_STATE=initial_state is not None,
263
+ STORE_FINAL_STATE=output_final_state,
264
+ CHECK=CHECK,
265
+ num_warps=num_warps,
266
+ num_stages=num_stages
267
+ )
268
+
269
+ o = o.sum(0)
270
+ ctx.save_for_backward(q, k, v, initial_state)
271
+ ctx.CHECK = CHECK
272
+ return o.to(q.dtype), final_state
273
+
274
+ @staticmethod
275
+ @input_guard
276
+ @autocast_custom_bwd
277
+ def backward(ctx, do, dht=None):
278
+ q, k, v, initial_state = ctx.saved_tensors
279
+ B, H, T, K, V = *k.shape, v.shape[-1]
280
+ scale = K ** -0.5
281
+
282
+ BT = 64
283
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
284
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
285
+ num_stages = 1
286
+ num_warps = 4
287
+
288
+ dq = q.new_empty(NV, B, H, T, K)
289
+ dk = q.new_empty(NV, B, H, T, K)
290
+ dv = q.new_empty(NK, B, H, T, V)
291
+ grid = (NV, NK, B * H)
292
+
293
+ fused_chunk_retention_bwd_kernel[grid](
294
+ q,
295
+ k,
296
+ v,
297
+ do,
298
+ dq,
299
+ dk,
300
+ dv,
301
+ initial_state,
302
+ scale,
303
+ T=T,
304
+ B=B,
305
+ H=H,
306
+ K=K,
307
+ V=V,
308
+ BT=BT,
309
+ BK=BK,
310
+ BV=BV,
311
+ USE_INITIAL_STATE=initial_state is not None,
312
+ CHECK=ctx.CHECK,
313
+ num_warps=num_warps,
314
+ num_stages=num_stages
315
+ )
316
+ dq = dq.sum(0)
317
+ dk = dk.sum(0)
318
+ dv = dv.sum(0)
319
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
320
+
321
+
322
+ def fused_chunk_retention(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ v: torch.Tensor,
326
+ scale: Optional[float] = None,
327
+ initial_state: Optional[torch.Tensor] = None,
328
+ output_final_state: bool = False,
329
+ head_first: bool = True
330
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
331
+ r"""
332
+ Args:
333
+ q (torch.Tensor):
334
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
335
+ k (torch.Tensor):
336
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
337
+ v (torch.Tensor):
338
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
339
+ scale (Optional[int]):
340
+ Scale factor for the attention scores.
341
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
342
+ initial_state (Optional[torch.Tensor]):
343
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
344
+ output_final_state (Optional[bool]):
345
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
346
+ head_first (Optional[bool]):
347
+ Whether the inputs are in the head-first format.
348
+ Default: `True`.
349
+
350
+ Returns:
351
+ o (torch.Tensor):
352
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
353
+ final_state (torch.Tensor):
354
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
355
+ """
356
+ if scale is None:
357
+ scale = k.shape[-1] ** -0.5
358
+ if not head_first:
359
+ q = q.transpose(1, 2)
360
+ k = k.transpose(1, 2)
361
+ v = v.transpose(1, 2)
362
+ o, final_state = FusedChunkRetentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
363
+ if not head_first:
364
+ o = o.transpose(1, 2)
365
+ return o, final_state
fla/ops/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
fla/ops/rwkv7/__pycache__/channel_mixing.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
fla/ops/simple_gla/naive.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None):
8
+ if scale is None:
9
+ scale = (q.shape[-1] ** -0.5)
10
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale
11
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
12
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
13
+ g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size)
14
+ g = g.cumsum(-1)
15
+ kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
16
+ S = torch.zeros_like(kv)
17
+
18
+ for i in range(1, g.shape[-2]):
19
+ S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
20
+
21
+ inter = (q * g[..., None].exp()) @ S
22
+ attn = q @ k.transpose(-1, -2)
23
+ attn = attn * (g[..., None] - g[..., None, :]).exp()
24
+ attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
25
+ intra = attn @ v
26
+ o = inter + intra
27
+ return rearrange(o, 'b h n c d -> b h (n c) d')
28
+
29
+
30
+ def torch_simple_gla_recurrent(q, k, v, g, scale=None, initial_state=None, output_final_state=True):
31
+ B, H, T, DK = q.shape
32
+ original_dtype = q.dtype
33
+ q, k, v, g = q.float(), k.float(), v.float(), g.float()
34
+ if scale is None:
35
+ scale = DK ** -0.5
36
+ q = q * scale
37
+ _, _, _, DV = v.shape
38
+ if initial_state is None:
39
+ S = torch.zeros(B, H, DK, DV)
40
+ else:
41
+ S = initial_state
42
+ o = torch.zeros(B, H, T, DV).to(q)
43
+ for i in range(T):
44
+ gate = g[:, :, i].exp()
45
+ key = k[:, :, i]
46
+ value = v[:, :, i]
47
+ kv = key.unsqueeze(-1) * value.unsqueeze(-2)
48
+ S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
49
+ q_i = q[:, :, i, :]
50
+ o_i = (q_i.unsqueeze(-1) * S).sum(-2)
51
+ o[:, :, i] = o_i
52
+ if not output_final_state:
53
+ S = None
54
+ return o.to(original_dtype), S
fla/ops/titans/naive.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from fla.ops.titans.log_impl import combine_params_log
7
+
8
+
9
+ def cal_n(theta, eta, seq_len):
10
+ n = torch.zeros(*theta.shape, seq_len, dtype=theta.dtype).to(
11
+ theta.device
12
+ ) # [batch_size, num_heads, seq_len, seq_len]
13
+
14
+ # 1. deal with diagonal elements
15
+ indices = torch.arange(seq_len, device=theta.device)
16
+ n[..., indices, indices] = theta[..., indices]
17
+
18
+ # 2. Create a cumulative product matrix
19
+ # First create a mask to mark the positions where eta needs to be multiplied
20
+ mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(theta.device)
21
+ # Convert mask to boolean type
22
+ mask = mask.bool()
23
+ # Expand eta to match the target shape
24
+ eta_expanded = eta.unsqueeze(-2).expand(*theta.shape[:-1], seq_len, seq_len)
25
+ # Create a matrix filled with 1s for cumulative product
26
+ cumulative = torch.ones_like(eta_expanded)
27
+ cumulative = torch.where(mask, eta_expanded, cumulative)
28
+ # Calculate the cumulative product
29
+ cumulative_prod = torch.cumprod(cumulative, dim=-1)
30
+
31
+ # 3. Calculate non-diagonal elements
32
+ # Create an expanded version of theta
33
+ theta_expanded = theta.unsqueeze(-1).expand(*theta.shape[:-1], seq_len, seq_len)
34
+ # Create a mask to keep only the upper triangular part (excluding the diagonal)
35
+ upper_triangular = torch.triu(torch.ones_like(n), diagonal=1).bool()
36
+ # Combine theta and cumulative product
37
+ n = torch.where(upper_triangular, theta_expanded * cumulative_prod, n)
38
+ return n
39
+
40
+
41
+ def cal_f(beta, seq_len, m):
42
+ a = torch.tril(beta.to(torch.float32).unsqueeze(-1).expand(*beta.shape, seq_len), 0)
43
+ ratio = (m.to(torch.float32) / beta.to(torch.float32)).unsqueeze(-1)
44
+ f = torch.matmul(a, ratio).squeeze(-1)
45
+ return f.to(beta.dtype)
46
+
47
+
48
+ def cal_G(beta, n, seq_len):
49
+ i_indices = torch.arange(seq_len, device=beta.device)
50
+ j_indices = torch.arange(seq_len, device=beta.device)
51
+ k_indices = torch.arange(seq_len, device=beta.device)
52
+ beta_ratio = beta[..., :, None] / beta[..., None, :] # [..., i, k]
53
+
54
+ # create mask
55
+ k_mask = (k_indices[None, None, :] >= j_indices[None, :, None]) & (
56
+ k_indices[None, None, :] <= i_indices[:, None, None]
57
+ )
58
+
59
+ # use mask to filter out invalid values
60
+ masked_beta_ratio = beta_ratio[..., :, None, :] * k_mask # [..., i, j, k]
61
+ masked_n = n[..., None, :, :] * k_mask # [..., i, j, k]
62
+ # calculate G
63
+ G = torch.sum(masked_beta_ratio * masked_n, dim=-1) # [..., i, j]
64
+ return G
65
+
66
+
67
+ def combine_params(theta, alpha, eta, seq_len):
68
+ theta = theta.squeeze(-1)
69
+ eta = eta.squeeze(-1)
70
+ alpha = alpha.squeeze(-1)
71
+ beta = torch.cumprod(1 - alpha, dim=-1) # β_t = ∏(1 - α_t) in titans paper
72
+ beta_T = beta[..., -1] # β_T
73
+ # Calculate m_i = ∏(k=1 to i) η_k
74
+ m = torch.cumprod(eta, dim=-1) # [batch_size, num_heads, seq_len]
75
+ m_T = m[..., -1] # m_T
76
+ # Calculate n_{i,j}
77
+ # We need to calculate ∏(k=j+1 to i) η_k for each i,j pair
78
+ # # this may be optimized
79
+ # n = torch.zeros(*theta.shape, seq_len, dtype = theta.dtype).to(
80
+ # theta.device) # [batch_size, num_heads, seq_len, seq_len]
81
+ # for i in range(seq_len):
82
+ # for j in range(i + 1):
83
+ # if i == j:
84
+ # n[..., j, i] = theta[..., j]
85
+ # else:
86
+ # # Calculate product of eta from j+1 to i
87
+ # eta_product = torch.prod(eta[..., j + 1:i + 1], dim = -1)
88
+ # n[..., j, i] = theta[..., j] * eta_product
89
+
90
+ n = cal_n(theta, eta, seq_len)
91
+ n_T = n[..., -1] # [batch_size, num_heads, seq_len]
92
+ # Calculate f_t = ∑(i=1 to t) (β_t/β_i) m_i
93
+ # f = torch.zeros_like(theta)
94
+ # for t in range(seq_len):
95
+ # for i in range(t + 1):
96
+ # f[..., t] += (beta[..., t] / beta[..., i]) * m[..., i]
97
+ f = cal_f(beta, seq_len, m)
98
+ f_T = f[..., -1] # [batch_size, num_heads, seq_len]
99
+ # Calculate g_j = ∑(i=j to t) (β_t/β_i) n_{i,j}
100
+ # g = torch.zeros_like(theta) # [batch_size, num_heads, seq_len]
101
+ # for j in range(seq_len):
102
+ # for i in range(j, seq_len):
103
+ # g[..., j] += (beta[..., -1] / beta[..., i]) * n[..., j, i]
104
+ # G = torch.zeros(*beta.shape[:-1], seq_len, seq_len, device = beta.device)
105
+ # # Fill in the lower triangular part
106
+ # for i in range(seq_len): # row
107
+ # for j in range(i + 1): # column
108
+ # # Sum from k=j to i
109
+ # for k in range(j, i + 1):
110
+ # G[..., i, j] += (beta[..., i] / beta[..., k]) * n[..., j, k]
111
+ G = cal_G(beta, n, seq_len)
112
+ g = G[:, :, -1, :] # [batch_size, num_heads, seq_len]
113
+ # g2, G2 = compute_g_and_G(beta, n, seq_len)
114
+ return beta, beta_T, f, f_T, g, G, m_T, n_T
115
+
116
+
117
+ def titans_linear(
118
+ q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state
119
+ ):
120
+ """
121
+ Implementation of Titans Linear function based on the update rules:
122
+ M_t = (1 - alpha_t) * M_{t-1} + S_t
123
+ S_t = eta_t * S_{t-1} - theta_t * nabla_l(M_{t-1}; x_t)
124
+
125
+ Args:
126
+ q: Query tensor
127
+ k: Key tensor
128
+ v: Value tensor
129
+ w: Weight tensor
130
+ b: Bias tensor
131
+ theta: Learning rate tensor
132
+ alpha: Momentum decay tensor
133
+ eta: Step size tensor
134
+ eps: Epsilon for numerical stability
135
+ initial_state: Initial state M_0
136
+ output_final_state: Whether to output the final state
137
+
138
+ Returns:
139
+ Tuple of (output tensor, final state)
140
+ """
141
+ B, H, T, D = q.shape
142
+ device = q.device
143
+ w = w.reshape(H, 1, D).to(torch.float32)
144
+ b = b.reshape(H, 1, D).to(torch.float32)
145
+ # Initialize states
146
+ if initial_state is None:
147
+ M_prev = torch.zeros(B, H, D, D, device=device)
148
+ else:
149
+ M_prev = initial_state
150
+ M_prev_nabla = M_prev.clone()
151
+ S_prev = torch.zeros_like(M_prev)
152
+ outputs = []
153
+
154
+ # Process sequence step by step
155
+ for t in range(T):
156
+ # Get current step inputs
157
+ q_t = q[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
158
+ k_t = k[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
159
+ v_t = v[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
160
+ theta_t = theta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
161
+ alpha_t = alpha[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
162
+ eta_t = eta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim)
163
+
164
+ # Compute gradient
165
+ km = k_t @ M_prev_nabla # (batch_size, num_heads, 1, dim)
166
+ reconstruction_target = v_t - k_t
167
+ mean = km.mean(-1, keepdim=True)
168
+ var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32)
169
+ rstd = torch.sqrt(var + eps).to(torch.float32)
170
+ km_hat = (km - mean) / rstd
171
+
172
+ grad = w * km_hat + b - reconstruction_target
173
+ grad = grad * w
174
+ # v_new = (D * grad - grad.sum(-1, keepdim = True) - km_hat * (grad * km_hat).sum(-1, keepdim = True)) / (
175
+ # rstd * D)
176
+ v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D)
177
+ proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D)
178
+ v_new = v_new - proj_term
179
+ # v_new = grad
180
+
181
+ # Update S_t
182
+ S_t = eta_t * S_prev - 2 * theta_t * k_t.transpose(-2, -1) @ v_new
183
+
184
+ # Update M_t
185
+ M_t = (1 - alpha_t) * M_prev + S_t
186
+
187
+ # Store output
188
+ output_t = q_t @ M_t # (batch_size, num_heads, seq_len, dim)
189
+ mean = output_t.mean(dim=-1, keepdim=True)
190
+ var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
191
+ rstd = torch.sqrt(var + eps).to(torch.float32)
192
+ output_t = output_t + (output_t - mean) / rstd * w + b
193
+ outputs.append(output_t)
194
+
195
+ # Update states for next step
196
+ if (t + 1) % chunk_size == 0:
197
+ M_prev_nabla = M_t.clone()
198
+ M_prev = M_t
199
+ S_prev = S_t
200
+
201
+ # Stack outputs along sequence dimension
202
+ output = torch.stack(outputs, dim=-2).squeeze(
203
+ -3
204
+ ) # (batch_size, num_heads, seq_len, dim)
205
+
206
+ if output_final_state:
207
+ return output, M_prev
208
+ return output, None
209
+
210
+
211
+ def chunk_titans_linear(
212
+ q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state
213
+ ):
214
+ B, H, T, D = q.shape
215
+ num_batch = T // chunk_size
216
+ # [num_batch, B, num_heads, mini_batch_size, head_dim]
217
+ _q = q.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
218
+ _k = k.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
219
+ _v = v.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4)
220
+ # [num_batch, B, num_heads, mini_batch_size, 1]
221
+ _eta = eta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
222
+ _theta = theta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
223
+ _alpha = alpha.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4)
224
+ # [H, 1, D]
225
+ w = w.reshape(H, 1, D).to(torch.float32)
226
+ b = b.reshape(H, 1, D).to(torch.float32)
227
+ # [num_heads, 1, head_dim]
228
+ if initial_state is None:
229
+ M_prev = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to(
230
+ torch.float32
231
+ )
232
+ else:
233
+ M_prev = initial_state
234
+
235
+ S_prev = torch.zeros_like(M_prev)
236
+
237
+ # [num_batch, B, num_heads, mini_batch_size, head_dim]
238
+ o = torch.empty_like(_v)
239
+
240
+ for i in range(num_batch):
241
+ q_i, k_i, v_i, eta_i, theta_i, alpha_i = [
242
+ x[i] for x in [_q, _k, _v, _eta, _theta, _alpha]
243
+ ]
244
+
245
+ # beta, beta_T, f, f_T, g, G, m_T, n = combine_params(theta_i, alpha_i, eta_i, chunk_size)
246
+ beta, beta_T, f, f_T, g, G, m_T, n = combine_params_log(
247
+ theta_i, alpha_i, eta_i, chunk_size
248
+ )
249
+
250
+ m_T = m_T.unsqueeze(-1).unsqueeze(-1)
251
+ beta_T = beta_T.unsqueeze(-1).unsqueeze(-1)
252
+ f_T = f_T.unsqueeze(-1).unsqueeze(-1)
253
+ g_diag = torch.diag_embed(g).to(q_i.dtype)
254
+ n = torch.diag_embed(n).to(q_i.dtype)
255
+ beta = torch.diag_embed(beta).to(q_i.dtype)
256
+ f = torch.diag_embed(f).to(q_i.dtype)
257
+ km = k_i @ M_prev
258
+ reconstruction_target = v_i - k_i
259
+
260
+ mean = km.mean(-1, True)
261
+ var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32)
262
+ rstd = torch.sqrt(var + eps).to(torch.float32)
263
+ km_hat = (km - mean) / rstd
264
+
265
+ grad = w * km_hat + b - reconstruction_target
266
+ grad *= w
267
+ v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D)
268
+ proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D)
269
+ v_new = v_new - proj_term
270
+ # v_new = (D * grad - grad.sum(-1, True))
271
+ # print(f"Projection term stats: min={torch.abs(beta_T).min()}")
272
+
273
+ # v_new = grad
274
+
275
+ Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) * G
276
+
277
+ # o_i
278
+ output_t = beta @ q_i @ M_prev + f @ q_i @ S_prev - 2 * Attn @ v_new
279
+
280
+ M_t = (
281
+ beta_T * M_prev
282
+ + f_T * S_prev
283
+ - 2 * (g_diag @ k_i).transpose(-1, -2) @ v_new
284
+ )
285
+ # cal S_T from S_0
286
+ S_t = m_T * S_prev - 2 * (n @ k_i).transpose(-1, -2) @ v_new
287
+ # layer norm with residuals
288
+ mean = output_t.mean(dim=-1, keepdim=True)
289
+ var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
290
+ rstd = torch.sqrt(var + eps).to(torch.float32)
291
+ output_t = output_t + (output_t - mean) / rstd * w + b
292
+ o[i] = output_t
293
+ S_prev = S_t
294
+ M_prev = M_t
295
+
296
+ # [B, num_mini_batch, mini_batch_size, num_heads, head_dim]
297
+ o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
298
+ M_prev = M_prev if output_final_state else None
299
+ return o, M_prev
300
+
301
+
302
+ # most of the code is copied from ttt
303
+ def chunk_titans_linear_ref(
304
+ q: torch.Tensor,
305
+ k: torch.Tensor,
306
+ v: torch.Tensor,
307
+ w: torch.Tensor,
308
+ b: torch.Tensor,
309
+ theta: torch.Tensor,
310
+ alpha: torch.Tensor,
311
+ eta: torch.Tensor,
312
+ eps: float = 1e-6,
313
+ chunk_size: int = 16, # chunk size
314
+ initial_state: torch.Tensor = None,
315
+ output_final_state: bool = False,
316
+ head_first: bool = True,
317
+ use_chunk: bool = True,
318
+ ):
319
+ assert q.dtype == k.dtype == v.dtype
320
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
321
+ if not head_first:
322
+ q = q.transpose(1, 2)
323
+ k = k.transpose(1, 2)
324
+ v = v.transpose(1, 2)
325
+ eta = eta.transpose(1, 2)
326
+ alpha = alpha.transpose(1, 2)
327
+ theta = theta.transpose(1, 2)
328
+ seq_len = q.shape[-2]
329
+ pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size
330
+ if pad_len > 0:
331
+ q = F.pad(q, (0, 0, 0, pad_len))
332
+ k = F.pad(k, (0, 0, 0, pad_len))
333
+ v = F.pad(v, (0, 0, 0, pad_len))
334
+ theta = F.pad(theta, (0, 0, 0, pad_len))
335
+ alpha = F.pad(alpha, (0, 0, 0, pad_len))
336
+ eta = F.pad(eta, (0, 0, 0, pad_len))
337
+ theta[:, :, -1, :] = theta[:, :, -(pad_len + 1), :]
338
+ alpha[:, :, -1, :] = alpha[:, :, -(pad_len + 1), :]
339
+ eta[:, :, -1, :] = eta[:, :, -(pad_len + 1), :]
340
+ assert q.shape[-2] % chunk_size == 0, "Sequence length should be a multiple of BT."
341
+ q, k, v, w, b = map(lambda x: x.to(torch.float32), [q, k, v, w, b])
342
+ if use_chunk:
343
+ o, final_state = chunk_titans_linear(
344
+ q,
345
+ k,
346
+ v,
347
+ w,
348
+ b,
349
+ theta,
350
+ alpha,
351
+ eta,
352
+ eps,
353
+ chunk_size,
354
+ initial_state,
355
+ output_final_state,
356
+ )
357
+ else:
358
+ o, final_state = titans_linear(
359
+ q,
360
+ k,
361
+ v,
362
+ w,
363
+ b,
364
+ theta,
365
+ alpha,
366
+ eta,
367
+ eps,
368
+ chunk_size,
369
+ initial_state,
370
+ output_final_state,
371
+ )
372
+ o = o[:, :, :seq_len, :]
373
+ if not head_first:
374
+ o = o.transpose(1, 2)
375
+ return o, final_state
fla/ops/ttt/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_ttt_linear
4
+ from .fused_chunk import fused_chunk_ttt_linear
5
+
6
+ __all__ = [
7
+ 'fused_chunk_ttt_linear',
8
+ 'chunk_ttt_linear'
9
+ ]
fla/ops/ttt/fused_chunk.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.modules.layernorm import group_norm
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=1),
23
+ triton.Config({}, num_warps=2),
24
+ triton.Config({}, num_warps=4)
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fused_chunk_ttt_linear_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ eta,
34
+ w,
35
+ b,
36
+ o,
37
+ scale,
38
+ eps,
39
+ h0,
40
+ hb0,
41
+ ht,
42
+ hbt,
43
+ offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ USE_INITIAL_STATE_B: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # indices
58
+ i_nh = tl.program_id(0)
59
+ i_n, i_h = i_nh // H, i_nh % H
60
+ if USE_OFFSETS:
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+
68
+ o_i = tl.arange(0, BT)
69
+ v_i = tl.arange(0, BV)
70
+ m_A = o_i[:, None] >= o_i[None, :]
71
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
72
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # [BV]
77
+ b_hb = tl.zeros([BV], dtype=tl.float32)
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
80
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
81
+ if USE_INITIAL_STATE_B:
82
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
83
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
84
+
85
+ for i_t in range(NT):
86
+ if HEAD_FIRST:
87
+ p_q = tl.make_block_ptr(q+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
88
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
90
+ p_o = tl.make_block_ptr(o+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
91
+ p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
92
+ p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
93
+ else:
94
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
95
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
98
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
99
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
100
+ # [BK, BT]
101
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
102
+ # [BT, BV]
103
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
104
+
105
+ # [BT, BV]
106
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
107
+ b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
108
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
109
+ xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
110
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
111
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
112
+ b_kh_hat = (b_kh - mean) * rstd
113
+
114
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
115
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
116
+ b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
117
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
118
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
119
+
120
+ # [BT, BK]
121
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
122
+ # [BT]
123
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
124
+ b_q = (b_q * scale).to(b_k.dtype)
125
+
126
+ # [BT, BT]
127
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
128
+ b_A = tl.where(m_A, b_A, 0)
129
+ b_Ae = tl.where(m_A, b_e[:, None], 0.0)
130
+
131
+ b_o = - tl.dot(b_e[:, None] * b_A.to(b_v2.dtype), b_v2, allow_tf32=False)
132
+ b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v2.dtype), b_v2, allow_tf32=False)
133
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
134
+ b_e_last = tl.load(p_e_last)
135
+ b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
136
+ b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
137
+ b_h = tl.where((v_i < V)[None, :], b_h, 0.)
138
+ b_hb = tl.where((v_i < V), b_hb, 0.)
139
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
140
+
141
+ if STORE_FINAL_STATE:
142
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
143
+ p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
144
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
145
+ tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,))
146
+
147
+
148
+ @triton.heuristics({
149
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
150
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=1),
155
+ triton.Config({}, num_warps=2),
156
+ triton.Config({}, num_warps=4)
157
+ ],
158
+ key=['BT', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def fused_chunk_ttt_linear_bwd_kernel_h(
162
+ k,
163
+ v,
164
+ v2,
165
+ x,
166
+ y,
167
+ r,
168
+ w,
169
+ b,
170
+ eta,
171
+ h0,
172
+ hb0,
173
+ h,
174
+ do,
175
+ dq,
176
+ scale,
177
+ eps,
178
+ T,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_INITIAL_STATE: tl.constexpr,
186
+ USE_INITIAL_STATE_B: tl.constexpr,
187
+ HEAD_FIRST: tl.constexpr
188
+ ):
189
+ # indices
190
+ i_nh = tl.program_id(0)
191
+ i_n, i_h = i_nh // H, i_nh % H
192
+ bos, _ = i_n * T, i_n * T + T
193
+ NT = tl.cdiv(T, BT)
194
+ boh = i_n * NT
195
+
196
+ o_i = tl.arange(0, BT)
197
+ v_i = tl.arange(0, BV)
198
+ m_A = o_i[:, None] >= o_i[None, :]
199
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
200
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
201
+
202
+ # [BK, BV]
203
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
204
+ # [BV]
205
+ b_hb = tl.zeros([BV], dtype=tl.float32)
206
+ if USE_INITIAL_STATE:
207
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
208
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
209
+ if USE_INITIAL_STATE_B:
210
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
211
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
212
+
213
+ for i_t in range(NT):
214
+ if HEAD_FIRST:
215
+ p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
216
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1))
217
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
218
+ p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
219
+ p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
220
+ p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
221
+ p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0))
222
+ p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
223
+ p_dq = tl.make_block_ptr(dq+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
224
+ p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
225
+ p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1
226
+ else:
227
+ p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
228
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
229
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
230
+ p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
231
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
232
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
233
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
234
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
235
+ p_dq = tl.make_block_ptr(dq+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
236
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
237
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
238
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
239
+ # [BK, BT]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
241
+ # [BT, BV]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
243
+
244
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
245
+ b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
246
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
247
+ xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
248
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
249
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
250
+ b_kh_hat = (b_kh - mean) * rstd
251
+
252
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
253
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
254
+ b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
255
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
256
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
257
+ tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1))
258
+ tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1))
259
+ tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1))
260
+ tl.store(p_v2, b_v2.to(p_v2.dtype.element_ty), boundary_check=(0, 1))
261
+
262
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
263
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
264
+
265
+ b_v2 = tl.where((v_i < V)[None, :], b_v2, 0.)
266
+ b_ds = tl.dot(b_do, tl.trans(b_v2).to(b_do.dtype))
267
+ b_ds = tl.where(m_A, b_ds, 0)
268
+ b_ds = b_ds.to(b_k.dtype)
269
+ b_dq = tl.dot(b_do, tl.trans(b_h).to(b_do.dtype))
270
+ b_dq -= tl.dot(b_ds, tl.trans(b_k)) * b_e[:, None]
271
+ b_dq *= scale
272
+
273
+ b_e_last = tl.load(p_e_last)
274
+ b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
275
+ b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
276
+ b_h = tl.where((v_i < V)[None, :], b_h, 0.)
277
+ b_hb = tl.where((v_i < V), b_hb, 0.)
278
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
279
+
280
+
281
+ @triton.heuristics({
282
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
283
+ 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None,
284
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
285
+ 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None,
286
+ })
287
+ @triton.autotune(
288
+ configs=[
289
+ triton.Config({}, num_warps=1),
290
+ triton.Config({}, num_warps=2),
291
+ triton.Config({}, num_warps=4)
292
+ ],
293
+ key=['BT', 'BK', 'BV'],
294
+ )
295
+ @triton.jit(do_not_specialize=['T'])
296
+ def fused_chunk_ttt_linear_bwd_kernel_dh(
297
+ q,
298
+ k,
299
+ v,
300
+ v2,
301
+ x,
302
+ y,
303
+ r,
304
+ w,
305
+ b,
306
+ eta,
307
+ h,
308
+ dht,
309
+ dhbt,
310
+ dh0,
311
+ dhb0,
312
+ do,
313
+ dk,
314
+ dv,
315
+ de,
316
+ dw,
317
+ db,
318
+ scale,
319
+ T,
320
+ H: tl.constexpr,
321
+ K: tl.constexpr,
322
+ V: tl.constexpr,
323
+ BT: tl.constexpr,
324
+ BK: tl.constexpr,
325
+ BV: tl.constexpr,
326
+ USE_INITIAL_STATE: tl.constexpr,
327
+ USE_INITIAL_STATE_B: tl.constexpr,
328
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
329
+ USE_FINAL_STATE_GRADIENT_B: tl.constexpr,
330
+ HEAD_FIRST: tl.constexpr
331
+ ):
332
+ # indices
333
+ i_nh = tl.program_id(0)
334
+ i_n, i_h = i_nh // H, i_nh % H
335
+ bos, _ = i_n * T, i_n * T + T
336
+ NT = tl.cdiv(T, BT)
337
+ boh = i_n * NT
338
+
339
+ # [BK, BV]
340
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
341
+ # [BV]
342
+ b_dhb = tl.zeros([BV], dtype=tl.float32)
343
+ if USE_FINAL_STATE_GRADIENT:
344
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
345
+ b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero")
346
+ if USE_FINAL_STATE_GRADIENT_B:
347
+ p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
348
+ b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero")
349
+
350
+ # [BV]
351
+ o_i = tl.arange(0, BT)
352
+ v_i = tl.arange(0, BV)
353
+ m_A = o_i[:, None] >= o_i[None, :]
354
+ m_A_t = o_i[:, None] <= o_i[None, :]
355
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
356
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
357
+ b_dw = tl.zeros([BV,], dtype=b_w.dtype)
358
+ b_db = tl.zeros([BV,], dtype=b_b.dtype)
359
+ p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
360
+ p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
361
+
362
+ for i_t in range(NT - 1, -1, -1):
363
+ if HEAD_FIRST:
364
+ p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1))
365
+ p_q = tl.make_block_ptr(q+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1))
366
+ p_k = tl.make_block_ptr(k+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
367
+ p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
368
+ p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
369
+ p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
370
+ p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
371
+ p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0))
372
+ p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
373
+ p_dv = tl.make_block_ptr(dv+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
374
+ p_dk = tl.make_block_ptr(dk+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
375
+ p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
376
+ p_de = tl.make_block_ptr(de+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,))
377
+ p_e_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1
378
+ else:
379
+ p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1))
380
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
381
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
382
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
383
+ p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
384
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
385
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
386
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
387
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
388
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
389
+ p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
390
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
391
+ p_de = tl.make_block_ptr(de+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
392
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
393
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
394
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
395
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
396
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
397
+ b_e_last = tl.load(p_e_last)
398
+ b_A = tl.dot(b_k, b_q)
399
+ b_A = - tl.where(m_A_t, b_A * scale * b_e[None, :], 0).to(do.dtype.element_ty)
400
+ b_Ae = - tl.where(m_A_t, b_e[None, :], 0).to(do.dtype.element_ty)
401
+ b_dv_new = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do)
402
+ b_dv_new -= tl.dot(b_e_last * b_k, b_dh.to(b_k.dtype))
403
+ b_dv_new -= b_e_last * b_dhb.to(b_k.dtype)[None, :]
404
+
405
+ b_v2 = tl.load(p_v2, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
406
+ b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
407
+ b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
408
+ b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
409
+ b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) -
410
+ b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
411
+ b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) +
412
+ b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
413
+ b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v2.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True)
414
+
415
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
416
+ b_w = b_w.to(b_k.dtype)
417
+ b_b = b_b.to(b_k.dtype)
418
+ b_dv = -b_w * b_dy.to(b_k.dtype)
419
+ b_dk = b_w * b_dy.to(b_k.dtype)
420
+ b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) +
421
+ (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype)
422
+ b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype)
423
+ b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype)
424
+
425
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
426
+ b_q = (b_q * scale).to(b_q.dtype)
427
+ b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) -
428
+ b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V
429
+ b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V
430
+ b_dkh = tl.where((v_i < V)[None, :] * (o_i < T-i_t*BT)[:, None], b_dkh, 0.)
431
+ b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype)
432
+
433
+ b_ds = tl.dot(b_do, tl.trans(b_v2))
434
+ b_ds = tl.where(m_A, b_ds, 0)
435
+ b_ds = b_ds.to(b_k.dtype)
436
+ i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1)
437
+ mask = (o_i == i_last)
438
+ b_dk -= b_e_last * tl.dot(b_v2, tl.trans(b_dh).to(b_v2.dtype))
439
+ b_dk -= tl.dot(tl.trans(b_ds), tl.trans(b_q) * b_e[:, None])
440
+ b_de = mask * tl.sum(- b_dh * tl.trans(tl.dot(tl.trans(b_v2), b_k))).to(b_k.dtype)
441
+ b_de -= mask * tl.sum(b_dhb * tl.sum(b_v2, axis=0)).to(b_k.dtype)
442
+ b_de -= tl.sum(tl.dot(b_ds, b_k) * tl.trans(b_q).to(b_k.dtype), axis=1)
443
+ b_de -= tl.sum(b_ds, axis=1)
444
+ b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh)
445
+ b_dhb += tl.sum(b_do + b_dkh, axis=0)
446
+ b_dh = tl.where((v_i < V)[None, :], b_dh, 0.)
447
+ b_dhb = tl.where((v_i < V), b_dhb, 0.)
448
+
449
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
450
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
451
+ tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,))
452
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,))
453
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
454
+
455
+ if USE_INITIAL_STATE:
456
+ p_dh0 = tl.make_block_ptr(dh0+i_nh*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
457
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
458
+ if USE_INITIAL_STATE_B:
459
+ p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (0,), (BV,), (0,))
460
+ tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,))
461
+
462
+
463
+ def fused_chunk_ttt_linear_bwd_h(
464
+ q: torch.Tensor,
465
+ k: torch.Tensor,
466
+ v: torch.Tensor,
467
+ w: torch.Tensor,
468
+ b: torch.Tensor,
469
+ eta: torch.Tensor,
470
+ scale: float,
471
+ eps: float,
472
+ do: torch.Tensor,
473
+ BT: int = 16,
474
+ initial_state: torch.Tensor = None,
475
+ initial_state_bias: torch.Tensor = None,
476
+ offsets: Optional[torch.LongTensor] = None,
477
+ head_first: bool = True
478
+ ):
479
+ assert offsets is None, "bwd of varlen is not implemented yet."
480
+ if head_first:
481
+ B, H, T, K, V = *k.shape, v.shape[-1]
482
+ else:
483
+ B, T, H, K, V = *k.shape, v.shape[-1]
484
+ # N: the actual number of sequences in the batch with either equal or variable lengths
485
+ N, NT = B, triton.cdiv(T, BT)
486
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
487
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
488
+
489
+ if head_first:
490
+ h = k.new_empty(B, H, NT, K, V)
491
+ r = v.new_empty(B, H, T, 1, dtype=torch.float32)
492
+ else:
493
+ h = k.new_empty(B, NT, H, K, V)
494
+ r = v.new_empty(B, T, H, 1, dtype=torch.float32)
495
+ v2 = torch.empty_like(v)
496
+ x = torch.empty_like(v)
497
+ y = torch.empty_like(v)
498
+ dq = torch.empty_like(q)
499
+
500
+ grid = (N * H,)
501
+ fused_chunk_ttt_linear_bwd_kernel_h[grid](
502
+ k=k,
503
+ v=v,
504
+ v2=v2,
505
+ x=x,
506
+ y=y,
507
+ r=r,
508
+ w=w,
509
+ b=b,
510
+ eta=eta,
511
+ h0=initial_state,
512
+ hb0=initial_state_bias,
513
+ h=h,
514
+ do=do,
515
+ dq=dq,
516
+ scale=scale,
517
+ eps=eps,
518
+ T=T,
519
+ H=H,
520
+ K=K,
521
+ V=V,
522
+ BT=BT,
523
+ BK=BK,
524
+ BV=BV,
525
+ HEAD_FIRST=head_first
526
+ )
527
+ return dq, h, v2, x, y, r
528
+
529
+
530
+ def fused_chunk_ttt_linear_bwd_dh(
531
+ q: torch.Tensor,
532
+ k: torch.Tensor,
533
+ v: torch.Tensor,
534
+ v2: torch.Tensor,
535
+ x: torch.Tensor,
536
+ y: torch.Tensor,
537
+ r: torch.Tensor,
538
+ w: torch.Tensor,
539
+ b: torch.Tensor,
540
+ eta: torch.Tensor,
541
+ scale: float,
542
+ h: torch.Tensor,
543
+ do: torch.Tensor,
544
+ dht: torch.Tensor,
545
+ dhbt: torch.Tensor,
546
+ BT: int = 16,
547
+ initial_state: torch.Tensor = None,
548
+ initial_state_bias: torch.Tensor = None,
549
+ offsets: Optional[torch.LongTensor] = None,
550
+ head_first: bool = True
551
+ ):
552
+ assert offsets is None, "bwd of varlen is not implemented yet."
553
+ if head_first:
554
+ B, H, T, K, V = *k.shape, v.shape[-1]
555
+ else:
556
+ B, T, H, K, V = *k.shape, v.shape[-1]
557
+ # N: the actual number of sequences in the batch with either equal or variable lengths
558
+ N = B
559
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
560
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
561
+
562
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32) if initial_state is not None else None
563
+ dhb0 = torch.empty_like(initial_state_bias, dtype=torch.float32) if initial_state_bias is not None else None
564
+ dk = torch.empty_like(k)
565
+ dv = torch.empty_like(v)
566
+ de = torch.empty_like(eta)
567
+ dw = w.new_empty(B, H, V)
568
+ db = b.new_empty(B, H, V)
569
+
570
+ grid = (N * H,)
571
+ fused_chunk_ttt_linear_bwd_kernel_dh[grid](
572
+ q=q,
573
+ k=k,
574
+ v=v,
575
+ v2=v2,
576
+ x=x,
577
+ y=y,
578
+ r=r,
579
+ w=w,
580
+ b=b,
581
+ eta=eta,
582
+ h=h,
583
+ dht=dht,
584
+ dhbt=dhbt,
585
+ dh0=dh0,
586
+ dhb0=dhb0,
587
+ do=do,
588
+ dk=dk,
589
+ dv=dv,
590
+ de=de,
591
+ dw=dw,
592
+ db=db,
593
+ scale=scale,
594
+ T=T,
595
+ H=H,
596
+ K=K,
597
+ V=V,
598
+ BT=BT,
599
+ BK=BK,
600
+ BV=BV,
601
+ HEAD_FIRST=head_first
602
+ )
603
+ dw = dw.sum(dim=0)
604
+ db = db.sum(dim=0)
605
+ return dk, dv, de, dw, db, dh0, dhb0
606
+
607
+
608
+ def fused_chunk_ttt_linear_fwd(
609
+ q: torch.Tensor,
610
+ k: torch.Tensor,
611
+ v: torch.Tensor,
612
+ w: torch.Tensor,
613
+ b: torch.Tensor,
614
+ eta: torch.Tensor,
615
+ scale: float,
616
+ eps: float,
617
+ initial_state: torch.Tensor,
618
+ initial_state_bias: torch.Tensor,
619
+ output_final_state: bool,
620
+ offsets: Optional[torch.LongTensor] = None,
621
+ head_first: bool = True,
622
+ BT: int = 16
623
+ ):
624
+ if head_first:
625
+ B, H, T, K, V = *k.shape, v.shape[-1]
626
+ else:
627
+ B, T, H, K, V = *k.shape, v.shape[-1]
628
+ # N: the actual number of sequences in the batch with either equal or variable lengths
629
+ N = B if offsets is None else len(offsets) - 1
630
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
631
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
632
+ o = torch.empty_like(v)
633
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
634
+ final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None
635
+
636
+ grid = (N * H,)
637
+ fused_chunk_ttt_linear_fwd_kernel[grid](
638
+ q=q,
639
+ k=k,
640
+ v=v,
641
+ eta=eta,
642
+ w=w,
643
+ b=b,
644
+ o=o,
645
+ scale=scale,
646
+ eps=eps,
647
+ h0=initial_state,
648
+ hb0=initial_state_bias,
649
+ ht=final_state,
650
+ hbt=final_state_bias,
651
+ offsets=offsets,
652
+ T=T,
653
+ H=H,
654
+ K=K,
655
+ V=V,
656
+ BT=BT,
657
+ BK=BK,
658
+ BV=BV,
659
+ HEAD_FIRST=head_first
660
+ )
661
+ return o, final_state, final_state_bias
662
+
663
+
664
+ def fused_chunk_ttt_linear_bwd(
665
+ q: torch.Tensor,
666
+ k: torch.Tensor,
667
+ v: torch.Tensor,
668
+ w: torch.Tensor,
669
+ b: torch.Tensor,
670
+ eta: torch.Tensor,
671
+ scale: float,
672
+ eps: float,
673
+ do: torch.Tensor,
674
+ dht: torch.Tensor,
675
+ dhbt: torch.Tensor,
676
+ BT: int = 16,
677
+ initial_state: torch.Tensor = None,
678
+ initial_state_bias: torch.Tensor = None,
679
+ offsets: Optional[torch.LongTensor] = None,
680
+ head_first: bool = True
681
+ ):
682
+ assert offsets is None, "bwd of varlen is not implemented yet."
683
+ dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h(
684
+ q=q,
685
+ k=k,
686
+ v=v,
687
+ w=w,
688
+ b=b,
689
+ eta=eta,
690
+ scale=scale,
691
+ eps=eps,
692
+ do=do,
693
+ BT=BT,
694
+ initial_state=initial_state,
695
+ initial_state_bias=initial_state_bias,
696
+ offsets=offsets,
697
+ head_first=head_first
698
+ )
699
+ dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd_dh(
700
+ q=q,
701
+ k=k,
702
+ v=v,
703
+ v2=v2,
704
+ x=x,
705
+ y=y,
706
+ r=rstd,
707
+ w=w,
708
+ b=b,
709
+ eta=eta,
710
+ scale=scale,
711
+ h=h,
712
+ do=do,
713
+ dht=dht,
714
+ dhbt=dhbt,
715
+ BT=BT,
716
+ initial_state=initial_state,
717
+ initial_state_bias=initial_state_bias,
718
+ offsets=offsets,
719
+ head_first=head_first
720
+ )
721
+ return dq, dk, dv, de, dw, db, dh0, dhb0
722
+
723
+
724
+ class FusedChunkTTTLinearFunction(torch.autograd.Function):
725
+
726
+ @staticmethod
727
+ @input_guard
728
+ @autocast_custom_fwd
729
+ def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state,
730
+ initial_state_bias, output_final_state, offsets, head_first):
731
+ o, final_state, final_state_bias = fused_chunk_ttt_linear_fwd(
732
+ q=q,
733
+ k=k,
734
+ v=v,
735
+ w=w,
736
+ b=b,
737
+ eta=eta,
738
+ scale=scale,
739
+ eps=eps,
740
+ BT=BT,
741
+ initial_state=initial_state,
742
+ initial_state_bias=initial_state_bias,
743
+ output_final_state=output_final_state,
744
+ offsets=offsets,
745
+ head_first=head_first
746
+ )
747
+ ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias)
748
+ ctx.BT = BT
749
+ ctx.scale = scale
750
+ ctx.eps = eps
751
+ ctx.offsets = offsets
752
+ ctx.head_first = head_first
753
+ return o.to(q.dtype), final_state, final_state_bias
754
+
755
+ @staticmethod
756
+ @input_guard
757
+ @autocast_custom_bwd
758
+ def backward(ctx, do, dht, dhbt):
759
+ q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors
760
+ dq, dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd(
761
+ q=q,
762
+ k=k,
763
+ v=v,
764
+ w=w,
765
+ b=b,
766
+ eta=eta,
767
+ scale=ctx.scale,
768
+ eps=ctx.eps,
769
+ do=do,
770
+ dht=dht,
771
+ dhbt=dhbt,
772
+ BT=ctx.BT,
773
+ initial_state=initial_state,
774
+ initial_state_bias=initial_state_bias,
775
+ offsets=ctx.offsets,
776
+ head_first=ctx.head_first
777
+ )
778
+ return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None
779
+
780
+
781
+ def norm_residual(x, weight, bias, eps, head_first):
782
+ # GroupNorm and Residual
783
+ if head_first:
784
+ B, H, T, D = x.shape
785
+ x = x.transpose(1, 2)
786
+ x += group_norm(
787
+ x.reshape(B, T, -1).clone(),
788
+ weight=weight.reshape(-1).clone(),
789
+ bias=bias.reshape(-1).clone(),
790
+ eps=eps,
791
+ num_groups=H,
792
+ ).reshape(x.shape)
793
+ x = x.transpose(1, 2)
794
+ else:
795
+ B, T, H, D = x.shape
796
+ x += group_norm(
797
+ x.reshape(B, T, -1).clone(),
798
+ weight=weight.reshape(-1).clone(),
799
+ bias=bias.reshape(-1).clone(),
800
+ eps=eps,
801
+ num_groups=H,
802
+ ).reshape(x.shape)
803
+ return x
804
+
805
+
806
+ def fused_chunk_ttt_linear(
807
+ q: torch.Tensor,
808
+ k: torch.Tensor,
809
+ v: torch.Tensor,
810
+ w: torch.Tensor,
811
+ b: torch.Tensor,
812
+ eta: torch.Tensor,
813
+ scale: float = None,
814
+ eps: float = 1e-6,
815
+ chunk_size: int = 16,
816
+ initial_state: torch.Tensor = None,
817
+ initial_state_bias: torch.Tensor = None,
818
+ output_final_state: bool = False,
819
+ cu_seqlens: Optional[torch.LongTensor] = None,
820
+ head_first: bool = True,
821
+ ):
822
+ r"""
823
+ Args:
824
+ q (torch.Tensor):
825
+ queries of shape `(B, H, T, K)`
826
+ k (torch.Tensor):
827
+ keys of shape `(B, H, T, K)`
828
+ v (torch.Tensor):
829
+ values of shape `(B, H, T, V)`
830
+ w (torch.Tensor):
831
+ layer norm weight of shape `(H, V)`
832
+ b (torch.Tensor):
833
+ layer norm bias of shape `(H, V)`
834
+ eta (torch.Tensor):
835
+ Learning rate for hidden state, of shape `(B, H, T, 1)`.
836
+ scale (Optional[int]):
837
+ Scale factor for the RetNet attention scores.
838
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
839
+ chunk_size (int):
840
+ chunk size. Default: `16`.
841
+ initial_state (Optional[torch.Tensor]):
842
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
843
+ initial_state_bias (Optional[torch.Tensor]):
844
+ Initial state bias of shape `(B, H, 1, V)`. Default: `None`.
845
+ output_final_state (Optional[bool]):
846
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
847
+ cu_seqlens (torch.LongTensor):
848
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
849
+ consistent with the FlashAttention API.
850
+ head_first (Optional[bool]):
851
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
852
+ Default: `True`.
853
+
854
+ Returns:
855
+ o (torch.Tensor):
856
+ Outputs of shape `[B, H, T, V]`
857
+ final_state (torch.Tensor):
858
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
859
+ final_state_bias (torch.Tensor):
860
+ Final state bias of shape `[B, H, 1, V]` if `output_final_state=True` else `None`.
861
+ """
862
+ assert q.dtype == k.dtype == v.dtype
863
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
864
+ if isinstance(eta, float):
865
+ eta = torch.full_like(q[:, :, :, :1], eta)
866
+ if cu_seqlens is not None:
867
+ if q.shape[0] != 1:
868
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
869
+ f"Please flatten variable-length inputs before processing.")
870
+ if head_first:
871
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
872
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
873
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
874
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
875
+ if scale is None:
876
+ scale = k.shape[-1] ** -0.5
877
+ else:
878
+ assert scale > 0, "Scale must be positive."
879
+ o, final_state, final_state_bias = FusedChunkTTTLinearFunction.apply(
880
+ q,
881
+ k,
882
+ v,
883
+ w,
884
+ b,
885
+ chunk_size,
886
+ eta,
887
+ scale,
888
+ eps,
889
+ initial_state,
890
+ initial_state_bias,
891
+ output_final_state,
892
+ cu_seqlens,
893
+ head_first
894
+ )
895
+ o = norm_residual(o, w, b, eps, head_first)
896
+ return o, final_state, final_state_bias
fla/ops/ttt/naive.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def ttt_linear(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ w: torch.Tensor,
13
+ b: torch.Tensor,
14
+ eta: torch.Tensor,
15
+ scale: float,
16
+ eps: float,
17
+ mini_batch_size: int,
18
+ initial_state: torch.Tensor,
19
+ initial_state_bias: torch.Tensor,
20
+ output_final_state: bool
21
+ ):
22
+ B, H, T, D = q.shape
23
+ BT = mini_batch_size
24
+ NT = T // BT
25
+ # [NT, B, H, mini_batch_size, D]
26
+ _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
27
+ _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
28
+ _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
29
+ # [NT, B, H, BT, 1]
30
+ _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4)
31
+ # [H, 1, D]
32
+ w = w.reshape(H, 1, D).to(torch.float32)
33
+ b = b.reshape(H, 1, D).to(torch.float32)
34
+
35
+ h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state
36
+ hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias
37
+ q *= scale
38
+ # [NT, B, H, BT, D]
39
+ o = torch.empty_like(_v)
40
+
41
+ for i in range(NT):
42
+ q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]]
43
+ kh = k_i @ h + hb
44
+ reconstruction_target = v_i - k_i
45
+
46
+ mean = kh.mean(-1, True)
47
+ var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
48
+ rstd = torch.sqrt(var + eps).to(torch.float32)
49
+ kh_hat = (kh - mean) / rstd
50
+
51
+ g = w * kh_hat + b - reconstruction_target
52
+ g *= w
53
+ v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D)
54
+
55
+ Attn = torch.tril(q_i @ k_i.transpose(-2, -1))
56
+ o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new
57
+ h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new
58
+ hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True)
59
+ # layer norm with residuals
60
+
61
+ mean = o_i.mean(dim=-1, keepdim=True)
62
+ var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
63
+ rstd = torch.sqrt(var + eps).to(torch.float32)
64
+ o[i] = o_i + (o_i - mean) / rstd * w + b
65
+
66
+ # [B, H, T, D]
67
+ o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
68
+ h = h if output_final_state else None
69
+ hb = hb if output_final_state else None
70
+ return o, h, hb
71
+
72
+
73
+ def chunk_ttt_linear_ref(
74
+ q: torch.Tensor,
75
+ k: torch.Tensor,
76
+ v: torch.Tensor,
77
+ w: torch.Tensor,
78
+ b: torch.Tensor,
79
+ eta: torch.Tensor,
80
+ scale: float = None,
81
+ eps: float = 1e-6,
82
+ mini_batch_size: int = 16,
83
+ initial_state: torch.Tensor = None,
84
+ initial_state_bias: torch.Tensor = None,
85
+ output_final_state: bool = False,
86
+ head_first: bool = True,
87
+ ):
88
+ assert q.dtype == k.dtype == v.dtype
89
+ assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same."
90
+ if isinstance(eta, float):
91
+ eta = torch.full_like(q[:, :, :, :1], eta)
92
+ if scale is None:
93
+ scale = k.shape[-1] ** -0.5
94
+ if not head_first:
95
+ q = q.transpose(1, 2)
96
+ k = k.transpose(1, 2)
97
+ v = v.transpose(1, 2)
98
+ eta = eta.transpose(1, 2)
99
+ T = q.shape[-2]
100
+ padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size
101
+ if padded > 0:
102
+ q = F.pad(q, (0, 0, 0, padded))
103
+ k = F.pad(k, (0, 0, 0, padded))
104
+ v = F.pad(v, (0, 0, 0, padded))
105
+ eta = F.pad(eta, (0, 0, 0, padded))
106
+ eta[:, :, -1, :] = eta[:, :, -(padded+1), :]
107
+ assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size."
108
+ q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b])
109
+ o, final_state, final_state_bias = ttt_linear(
110
+ q,
111
+ k,
112
+ v,
113
+ w,
114
+ b,
115
+ eta,
116
+ scale,
117
+ eps,
118
+ mini_batch_size,
119
+ initial_state,
120
+ initial_state_bias,
121
+ output_final_state,
122
+ )
123
+ o = o[:, :, :T, :].contiguous()
124
+ if not head_first:
125
+ o = o.transpose(1, 2)
126
+ return o, final_state, final_state_bias
fla/ops/utils/cumsum.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, input_guard
11
+
12
+ BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps)
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_local_cumsum_scalar_kernel(
27
+ s,
28
+ o,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ HEAD_FIRST: tl.constexpr,
35
+ USE_OFFSETS: tl.constexpr,
36
+ REVERSE: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+
47
+ if HEAD_FIRST:
48
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
49
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
50
+ else:
51
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
52
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
53
+ # [BT]
54
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
55
+ b_o = tl.cumsum(b_s, axis=0)
56
+ if REVERSE:
57
+ b_z = tl.sum(b_s, axis=0)
58
+ b_o = -b_o + b_z[None] + b_s
59
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
60
+
61
+
62
+ @triton.heuristics({
63
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
64
+ })
65
+ @triton.autotune(
66
+ configs=[
67
+ triton.Config({'BS': BS}, num_warps=num_warps)
68
+ for BS in BS_LIST
69
+ for num_warps in [2, 4, 8]
70
+ ],
71
+ key=['S', 'BT'],
72
+ )
73
+ @triton.jit(do_not_specialize=['T'])
74
+ def chunk_local_cumsum_vector_kernel(
75
+ s,
76
+ o,
77
+ offsets,
78
+ indices,
79
+ T,
80
+ H: tl.constexpr,
81
+ S: tl.constexpr,
82
+ BT: tl.constexpr,
83
+ BS: tl.constexpr,
84
+ HEAD_FIRST: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ REVERSE: tl.constexpr
87
+ ):
88
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
89
+ i_b, i_h = i_bh // H, i_bh % H
90
+ if USE_OFFSETS:
91
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
92
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
93
+ T = eos - bos
94
+ else:
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ o_i = tl.arange(0, BT)
98
+ if REVERSE:
99
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
100
+ else:
101
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
102
+
103
+ if HEAD_FIRST:
104
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
105
+ p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
106
+ else:
107
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
108
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
109
+ # [BT, BS]
110
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
111
+ b_o = tl.dot(m_s, b_s, allow_tf32=False)
112
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
113
+
114
+
115
+ @triton.heuristics({
116
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
117
+ })
118
+ @triton.autotune(
119
+ configs=[
120
+ triton.Config({'BT': 16}, num_warps=2),
121
+ triton.Config({'BT': 32}, num_warps=4),
122
+ triton.Config({'BT': 32}, num_warps=2),
123
+ triton.Config({'BT': 64}, num_warps=8),
124
+ triton.Config({'BT': 64}, num_warps=4),
125
+ ],
126
+ key=[]
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_global_cumsum_scalar_kernel(
130
+ s,
131
+ o,
132
+ offsets,
133
+ T,
134
+ H: tl.constexpr,
135
+ BT: tl.constexpr,
136
+ HEAD_FIRST: tl.constexpr,
137
+ USE_OFFSETS: tl.constexpr,
138
+ REVERSE: tl.constexpr
139
+ ):
140
+ i_bh = tl.program_id(0)
141
+ i_b, i_h = i_bh // H, i_bh % H
142
+ if USE_OFFSETS:
143
+ bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32)
144
+ else:
145
+ bos, eos = i_b * T, i_b * T + T
146
+ T = eos - bos
147
+
148
+ b_z = tl.zeros([], dtype=tl.float32)
149
+ NT = tl.cdiv(T, BT)
150
+ for i_c in range(NT):
151
+ i_t = NT-1-i_c if REVERSE else i_c
152
+ if HEAD_FIRST:
153
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
154
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
155
+ else:
156
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
157
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
158
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
159
+ b_o = tl.cumsum(b_s, axis=0)
160
+ b_ss = tl.sum(b_s, 0)
161
+ if REVERSE:
162
+ b_o = -b_o + b_ss + b_s
163
+ b_o += b_z
164
+ if i_c >= 0:
165
+ b_z += b_ss
166
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
167
+
168
+
169
+ @triton.heuristics({
170
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
171
+ })
172
+ @triton.autotune(
173
+ configs=[
174
+ triton.Config({'BT': BT}, num_warps=num_warps)
175
+ for BT in [16, 32, 64]
176
+ for num_warps in [2, 4, 8]
177
+ ],
178
+ key=['S']
179
+ )
180
+ @triton.jit(do_not_specialize=['T'])
181
+ def chunk_global_cumsum_vector_kernel(
182
+ s,
183
+ z,
184
+ offsets,
185
+ T,
186
+ H: tl.constexpr,
187
+ S: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr,
191
+ USE_OFFSETS: tl.constexpr,
192
+ REVERSE: tl.constexpr
193
+ ):
194
+ i_s, i_bh = tl.program_id(0), tl.program_id(1)
195
+ i_b, i_h = i_bh // H, i_bh % H
196
+ if USE_OFFSETS:
197
+ bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32)
198
+ else:
199
+ bos, eos = i_b * T, i_b * T + T
200
+ T = eos - bos
201
+
202
+ o_i = tl.arange(0, BT)
203
+ if REVERSE:
204
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
205
+ else:
206
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
207
+
208
+ b_z = tl.zeros([BS], dtype=tl.float32)
209
+ NT = tl.cdiv(T, BT)
210
+ for i_c in range(NT):
211
+ i_t = NT-1-i_c if REVERSE else i_c
212
+ if HEAD_FIRST:
213
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
214
+ p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
215
+ else:
216
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
217
+ p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
218
+ # [BT, BS]
219
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
220
+ b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
221
+ tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
222
+ if i_c >= 0:
223
+ b_z += tl.sum(b_s, 0)
224
+
225
+
226
+ def chunk_local_cumsum_scalar(
227
+ g: torch.Tensor,
228
+ chunk_size: int,
229
+ reverse: bool = False,
230
+ offsets: Optional[torch.Tensor] = None,
231
+ indices: Optional[torch.Tensor] = None,
232
+ head_first: bool = True,
233
+ output_dtype: Optional[torch.dtype] = torch.float
234
+ ) -> torch.Tensor:
235
+ if head_first:
236
+ B, H, T = g.shape
237
+ else:
238
+ B, T, H = g.shape
239
+ if offsets is not None:
240
+ B = len(offsets) - 1
241
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
242
+ BT = chunk_size
243
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
244
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
245
+ grid = (NT, B * H)
246
+ chunk_local_cumsum_scalar_kernel[grid](
247
+ g_org,
248
+ g,
249
+ offsets,
250
+ indices,
251
+ T=T,
252
+ H=H,
253
+ BT=BT,
254
+ HEAD_FIRST=head_first,
255
+ REVERSE=reverse
256
+ )
257
+ return g
258
+
259
+
260
+ def chunk_local_cumsum_vector(
261
+ g: torch.Tensor,
262
+ chunk_size: int,
263
+ reverse: bool = False,
264
+ offsets: Optional[torch.Tensor] = None,
265
+ indices: Optional[torch.Tensor] = None,
266
+ head_first: bool = True,
267
+ output_dtype: Optional[torch.dtype] = torch.float
268
+ ) -> torch.Tensor:
269
+ if head_first:
270
+ B, H, T, S = g.shape
271
+ else:
272
+ B, T, H, S = g.shape
273
+ BT = chunk_size
274
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
275
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
276
+
277
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
278
+ def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
279
+ # keep cummulative normalizer in fp32
280
+ # this kernel is equivalent to
281
+ # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
282
+ chunk_local_cumsum_vector_kernel[grid](
283
+ g_org,
284
+ g,
285
+ offsets,
286
+ indices,
287
+ T=T,
288
+ H=H,
289
+ S=S,
290
+ BT=BT,
291
+ HEAD_FIRST=head_first,
292
+ REVERSE=reverse
293
+ )
294
+ return g
295
+
296
+
297
+ @input_guard
298
+ def chunk_global_cumsum_scalar(
299
+ s: torch.Tensor,
300
+ dtype: Optional[torch.dtype] = None,
301
+ reverse: bool = False,
302
+ offsets: Optional[torch.Tensor] = None,
303
+ head_first: bool = True,
304
+ output_dtype: Optional[torch.dtype] = torch.float
305
+ ) -> torch.Tensor:
306
+ dtype = dtype or s.dtype
307
+ if head_first:
308
+ B, H, T = s.shape
309
+ else:
310
+ B, T, H = s.shape
311
+ if offsets is not None:
312
+ B = len(offsets) - 1
313
+ grid = (B * H,)
314
+ z = torch.empty_like(s, dtype=output_dtype or dtype)
315
+ chunk_global_cumsum_scalar_kernel[grid](
316
+ s,
317
+ z,
318
+ offsets,
319
+ T=T,
320
+ H=H,
321
+ HEAD_FIRST=head_first,
322
+ REVERSE=reverse
323
+ )
324
+ return z
325
+
326
+
327
+ @input_guard
328
+ def chunk_global_cumsum_vector(
329
+ s: torch.Tensor,
330
+ dtype: Optional[torch.dtype] = None,
331
+ reverse: bool = False,
332
+ offsets: Optional[torch.Tensor] = None,
333
+ head_first: bool = True,
334
+ output_dtype: Optional[torch.dtype] = torch.float
335
+ ) -> torch.Tensor:
336
+ dtype = dtype or s.dtype
337
+ if head_first:
338
+ B, H, T, S = s.shape
339
+ else:
340
+ B, T, H, S = s.shape
341
+ BS = min(32, triton.next_power_of_2(S))
342
+ if offsets is not None:
343
+ B = len(offsets) - 1
344
+ grid = (triton.cdiv(S, BS), B * H)
345
+ z = torch.empty_like(s, dtype=output_dtype or dtype)
346
+ chunk_global_cumsum_vector_kernel[grid](
347
+ s,
348
+ z,
349
+ offsets,
350
+ T=T,
351
+ H=H,
352
+ S=S,
353
+ BS=BS,
354
+ HEAD_FIRST=head_first,
355
+ REVERSE=reverse
356
+ )
357
+ return z
358
+
359
+
360
+ @input_guard
361
+ def chunk_global_cumsum(
362
+ s: torch.Tensor,
363
+ dtype: Optional[torch.dtype] = None,
364
+ reverse: bool = False,
365
+ offsets: Optional[torch.Tensor] = None,
366
+ head_first: bool = True,
367
+ output_dtype: Optional[torch.dtype] = torch.float
368
+ ) -> torch.Tensor:
369
+ if offsets is not None:
370
+ assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided"
371
+ if len(s.shape) == 3:
372
+ return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype)
373
+ elif len(s.shape) == 4:
374
+ return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype)
375
+ else:
376
+ raise ValueError(f"Unsupported input shape {s.shape}. "
377
+ f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` "
378
+ f"or [B, T, H]/[B, T, H, D] otherwise")
379
+
380
+
381
+ @input_guard
382
+ def chunk_local_cumsum(
383
+ g: torch.Tensor,
384
+ chunk_size: int,
385
+ reverse: bool = False,
386
+ offsets: Optional[torch.Tensor] = None,
387
+ indices: Optional[torch.Tensor] = None,
388
+ head_first: bool = True,
389
+ output_dtype: Optional[torch.dtype] = torch.float
390
+ ) -> torch.Tensor:
391
+ if offsets is not None:
392
+ assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided"
393
+ if len(g.shape) == 3:
394
+ return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
395
+ elif len(g.shape) == 4:
396
+ return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
397
+ else:
398
+ raise ValueError(f"Unsupported input shape {g.shape}. "
399
+ f"which should be (B, H, T, dim) if `head_first=True` "
400
+ f"or (batch_size, num_heads, seq_len) otherwise")
fla/ops/utils/logcumsumexp.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from fla.ops.utils.op import exp, log
8
+
9
+
10
+ @triton.autotune(
11
+ configs=[
12
+ triton.Config({'BT': BT}, num_warps=num_warps)
13
+ for BT in [16, 32, 64]
14
+ for num_warps in [2, 4, 8]
15
+ ],
16
+ key=['S']
17
+ )
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def logcumsumexp_fwd_kernel(
20
+ s,
21
+ z,
22
+ T,
23
+ S: tl.constexpr,
24
+ BT: tl.constexpr
25
+ ):
26
+ i_bh = tl.program_id(0)
27
+ o_i = tl.arange(0, BT)
28
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
29
+
30
+ b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
31
+ b_zp = tl.zeros([S,], dtype=tl.float32)
32
+ for i_t in range(tl.cdiv(T, BT)):
33
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
34
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
35
+
36
+ # [BT, S]
37
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
38
+ # [S,]
39
+ b_mc = tl.max(b_s, 0)
40
+ b_mc = tl.maximum(b_mp, b_mc)
41
+ b_zp = b_zp * exp(b_mp - b_mc)
42
+ # [BT, S]
43
+ b_s = exp(b_s - b_mc)
44
+ b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
45
+ # [S,]
46
+ b_zc = tl.max(b_z, 0)
47
+ b_mp = b_mc
48
+ b_zp = b_zc
49
+ # [BT, BS]
50
+ # small eps to prevent underflows
51
+ b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
52
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
fla/ops/utils/solve_tril.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [1, 2, 4, 8]
21
+ for num_stages in [2, 3, 4, 5]
22
+ ],
23
+ key=['BT'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def solve_tril_16x16_kernel(
27
+ A,
28
+ Ad,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ USE_OFFSETS: tl.constexpr,
35
+ HEAD_FIRST: tl.constexpr,
36
+ ):
37
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
38
+ i_b, i_h = i_bh // H, i_bh % H
39
+ if USE_OFFSETS:
40
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
41
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
42
+ T = eos - bos
43
+ else:
44
+ bos, eos = i_b * T, i_b * T + T
45
+
46
+ if HEAD_FIRST:
47
+ A = A + i_bh * T * BT
48
+ Ad = Ad + i_bh * T * 16
49
+ stride_16 = 16
50
+ stride_BT = BT
51
+ else:
52
+ A = A + (bos*H + i_h) * BT
53
+ Ad = Ad + (bos*H + i_h) * 16
54
+ stride_16 = H*16
55
+ stride_BT = H*BT
56
+
57
+ offset = (i_t * 16) % BT
58
+ p_A = tl.make_block_ptr(A, (T, BT), (stride_BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
59
+ p_Ai = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 16, 0), (16, 16), (1, 0))
60
+ b_A = tl.load(p_A, boundary_check=(0, 1))
61
+ b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
62
+
63
+ o_i = tl.arange(0, 16)
64
+ for i in range(1, min(16, T-i_t*16)):
65
+ b_a = -tl.load(A + (i_t * 16 + i) * stride_BT + o_i + offset)
66
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
67
+ mask = o_i == i
68
+ b_A = tl.where(mask[:, None], b_a, b_A)
69
+ b_A += o_i[:, None] == o_i[None, :]
70
+ tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
71
+
72
+
73
+ @triton.heuristics({
74
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
75
+ })
76
+ @triton.autotune(
77
+ configs=[
78
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
79
+ for num_warps in [1, 2, 4, 8]
80
+ for num_stages in [2, 3, 4, 5]
81
+ ],
82
+ key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'],
83
+ )
84
+ @triton.jit(do_not_specialize=['T'])
85
+ def merge_16x16_to_32x32_inverse_kernel(
86
+ A,
87
+ Ad,
88
+ Ai,
89
+ offsets,
90
+ indices,
91
+ T,
92
+ H: tl.constexpr,
93
+ BT: tl.constexpr,
94
+ HEAD_FIRST: tl.constexpr,
95
+ USE_OFFSETS: tl.constexpr
96
+ ):
97
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
98
+ i_b, i_h = i_bh // H, i_bh % H
99
+ if USE_OFFSETS:
100
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
101
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
102
+ T = eos - bos
103
+ else:
104
+ bos, eos = i_b * T, i_b * T + T
105
+
106
+ if HEAD_FIRST:
107
+ A += (i_bh * T * 32)
108
+ Ad += (i_bh * T * 16)
109
+ Ai += (i_bh * T * 32)
110
+ stride_16 = 16
111
+ stride_32 = 32
112
+ else:
113
+ A += (bos*H + i_h) * 32
114
+ Ad += (bos*H + i_h) * 16
115
+ Ai += (bos*H + i_h) * 32
116
+ stride_16 = 16 * H
117
+ stride_32 = 32 * H
118
+
119
+ p_A_21 = tl.make_block_ptr(A, (T, 32), (stride_32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
120
+ p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 32, 0), (16, 16), (1, 0))
121
+ p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
122
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32, 0), (16, 16), (1, 0))
123
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
124
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
125
+
126
+ A_21 = tl.load(p_A_21, boundary_check=(0, 1))
127
+ Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1))
128
+ Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1))
129
+ Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
130
+ tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
131
+ tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
132
+ tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
133
+
134
+
135
+ @triton.heuristics({
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
141
+ for num_warps in [2, 4, 8]
142
+ for num_stages in [2, 3, 4, 5]
143
+ ],
144
+ key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'],
145
+ )
146
+ @triton.jit(do_not_specialize=['T'])
147
+ def merge_16x16_to_64x64_inverse_kernel(
148
+ A,
149
+ Ad,
150
+ Ai,
151
+ offsets,
152
+ indices,
153
+ T,
154
+ H: tl.constexpr,
155
+ BT: tl.constexpr,
156
+ HEAD_FIRST: tl.constexpr,
157
+ USE_OFFSETS: tl.constexpr
158
+ ):
159
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
160
+ i_b, i_h = i_bh // H, i_bh % H
161
+ if USE_OFFSETS:
162
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
163
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
164
+ T = eos - bos
165
+ else:
166
+ bos, eos = i_b * T, i_b * T + T
167
+
168
+ if HEAD_FIRST:
169
+ A += i_bh * T * 64
170
+ Ad += i_bh * T * 16
171
+ Ai += i_bh * T * 64
172
+ stride_16 = 16
173
+ stride_64 = 64
174
+ else:
175
+ A += (bos*H + i_h) * 64
176
+ Ad += (bos*H + i_h) * 16
177
+ Ai += (bos*H + i_h) * 64
178
+ stride_16 = 16 * H
179
+ stride_64 = 64 * H
180
+
181
+ p_A_21 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
182
+ p_A_32 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
183
+ p_A_31 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
184
+ p_A_43 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
185
+ p_A_42 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
186
+ p_A_41 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
187
+ p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64, 0), (16, 16), (1, 0))
188
+ p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
189
+ p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
190
+ p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
191
+
192
+ A_21 = tl.load(p_A_21, boundary_check=(0, 1))
193
+ A_32 = tl.load(p_A_32, boundary_check=(0, 1))
194
+ A_31 = tl.load(p_A_31, boundary_check=(0, 1))
195
+ A_43 = tl.load(p_A_43, boundary_check=(0, 1))
196
+ A_42 = tl.load(p_A_42, boundary_check=(0, 1))
197
+ A_41 = tl.load(p_A_41, boundary_check=(0, 1))
198
+
199
+ Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1))
200
+ Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1))
201
+ Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1))
202
+ Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1))
203
+
204
+ Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
205
+ Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee')
206
+ Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee')
207
+
208
+ Ai_31 = -tl.dot(
209
+ Ai_33,
210
+ tl.dot(A_31, Ai_11, input_precision='ieee') +
211
+ tl.dot(A_32, Ai_21, input_precision='ieee'),
212
+ input_precision='ieee'
213
+ )
214
+ Ai_42 = -tl.dot(
215
+ Ai_44,
216
+ tl.dot(A_42, Ai_22, input_precision='ieee') +
217
+ tl.dot(A_43, Ai_32, input_precision='ieee'),
218
+ input_precision='ieee'
219
+ )
220
+ Ai_41 = -tl.dot(
221
+ Ai_44,
222
+ tl.dot(A_41, Ai_11, input_precision='ieee') +
223
+ tl.dot(A_42, Ai_21, input_precision='ieee') +
224
+ tl.dot(A_43, Ai_31, input_precision='ieee'),
225
+ input_precision='ieee'
226
+ )
227
+
228
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64, 0), (16, 16), (1, 0))
229
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0))
230
+ p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0))
231
+ p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0))
232
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
233
+ p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
234
+ p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
235
+ p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
236
+ p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
237
+ p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
238
+ tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
241
+ tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
242
+ tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
243
+ tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
244
+ tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
245
+ tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
246
+ tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
247
+ tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
248
+
249
+
250
+ @input_guard
251
+ def solve_tril(
252
+ A: torch.Tensor,
253
+ cu_seqlens: Optional[torch.Tensor] = None,
254
+ head_first: bool = False,
255
+ output_dtype: torch.dtype = torch.float
256
+ ) -> torch.Tensor:
257
+ """
258
+ Compute the inverse of the lower triangular matrix
259
+ A should be strictly lower triangular, i.e., A.triu() == 0.
260
+
261
+ Args:
262
+ A (torch.Tensor):
263
+ [B, T, H, K] if head_first else [B, H, T, K]
264
+ cu_seqlens (torch.Tensor):
265
+ The cumulative sequence lengths of the input tensor.
266
+ Default: None.
267
+ head_first (bool):
268
+ If False, the input/output tensor is in the shape of [B, T, H, K].
269
+ If True, the input/output tensor is in the shape of [B, H, T, K].
270
+ Default: False
271
+ output_dtype (torch.dtype):
272
+ The dtype of the output tensor. Default: `torch.float`
273
+
274
+ Returns:
275
+ (I + A)^-1 with the same shape as A
276
+ """
277
+ assert A.shape[-1] in [16, 32, 64]
278
+ assert A.dtype == torch.float, "A should be float32."
279
+
280
+ if head_first:
281
+ B, H, T, BT = A.shape
282
+ Ad = torch.empty(B, H, T, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
283
+ else:
284
+ B, T, H, BT = A.shape
285
+ Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
286
+
287
+ indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
288
+ NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, 16)
289
+ solve_tril_16x16_kernel[NT, B * H](
290
+ A=A,
291
+ Ad=Ad,
292
+ offsets=cu_seqlens,
293
+ indices=indices,
294
+ T=T,
295
+ H=H,
296
+ BT=BT,
297
+ HEAD_FIRST=head_first,
298
+ )
299
+ if BT == 16:
300
+ return Ad
301
+
302
+ if head_first:
303
+ Ai = torch.zeros(B, H, T, BT, device=A.device, dtype=output_dtype)
304
+ else:
305
+ Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype)
306
+ merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
307
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
308
+ NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, BT)
309
+ merge_fn[NT, B * H](
310
+ A=A,
311
+ Ad=Ad,
312
+ Ai=Ai,
313
+ offsets=cu_seqlens,
314
+ indices=indices,
315
+ T=T,
316
+ H=H,
317
+ BT=BT,
318
+ HEAD_FIRST=head_first,
319
+ USE_OFFSETS=cu_seqlens is not None
320
+ )
321
+ return Ai
profile_trace/iteration_1024/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_11264/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_15360/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_15360/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_20992/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff