msj19 commited on
Commit
5b2d430
·
verified ·
1 Parent(s): e73a905

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc +0 -0
  3. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc +0 -0
  4. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
  5. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc +0 -0
  6. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc +0 -0
  7. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc +0 -0
  8. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
  9. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc +0 -0
  10. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc +0 -0
  11. fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  12. fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  13. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc +0 -0
  14. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc +0 -0
  15. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc +0 -0
  16. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  17. fla3/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +365 -0
  18. fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +173 -0
  19. fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +123 -0
  20. fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py +273 -0
  21. fla3/ops/generalized_delta_rule/dplr/naive.py +96 -0
  22. fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +164 -0
  23. fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +284 -0
  24. fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc +0 -0
  25. fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  26. fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  27. fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc +0 -0
  28. fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py +452 -0
  29. fla3/ops/generalized_delta_rule/iplr/wy_fast.py +300 -0
  30. fla3/ops/gla/__pycache__/chunk.cpython-310.pyc +0 -0
  31. fla3/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  32. fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc +0 -0
  33. fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  34. fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  35. fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  36. fla3/ops/gla/chunk.py +1300 -0
  37. fla3/ops/gla/fused_recurrent.py +111 -0
  38. fla3/ops/gla/naive.py +41 -0
  39. fla3/ops/gsa/__init__.py +9 -0
  40. fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc +0 -0
  41. fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  42. fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc +0 -0
  43. fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc +0 -0
  44. fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  45. fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  46. fla3/ops/gsa/chunk.py +1136 -0
  47. fla3/ops/gsa/fused_recurrent.py +525 -0
  48. fla3/ops/gsa/naive.py +69 -0
  49. fla3/ops/hgrn/__init__.py +9 -0
  50. fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc +0 -0
fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (366 Bytes). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc ADDED
Binary file (9.43 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc ADDED
Binary file (25 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc ADDED
Binary file (7.6 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/chunk_A_bwd.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import exp, gather
12
+ from ....utils import check_shared_mem, is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
21
+ for num_warps in [2, 4, 8, 16, 32]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BK', 'BT', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_bwd_kernel_intra(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi,
34
+ ge,
35
+ dAqk,
36
+ dAqb,
37
+ dAak,
38
+ dAab,
39
+ dq,
40
+ dk,
41
+ da,
42
+ db,
43
+ dqg,
44
+ dkg,
45
+ dag,
46
+ dbg,
47
+ dgk,
48
+ dgk_offset,
49
+ cu_seqlens,
50
+ chunk_indices,
51
+ scale: tl.constexpr,
52
+ T,
53
+ H: tl.constexpr,
54
+ K: tl.constexpr,
55
+ BT: tl.constexpr,
56
+ BC: tl.constexpr,
57
+ BK: tl.constexpr,
58
+ IS_VARLEN: tl.constexpr,
59
+ GATHER_SUPPORTED: tl.constexpr
60
+ ):
61
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_b, i_h = i_bh // H, i_bh % H
63
+ if IS_VARLEN:
64
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ else:
68
+ bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
69
+
70
+ if i_t * BT >= T:
71
+ return
72
+
73
+ # offset calculation
74
+ ge += (bos*H + i_h) * K
75
+ gi += (bos*H + i_h) * K
76
+ q += (bos*H + i_h) * K
77
+ a += (bos*H + i_h) * K
78
+ b += (bos*H + i_h) * K
79
+ k += (bos*H + i_h) * K
80
+ dq += (bos*H + i_h) * K
81
+ dk += (bos*H + i_h) * K
82
+ da += (bos*H + i_h) * K
83
+ db += (bos*H + i_h) * K
84
+ dqg += (bos*H + i_h) * K
85
+ dag += (bos*H + i_h) * K
86
+ dkg += (bos*H + i_h) * K
87
+ dbg += (bos*H + i_h) * K
88
+ dgk += (bos*H + i_h) * K
89
+ dgk_offset += (bos*H + i_h) * K
90
+ dAqk += (bos*H + i_h) * BT
91
+ dAqb += (bos*H + i_h) * BT
92
+ dAak += (bos*H + i_h) * BT
93
+ dAab += (bos*H + i_h) * BT
94
+
95
+ stride_qk = H*K
96
+ stride_A = H*BT
97
+
98
+ p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
99
+ p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
100
+ # [BC, BK]
101
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
102
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
103
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
104
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
107
+ # intra chunk gradient calculation
108
+ p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
109
+ p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
110
+ p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
111
+ p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
112
+ o_i = tl.arange(0, BC)
113
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
114
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
115
+ p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
116
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
117
+ b_k = tl.load(p_k, boundary_check=(0, 1))
118
+ b_b = tl.load(p_b, boundary_check=(0, 1))
119
+ b_q = tl.load(p_q, boundary_check=(0, 1))
120
+ b_a = tl.load(p_a, boundary_check=(0, 1))
121
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1))
122
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1))
123
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1))
124
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1))
125
+
126
+ # inter chunk gradient calculation
127
+ o_k = i_k * BK + tl.arange(0, BK)
128
+ m_k = o_k < K
129
+ # intra chunk gradient calculation
130
+ for j in range(0, min(BC, T - i_t * BT)):
131
+ # trick to index the block
132
+ if GATHER_SUPPORTED:
133
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
134
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
135
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
136
+ # [1, BK]
137
+ b_kj = gather(b_k, row_idx, axis=0)
138
+ b_bj = gather(b_b, row_idx, axis=0)
139
+ b_gij = gather(b_gi, row_idx, axis=0)
140
+ b_gej = gather(b_ge, row_idx, axis=0)
141
+ b_qj = gather(b_q, row_idx, axis=0)
142
+ b_aj = gather(b_a, row_idx, axis=0)
143
+ # [BC, 1]
144
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
145
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
146
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
147
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
148
+ # [1, BC] -> [BC, 1]
149
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
150
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
151
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
152
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
153
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
154
+ else:
155
+ mask_idx = tl.arange(0, BC) == j
156
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
157
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
158
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
159
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
160
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
161
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
162
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
163
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
164
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
165
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
166
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
167
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
168
+ # [1, BK] b_qj, b_aj
169
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
170
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
171
+
172
+ m_e = o_i[:, None] > j
173
+ m_i = o_i[:, None] >= j
174
+ tmp1 = exp(b_gi - b_gij)
175
+ tmp2 = exp(b_ge - b_gij)
176
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.)
177
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.)
178
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.)
179
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.)
180
+
181
+ m_i = o_i[:, None] <= j
182
+ m_e = o_i[:, None] < j
183
+ tmp1 = exp(b_gij - b_gi)
184
+ tmp2 = exp(b_gej - b_gi)
185
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.)
186
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.)
187
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.)
188
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.)
189
+
190
+ # post processing
191
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
192
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
193
+ p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
194
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
195
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
196
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
197
+ p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
198
+ p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
199
+ p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
200
+ p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
201
+ p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k
202
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
203
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
204
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
205
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
206
+ tmp = exp(b_gn[None, :] - b_gi)
207
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)).to(tl.float32) * tmp
208
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)).to(tl.float32) * tmp
209
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
212
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
213
+ b_dgk = (b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b).to(tl.float32)
214
+ b_dgk_offset = b_da * b_a
215
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
216
+ tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1))
217
+
218
+
219
+ @triton.heuristics({
220
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
221
+ })
222
+ @triton.autotune(
223
+ configs=[
224
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
225
+ for num_warps in [2, 4, 8, 16, 32]
226
+ for num_stages in [2, 3, 4]
227
+ for BK in [32, 64]
228
+ ],
229
+ key=['BK', 'BT', 'K'],
230
+ use_cuda_graph=use_cuda_graph,
231
+ )
232
+ @triton.jit(do_not_specialize=['T'])
233
+ def chunk_dplr_bwd_dgk_kernel(
234
+ dgk,
235
+ dgk_offset,
236
+ dgk_last,
237
+ dgk_output,
238
+ cu_seqlens,
239
+ chunk_indices,
240
+ T,
241
+ H: tl.constexpr,
242
+ K: tl.constexpr,
243
+ BT: tl.constexpr,
244
+ BK: tl.constexpr,
245
+ IS_VARLEN: tl.constexpr,
246
+ ):
247
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_b, i_h = i_bh // H, i_bh % H
249
+ if IS_VARLEN:
250
+ i_tg = i_t
251
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
252
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
253
+ T = eos - bos
254
+ NT = tl.cdiv(T, BT)
255
+ else:
256
+ NT = tl.cdiv(T, BT)
257
+ i_tg = (i_b * NT + i_t).to(tl.int32)
258
+ bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
259
+
260
+ stride_qk = H * K
261
+ dgk += (bos * H + i_h) * K
262
+ dgk_offset += (bos * H + i_h) * K
263
+ dgk_last += (i_tg * H + i_h) * K
264
+ dgk_output += (bos * H + i_h) * K
265
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
266
+ m_k = tl.arange(0, BK) + i_k * BK < K
267
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
268
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
270
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
271
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
272
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
273
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
274
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
275
+ b_dgk_cumsum += b_dgk_last[None, :]
276
+ b_dgk_cumsum -= b_dgk_offset
277
+ p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
278
+ tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1))
279
+
280
+
281
+ def chunk_dplr_bwd_dqk_intra(
282
+ q: torch.Tensor,
283
+ k: torch.Tensor,
284
+ a: torch.Tensor,
285
+ b: torch.Tensor,
286
+ gi: torch.Tensor,
287
+ ge: torch.Tensor,
288
+ dAqk: torch.Tensor,
289
+ dAqb: torch.Tensor,
290
+ dAak: torch.Tensor,
291
+ dAab: torch.Tensor,
292
+ dqg: torch.Tensor,
293
+ dkg: torch.Tensor,
294
+ dag: torch.Tensor,
295
+ dbg: torch.Tensor,
296
+ dgk_last: torch.Tensor,
297
+ scale: float = 1.0,
298
+ cu_seqlens: Optional[torch.LongTensor] = None,
299
+ chunk_size: int = 64,
300
+ ):
301
+ B, T, H, K = q.shape
302
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
303
+ BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
304
+
305
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
306
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
307
+ NK = triton.cdiv(K, BK)
308
+
309
+ dq = torch.empty_like(q)
310
+ dk = torch.empty_like(k)
311
+ da = torch.empty_like(a)
312
+ db = torch.empty_like(b)
313
+ dgk = torch.empty_like(gi, dtype=torch.float)
314
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
315
+
316
+ grid = (NK, NT, B * H)
317
+ chunk_dplr_bwd_kernel_intra[grid](
318
+ q=q,
319
+ k=k,
320
+ a=a,
321
+ b=b,
322
+ gi=gi,
323
+ ge=ge,
324
+ dAqk=dAqk,
325
+ dAqb=dAqb,
326
+ dAak=dAak,
327
+ dAab=dAab,
328
+ dq=dq,
329
+ dk=dk,
330
+ dgk=dgk,
331
+ dgk_offset=dgk_offset,
332
+ dqg=dqg,
333
+ dkg=dkg,
334
+ dag=dag,
335
+ dbg=dbg,
336
+ da=da,
337
+ db=db,
338
+ cu_seqlens=cu_seqlens,
339
+ chunk_indices=chunk_indices,
340
+ scale=scale,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ BT=BT,
345
+ BC=BT,
346
+ BK=BK,
347
+ GATHER_SUPPORTED=is_gather_supported
348
+ )
349
+
350
+ dgk_output = torch.empty_like(dgk)
351
+
352
+ def grid(meta): return (NT, triton.cdiv(K, meta['BK']), B * H)
353
+ chunk_dplr_bwd_dgk_kernel[grid](
354
+ dgk=dgk,
355
+ dgk_offset=dgk_offset,
356
+ dgk_last=dgk_last,
357
+ dgk_output=dgk_output,
358
+ cu_seqlens=cu_seqlens,
359
+ chunk_indices=chunk_indices,
360
+ T=T,
361
+ H=H,
362
+ K=K,
363
+ BT=BT,
364
+ )
365
+ return dq, dk, da, db, dgk_output
fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ cu_seqlens,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ IS_VARLEN: tl.constexpr,
54
+ ):
55
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ NT = tl.cdiv(T, BT)
61
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
62
+ else:
63
+ bos, eos = i_n * T, i_n * T + T
64
+ NT = tl.cdiv(T, BT)
65
+ boh = i_n * NT
66
+
67
+ # [BK, BV]
68
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
69
+ if USE_FINAL_STATE_GRADIENT:
70
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
71
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
72
+
73
+ mask_k = tl.arange(0, BK) < K
74
+ for i_t in range(NT - 1, -1, -1):
75
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
77
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
78
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
79
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
80
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
83
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
84
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ # [BK, BT]
86
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
87
+ # [BT, BK]
88
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
89
+ b_w = tl.load(p_w, boundary_check=(0, 1))
90
+ # [BT, V]
91
+ b_do = tl.load(p_do, boundary_check=(0, 1))
92
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
93
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
94
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
95
+ # [BK, BV]
96
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
97
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
98
+ last_idx = min((i_t + 1) * BT, T) - 1
99
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
100
+ b_dh *= exp(bg_last)[:, None]
101
+ b_dh += b_dh_tmp
102
+
103
+ if USE_INITIAL_STATE:
104
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
105
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
106
+
107
+
108
+ def chunk_dplr_bwd_dhu(
109
+ qg: torch.Tensor,
110
+ bg: torch.Tensor,
111
+ w: torch.Tensor,
112
+ gk: torch.Tensor,
113
+ h0: torch.Tensor,
114
+ dht: Optional[torch.Tensor],
115
+ do: torch.Tensor,
116
+ dv: torch.Tensor,
117
+ cu_seqlens: Optional[torch.LongTensor] = None,
118
+ chunk_size: int = 64
119
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
120
+ B, T, H, K, V = *qg.shape, do.shape[-1]
121
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
122
+ BK = triton.next_power_of_2(K)
123
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
124
+ # H100
125
+ if check_shared_mem('hopper', qg.device.index):
126
+ BV = 64
127
+ BC = 64 if K <= 128 else 32
128
+ elif check_shared_mem('ampere', qg.device.index): # A100
129
+ BV = 32
130
+ BC = 32
131
+ else: # Etc: 4090
132
+ BV = 16
133
+ BC = 16
134
+
135
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
136
+ # N: the actual number of sequences in the batch with either equal or variable lengths
137
+ if cu_seqlens is None:
138
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
139
+ else:
140
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
141
+
142
+ BC = min(BT, BC)
143
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
144
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
145
+
146
+ dh = qg.new_empty(B, NT, H, K, V)
147
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
148
+ dv2 = torch.zeros_like(dv)
149
+
150
+ grid = (NK, NV, N * H)
151
+ chunk_dplr_bwd_kernel_dhu[grid](
152
+ qg=qg,
153
+ bg=bg,
154
+ w=w,
155
+ gk=gk,
156
+ dht=dht,
157
+ dh0=dh0,
158
+ do=do,
159
+ dh=dh,
160
+ dv=dv,
161
+ dv2=dv2,
162
+ cu_seqlens=cu_seqlens,
163
+ chunk_offsets=chunk_offsets,
164
+ T=T,
165
+ H=H,
166
+ K=K,
167
+ V=V,
168
+ BT=BT,
169
+ BC=BC,
170
+ BK=BK,
171
+ BV=BV,
172
+ )
173
+ return dh, dh0, dv2
fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, use_cuda_graph
12
+
13
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BK in BK_LIST
23
+ for BV in BK_LIST
24
+ for num_warps in [2, 4, 8, 16, 32]
25
+ for num_stages in [2, 3, 4]
26
+ ],
27
+ key=['BT'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_dplr_fwd_kernel_o(
32
+ qg,
33
+ v,
34
+ v_new,
35
+ A_qk,
36
+ A_qb,
37
+ h,
38
+ o,
39
+ cu_seqlens,
40
+ chunk_indices,
41
+ T,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if IS_VARLEN:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
67
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
68
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
69
+ b_h = tl.load(p_h, boundary_check=(0, 1))
70
+ b_o += tl.dot(b_qg, b_h)
71
+
72
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
74
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
75
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+
78
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
79
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
80
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
81
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
82
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
83
+ b_v = tl.load(p_v, boundary_check=(0, 1))
84
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
85
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
86
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
87
+
88
+
89
+ def chunk_dplr_fwd_o(
90
+ qg: torch.Tensor,
91
+ v: torch.Tensor,
92
+ v_new: torch.Tensor,
93
+ A_qk: torch.Tensor,
94
+ A_qb: torch.Tensor,
95
+ h: torch.Tensor,
96
+ cu_seqlens: Optional[torch.LongTensor] = None,
97
+ chunk_size: int = 64
98
+ ) -> torch.Tensor:
99
+ B, T, H, K, V = *qg.shape, v.shape[-1]
100
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
101
+
102
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
103
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
104
+
105
+ o = torch.empty_like(v)
106
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
107
+ chunk_dplr_fwd_kernel_o[grid](
108
+ qg=qg,
109
+ v=v,
110
+ v_new=v_new,
111
+ A_qk=A_qk,
112
+ A_qb=A_qb,
113
+ h=h,
114
+ o=o,
115
+ cu_seqlens=cu_seqlens,
116
+ chunk_indices=chunk_indices,
117
+ T=T,
118
+ H=H,
119
+ K=K,
120
+ V=V,
121
+ BT=BT,
122
+ )
123
+ return o
fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils.op import exp
11
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ cu_seqlens,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ IS_VARLEN: tl.constexpr,
53
+ ):
54
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
55
+ i_n, i_h = i_nh // H, i_nh % H
56
+
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
59
+ T = eos - bos
60
+ else:
61
+ bos, eos = i_n * T, i_n * T + T
62
+
63
+ o_k = tl.arange(0, BK)
64
+ o_v = i_v * BV + tl.arange(0, BV)
65
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
66
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
67
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
68
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
69
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
70
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
71
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
72
+
73
+ mask_k = o_k < K
74
+ mask_v = o_v < V
75
+ mask_h = mask_k[None, :] & mask_v[:, None]
76
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
77
+
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
80
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
81
+
82
+ for _ in range(0, T):
83
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
84
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
86
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
87
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
88
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
89
+
90
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
91
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
92
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
93
+
94
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
95
+ p_q += (-1 if REVERSE else 1) * H*K
96
+ p_k += (-1 if REVERSE else 1) * H*K
97
+ p_a += (-1 if REVERSE else 1) * H*K
98
+ p_b += (-1 if REVERSE else 1) * H*K
99
+ p_gk += (-1 if REVERSE else 1) * H*K
100
+ p_v += (-1 if REVERSE else 1) * H*V
101
+ p_o += (-1 if REVERSE else 1) * H*V
102
+
103
+ if STORE_FINAL_STATE:
104
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
105
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
106
+
107
+
108
+ def fused_recurrent_dplr_delta_rule_fwd(
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: Optional[float] = 1.0,
116
+ initial_state: Optional[torch.Tensor] = None,
117
+ output_final_state: bool = False,
118
+ reverse: bool = False,
119
+ cu_seqlens: Optional[torch.LongTensor] = None,
120
+ ):
121
+ B, T, H, K, V = *k.shape, v.shape[-1]
122
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
123
+ BK = triton.next_power_of_2(K)
124
+
125
+ h0 = initial_state
126
+ if output_final_state:
127
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
128
+ else:
129
+ ht = None
130
+ o = torch.empty_like(v)
131
+
132
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
133
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
134
+ q,
135
+ k,
136
+ v,
137
+ a,
138
+ b,
139
+ gk,
140
+ o,
141
+ h0,
142
+ ht,
143
+ cu_seqlens,
144
+ scale,
145
+ T=T,
146
+ B=B,
147
+ H=H,
148
+ K=K,
149
+ V=V,
150
+ BK=BK,
151
+ REVERSE=reverse,
152
+ )
153
+ return o, ht
154
+
155
+
156
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ @input_guard
160
+ @autocast_custom_fwd
161
+ def forward(
162
+ ctx,
163
+ q: torch.Tensor,
164
+ k: torch.Tensor,
165
+ v: torch.Tensor,
166
+ a: torch.Tensor,
167
+ b: torch.Tensor,
168
+ gk: torch.Tensor,
169
+ scale: Optional[float] = 1.0,
170
+ initial_state: Optional[torch.Tensor] = None,
171
+ output_final_state: bool = False,
172
+ reverse: bool = False,
173
+ cu_seqlens: Optional[torch.LongTensor] = None,
174
+ ):
175
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
176
+ q=q,
177
+ k=k,
178
+ v=v,
179
+ a=a,
180
+ b=b,
181
+ gk=gk,
182
+ scale=scale,
183
+ initial_state=initial_state,
184
+ output_final_state=output_final_state,
185
+ reverse=reverse,
186
+ cu_seqlens=cu_seqlens,
187
+ )
188
+ return o, ht
189
+
190
+ @staticmethod
191
+ @input_guard
192
+ @autocast_custom_bwd
193
+ def backward(ctx, do, dht):
194
+ raise NotImplementedError(
195
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
196
+ "This kernel is only for inference. "
197
+ "For training, please use `chunk_dplr_delta_rule`."
198
+ )
199
+
200
+
201
+ def fused_recurrent_dplr_delta_rule(
202
+ q: torch.Tensor,
203
+ k: torch.Tensor,
204
+ v: torch.Tensor,
205
+ a: torch.Tensor,
206
+ b: torch.Tensor,
207
+ gk: torch.Tensor,
208
+ scale: Optional[float] = 1.0,
209
+ initial_state: Optional[torch.Tensor] = None,
210
+ output_final_state: bool = False,
211
+ reverse: bool = False,
212
+ cu_seqlens: Optional[torch.Tensor] = None,
213
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
214
+ r"""
215
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
216
+
217
+ Args:
218
+ q (torch.Tensor):
219
+ queries of shape `[B, T, H, K]`.
220
+ k (torch.Tensor):
221
+ keys of shape `[B, T, H, K]`.
222
+ v (torch.Tensor):
223
+ values of shape `[B, T, H, V]`.
224
+ a (torch.Tensor):
225
+ a of shape `[B, T, H, K]`.
226
+ b (torch.Tensor):
227
+ b of shape `[B, T, H, K]`.
228
+ gk (torch.Tensor):
229
+ gk of shape `[B, T, H, K]`. decay term in log space!
230
+ scale (Optional[int]):
231
+ Scale factor for the RetNet attention scores.
232
+ If not provided, it will default to `1 / sqrt(K)`. Default: 1.
233
+ initial_state (Optional[torch.Tensor]):
234
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
235
+ For equal-length input sequences, `N` equals the batch size `B`.
236
+ Default: `None`.
237
+ output_final_state (Optional[bool]):
238
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
239
+ reverse (Optional[bool]):
240
+ If `True`, process the state passing in reverse order. Default: `False`.
241
+ cu_seqlens (Optional[torch.Tensor]):
242
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
243
+ consistent with the FlashAttention API.
244
+ """
245
+ if cu_seqlens is not None:
246
+ if q.shape[0] != 1:
247
+ raise ValueError(
248
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
249
+ f"Please flatten variable-length inputs before processing."
250
+ )
251
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
252
+ raise ValueError(
253
+ f"The number of initial states is expected to be equal to the number of input sequences, "
254
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
255
+ )
256
+ if scale is None:
257
+ scale = q.shape[-1] ** -0.5
258
+ else:
259
+ assert scale > 0, "scale must be positive"
260
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
261
+ q,
262
+ k,
263
+ v,
264
+ a,
265
+ b,
266
+ gk,
267
+ scale,
268
+ initial_state,
269
+ output_final_state,
270
+ reverse,
271
+ cu_seqlens,
272
+ )
273
+ return o, final_state
fla3/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
12
+
13
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
14
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def prepare_wy_repr_bwd_kernel(
31
+ A_ab_inv,
32
+ A_ak,
33
+ ag,
34
+ v,
35
+ dw,
36
+ du,
37
+ dv,
38
+ dv0,
39
+ dag,
40
+ dAak,
41
+ dAab,
42
+ cu_seqlens,
43
+ chunk_indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ IS_VARLEN: tl.constexpr,
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if IS_VARLEN:
56
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
65
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+
67
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
68
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
69
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
70
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
71
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
72
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
73
+
74
+ for i_v in range(tl.cdiv(V, BV)):
75
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
78
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
79
+ b_v = tl.load(p_v, boundary_check=(0, 1))
80
+ b_du = tl.load(p_du, boundary_check=(0, 1))
81
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
82
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
83
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
84
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
85
+
86
+ m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :]
87
+ b_dA_tmp = tl.where(m_i, b_dA_tmp, 0)
88
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
89
+ b_dA_ak = tl.where(m_i, b_dA_ak, 0)
90
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
91
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
92
+
93
+ for i_k in range(tl.cdiv(K, BK)):
94
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
95
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
96
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
97
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
98
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
99
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
100
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
101
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
102
+
103
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
104
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
105
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
106
+ # denote A = I - lower(A_ab), B = A^-1
107
+ # in the backward pass.
108
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
109
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
110
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
111
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
112
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
113
+ b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0)
114
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
115
+
116
+
117
+ def chunk_dplr_bwd_wy(
118
+ A_ab_inv: torch.Tensor,
119
+ A_ak: torch.Tensor,
120
+ v: torch.Tensor,
121
+ ag: torch.Tensor,
122
+ dw: torch.Tensor,
123
+ du: torch.Tensor,
124
+ dv0: torch.Tensor,
125
+ cu_seqlens: Optional[torch.LongTensor],
126
+ chunk_size: int,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
129
+ B, T, H, K, V = *dw.shape, du.shape[-1]
130
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
131
+
132
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
133
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
134
+ BK = min(triton.next_power_of_2(K), 64)
135
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
136
+
137
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
138
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
139
+ dv = torch.empty_like(v)
140
+ dag = torch.empty_like(ag)
141
+
142
+ prepare_wy_repr_bwd_kernel[(NT, B * H)](
143
+ A_ab_inv=A_ab_inv,
144
+ A_ak=A_ak,
145
+ ag=ag,
146
+ v=v,
147
+ dw=dw,
148
+ du=du,
149
+ dv=dv,
150
+ dv0=dv0,
151
+ dag=dag,
152
+ dAak=dA_ak,
153
+ dAab=dA_ab,
154
+ cu_seqlens=cu_seqlens,
155
+ chunk_indices=chunk_indices,
156
+ T=T,
157
+ H=H,
158
+ K=K,
159
+ V=V,
160
+ BT=BT,
161
+ BK=BK,
162
+ BV=BV,
163
+ )
164
+ return dA_ab, dA_ak, dv, dag
fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import gather
12
+ from ....utils import is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps)
21
+ for num_warps in [1, 2, 4, 8, 16]
22
+ ],
23
+ key=['BT'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def prepare_wy_repr_fwd_kernel_chunk32(
28
+ A_ab,
29
+ A_ab_inv,
30
+ cu_seqlens,
31
+ chunk_indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BC: tl.constexpr, # placeholder, do not delete
36
+ IS_VARLEN: tl.constexpr,
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
47
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
49
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
50
+ for i in range(1, BT):
51
+ mask = tl.arange(0, BT) == i
52
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
53
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
54
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
55
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
56
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
57
+
58
+
59
+ @triton.heuristics({
60
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
61
+ })
62
+ @triton.autotune(
63
+ configs=[
64
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
65
+ for num_warps in [2, 4, 8]
66
+ for num_stages in [2, 3, 4]
67
+ ],
68
+ key=['BC'],
69
+ use_cuda_graph=use_cuda_graph,
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def prepare_wy_repr_fwd_kernel_chunk64(
73
+ A_ab,
74
+ A_ab_inv,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ BT: tl.constexpr,
80
+ BC: tl.constexpr,
81
+ IS_VARLEN: tl.constexpr,
82
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
83
+ ):
84
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
88
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
89
+ T = eos - bos
90
+ else:
91
+ bos, eos = i_b * T, i_b * T + T
92
+
93
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
95
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
96
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
98
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
99
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
100
+
101
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
102
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
103
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
104
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
105
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
106
+
107
+ for i in range(1, BC):
108
+ if GATHER_SUPPORTED:
109
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
110
+ # [1, BK] -> [BK]
111
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
112
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
113
+ else:
114
+ mask = tl.arange(0, BC) == i
115
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
116
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
117
+ mask = tl.arange(0, BC) == i
118
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
119
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
120
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
121
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
122
+ b_A = tl.where(mask[:, None], b_a, b_A)
123
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
124
+
125
+ # blockwise computation of lower triangular matrix's inverse
126
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
127
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
128
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
129
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
130
+ # tl.debug_barrier()
131
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
132
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
133
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
134
+ # causal mask
135
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
136
+
137
+
138
+ @triton.heuristics({
139
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
140
+ })
141
+ @triton.autotune(
142
+ configs=[
143
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
144
+ for num_warps in [2, 4, 8, 16]
145
+ for num_stages in [2, 3, 4]
146
+ ],
147
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
148
+ use_cuda_graph=use_cuda_graph,
149
+ )
150
+ @triton.jit(do_not_specialize=['T'])
151
+ def wu_fwd_kernel(
152
+ w,
153
+ u,
154
+ ag,
155
+ v,
156
+ A_ab_inv,
157
+ A_ak,
158
+ cu_seqlens,
159
+ chunk_indices,
160
+ T,
161
+ H: tl.constexpr,
162
+ K: tl.constexpr,
163
+ V: tl.constexpr,
164
+ BT: tl.constexpr,
165
+ BK: tl.constexpr,
166
+ BV: tl.constexpr,
167
+ IS_VARLEN: tl.constexpr,
168
+ ):
169
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
170
+ i_b, i_h = i_bh // H, i_bh % H
171
+ if IS_VARLEN:
172
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
173
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
174
+ T = eos - bos
175
+ else:
176
+ bos, eos = i_b * T, i_b * T + T
177
+ o_s = tl.arange(0, BT)
178
+
179
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
180
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+
182
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
183
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
184
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
185
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
186
+ # let's use tf32 here
187
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
188
+ # (SY 01/04) should be bf16 or tf32? To verify.
189
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
190
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
191
+
192
+ for i_k in range(tl.cdiv(K, BK)):
193
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
195
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
196
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
197
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
198
+
199
+ for i_v in range(tl.cdiv(V, BV)):
200
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
201
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
202
+ b_v = tl.load(p_v, boundary_check=(0, 1))
203
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
204
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
205
+
206
+
207
+ def wu_fwd(
208
+ ag: torch.Tensor,
209
+ v: torch.Tensor,
210
+ A_ak: torch.Tensor,
211
+ A_ab_inv: torch.Tensor,
212
+ cu_seqlens: Optional[torch.LongTensor],
213
+ chunk_size: int
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ B, T, H, K, V = *ag.shape, v.shape[-1]
216
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
217
+
218
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
219
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
220
+ BK = min(triton.next_power_of_2(K), 64)
221
+ BV = min(triton.next_power_of_2(V), 64)
222
+
223
+ w = torch.empty_like(ag)
224
+ u = torch.empty_like(v)
225
+ wu_fwd_kernel[(NT, B * H)](
226
+ ag=ag,
227
+ v=v,
228
+ A_ak=A_ak,
229
+ A_ab_inv=A_ab_inv,
230
+ w=w,
231
+ u=u,
232
+ cu_seqlens=cu_seqlens,
233
+ chunk_indices=chunk_indices,
234
+ T=T,
235
+ H=H,
236
+ K=K,
237
+ V=V,
238
+ BT=BT,
239
+ BK=BK,
240
+ BV=BV,
241
+ )
242
+ return w, u
243
+
244
+
245
+ def prepare_wy_repr_fwd(
246
+ ag: torch.Tensor,
247
+ v: torch.Tensor,
248
+ A_ak: torch.Tensor,
249
+ A_ab: torch.Tensor,
250
+ cu_seqlens: Optional[torch.LongTensor],
251
+ chunk_size: int = 64
252
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
253
+ B, T, H, _ = ag.shape
254
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
255
+
256
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
257
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
258
+ BC = min(BT, 32)
259
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
260
+ A_ab_inv = torch.empty_like(A_ab)
261
+ fwd_fn[(NT, B * H)](
262
+ A_ab=A_ab,
263
+ A_ab_inv=A_ab_inv,
264
+ cu_seqlens=cu_seqlens,
265
+ chunk_indices=chunk_indices,
266
+ T=T,
267
+ H=H,
268
+ BT=BT,
269
+ BC=BC,
270
+ )
271
+ w, u = wu_fwd(
272
+ ag=ag,
273
+ v=v,
274
+ A_ak=A_ak,
275
+ A_ab_inv=A_ab_inv,
276
+ cu_seqlens=cu_seqlens,
277
+ chunk_size=BT
278
+ )
279
+ return w, u, A_ab_inv
280
+
281
+
282
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
283
+
284
+ fwd_wu = wu_fwd
fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (366 Bytes). View file
 
fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (25.6 kB). View file
 
fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (19.7 kB). View file
 
fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....utils import input_guard
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BV in [32, 64]
22
+ for num_warps in [2, 4, 8, 16]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=["BK"],
26
+ )
27
+ @triton.jit
28
+ def fused_recurrent_fwd_kernel(
29
+ q, # query [B, H, L, K]
30
+ k, # key [B, H, L, V]
31
+ v, # value [B, H, L, V].
32
+ a, # a [B, H, L, K]
33
+ b, # b [B, H, L, K]
34
+ o, # output [B, H, L, V]
35
+ ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
36
+ h0, # initial hidden state [B, H, K, V]
37
+ ht, # final hidden state [B, H, K, V]
38
+ cu_seqlens, # varlen cu_seqlens
39
+ scale, # K ** -0.5
40
+ H, # n_heads
41
+ T, # seq_len
42
+ K: tl.constexpr, # K
43
+ V: tl.constexpr, # V
44
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
45
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
46
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
47
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
51
+ i_n, i_h = i_nh // H, i_nh % H
52
+
53
+ if IS_VARLEN:
54
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
55
+ T = eos - bos
56
+ else:
57
+ bos, eos = i_n * T, i_n * T + T
58
+
59
+ p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
60
+ p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
61
+ p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
62
+ p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
63
+ p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
64
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
65
+ p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
66
+
67
+ mask_k = tl.arange(0, BK) < K
68
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
69
+ mask_h = mask_k[None, :] & mask_v[:, None]
70
+
71
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
72
+
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
75
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
76
+
77
+ for _ in range(0, T):
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
81
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
82
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
83
+ # to store
84
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
85
+ b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
86
+ b_o = b_h * b_q[None, :]
87
+ b_o = tl.sum(b_o, axis=1)
88
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
89
+ tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
90
+ p_q += K*H
91
+ p_k += K*H
92
+ p_o += V*H
93
+ p_v += V*H
94
+ p_ha += V*H
95
+ p_a += K*H
96
+ p_b += K*H
97
+
98
+ if STORE_FINAL_STATE:
99
+ p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
100
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
101
+
102
+
103
+ @triton.heuristics({
104
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
105
+ 'USE_DHT': lambda args: args['dht'] is not None,
106
+ 'USE_DH0': lambda args: args['dh0'] is not None,
107
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
108
+ })
109
+ @triton.autotune(
110
+ configs=[
111
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
112
+ for num_warps in [2, 4, 8, 16]
113
+ for num_stages in [2, 3]
114
+ ],
115
+ key=["BK", "BV"],
116
+ )
117
+ @triton.jit
118
+ def fused_recurrent_bwd_kernel(
119
+ # B: batch_size, H: n_heads, T: seq_len, D: b_dhead
120
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
121
+ q, # query [B, H, L, K]
122
+ k, # key [B, H, L, V]
123
+ v, # value [B, H, L, V]
124
+ a, # a [B, H, L, K]
125
+ b, # b [B, H, L, K]
126
+ ha, # ha [B, H, L, V]
127
+ dht, # gradient of final state [B, H, K, V]
128
+ dh0, # gradient of initial state [B, H, K, V]
129
+ do, # gradient of output [B, H, L, V]
130
+ dq, # gradient of query [NV, B, H, L, K]
131
+ dk, # gradient of key [NV, B, H, L, K]
132
+ dv, # gradient of value [NK, B, H, L, V]
133
+ da, # gradient of a [NV, B, H, L, K]
134
+ db, # gradient of b [NV, B, H, L, K]
135
+ dha, # gradient of ha [NK, B, H, L, V]
136
+ h0, # initial state [B, H, K, V]
137
+ scale, # K ** -0.5
138
+ cu_seqlens, # cu_seqlens
139
+ B, # batch_size
140
+ H, # n_heads
141
+ T, # seq_len
142
+ K: tl.constexpr, # K
143
+ V: tl.constexpr, # V
144
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
145
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
146
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
147
+ USE_DH0: tl.constexpr, # whether to use dh0
148
+ USE_DHT: tl.constexpr, # whether to use dht
149
+ IS_VARLEN: tl.constexpr,
150
+ ):
151
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
152
+ i_n, i_h = i_nh // H, i_nh % H
153
+ dk += i_v * B * H * K * T
154
+ db += i_v * B * H * K * T
155
+ dq += i_v * B * H * K * T
156
+ da += i_v * B * H * K * T
157
+ if IS_VARLEN:
158
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
159
+ T = eos - bos
160
+ else:
161
+ bos, eos = i_n * T, i_n * T + T
162
+ mask_k = tl.arange(0, BK) < K
163
+ mask_v = (tl.arange(0, BV) + i_v * BV) < V
164
+
165
+ q += (bos * H + i_h) * K
166
+ k += (bos * H + i_h) * K
167
+ v += (bos * H + i_h) * V + i_v * BV
168
+ ha += (bos * H + i_h) * V + i_v * BV
169
+ a += (bos * H + i_h) * K
170
+ b += (bos * H + i_h) * K
171
+ do += (bos * H + i_h) * V + i_v * BV
172
+ dq += (bos * H + i_h) * K
173
+ dk += (bos * H + i_h) * K
174
+ dv += (bos * H + i_h) * V + i_v * BV
175
+ da += (bos * H + i_h) * K
176
+ db += (bos * H + i_h) * K
177
+ dha += (bos * H + i_h) * V + i_v * BV
178
+
179
+ p_q = q + tl.arange(0, BK) + (T - 1) * H*K
180
+ p_k = k + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_v = v + tl.arange(0, BV) + (T - 1) * H*V
182
+ p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V
183
+ p_a = a + tl.arange(0, BK) + (T - 1) * H*K
184
+ p_b = b + tl.arange(0, BK) + (T - 1) * H*K
185
+ p_do = do + tl.arange(0, BV) + (T - 1) * H*V
186
+ p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K
187
+ p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V
188
+ p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V
189
+ p_db = db + tl.arange(0, BK) + (T - 1) * H*K
190
+ p_da = da + tl.arange(0, BK) + (T - 1) * H*K
191
+ p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K
192
+
193
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
194
+ if USE_DHT:
195
+ p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
196
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
197
+
198
+ for _ in range(T):
199
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
200
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
201
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
202
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
203
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
204
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
205
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
206
+
207
+ b_dh += b_q[:, None] * b_do[None, :]
208
+ d_k = tl.sum(b_dh * b_v[None, :], axis=1)
209
+ d_v = tl.sum(b_dh * b_k[:, None], axis=0)
210
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
212
+
213
+ b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
214
+ tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
215
+ b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
216
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
217
+
218
+ b_dh += b_dha[None, :] * b_a[:, None]
219
+ p_do -= H*V
220
+ p_q -= H*K
221
+ p_k -= H*K
222
+ p_v -= H*V
223
+ p_dk -= H*K
224
+ p_dv -= H*V
225
+ p_b -= H*K
226
+ p_db -= H*K
227
+ p_a -= H*K
228
+ p_dha -= H*V
229
+ p_ha -= H*V
230
+
231
+ if USE_DH0:
232
+ p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
233
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
234
+
235
+ tl.debug_barrier()
236
+
237
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
238
+
239
+ if USE_INITIAL_STATE:
240
+ mask_kv = mask_k[:, None] & mask_v[None, :]
241
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
242
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
243
+
244
+ p_k = k + tl.arange(0, BK)
245
+ p_v = v + tl.arange(0, BV)
246
+ p_ha = ha + tl.arange(0, BV)
247
+ p_do = do + tl.arange(0, BV)
248
+ p_dha = dha + tl.arange(0, BV)
249
+ p_da = da + tl.arange(0, BK)
250
+ p_dq = dq + tl.arange(0, BK)
251
+ p_b = b + tl.arange(0, BK)
252
+
253
+ for i in range(0, T):
254
+ b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
255
+ d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
256
+ tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
257
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
258
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
259
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
260
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
261
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
262
+ b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
263
+ _d_q = b_h * b_do[None, :]
264
+ d_q = tl.sum(_d_q, axis=1) * scale
265
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
266
+
267
+ p_k += H*K
268
+ p_do += H*V
269
+ p_v += H*V
270
+ p_da += H*K
271
+ p_dha += H*V
272
+ p_ha += H*V
273
+ p_dq += H*K
274
+ p_b += H*K
275
+
276
+
277
+ class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
278
+
279
+ @staticmethod
280
+ @input_guard
281
+ def forward(
282
+ ctx,
283
+ q: torch.Tensor,
284
+ k: torch.Tensor,
285
+ v: torch.Tensor,
286
+ a: torch.Tensor,
287
+ b: torch.Tensor,
288
+ scale: Optional[float] = None,
289
+ initial_state: Optional[torch.Tensor] = None,
290
+ output_final_state: bool = False,
291
+ cu_seqlens: Optional[torch.LongTensor] = None
292
+ ):
293
+ B, T, H, K, V = *k.shape, v.shape[-1]
294
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
295
+
296
+ BK = triton.next_power_of_2(K)
297
+ if output_final_state:
298
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
299
+ else:
300
+ final_state = None
301
+
302
+ ha = torch.empty_like(v, dtype=torch.float32)
303
+
304
+ def grid(meta): return (
305
+ triton.cdiv(V, meta['BV']),
306
+ N * H
307
+ )
308
+ o = torch.empty_like(v)
309
+ fused_recurrent_fwd_kernel[grid](
310
+ q=q,
311
+ k=k,
312
+ v=v,
313
+ a=a,
314
+ b=b,
315
+ o=o,
316
+ ha=ha,
317
+ h0=initial_state,
318
+ ht=final_state,
319
+ scale=scale,
320
+ cu_seqlens=cu_seqlens,
321
+ H=H,
322
+ T=T,
323
+ K=K,
324
+ V=V,
325
+ BK=BK,
326
+ )
327
+ ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
328
+ ctx.scale = scale
329
+ ctx.cu_seqlens = cu_seqlens
330
+ return o, final_state
331
+
332
+ @staticmethod
333
+ @input_guard
334
+ def backward(ctx, do, dht):
335
+ q, k, v, a, b, ha, initial_state = ctx.saved_tensors
336
+ B, T, H, K, V = *q.shape, v.shape[-1]
337
+ N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1
338
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
339
+ NV = triton.cdiv(V, BV)
340
+ scale = ctx.scale
341
+
342
+ dq = q.new_empty(NV, *q.shape)
343
+ dk = k.new_empty(NV, *k.shape)
344
+ da = a.new_empty(NV, *a.shape)
345
+ db = b.new_empty(NV, *b.shape)
346
+ dv = torch.empty_like(v)
347
+ dha = torch.empty_like(ha)
348
+ grid = (NV, N * H)
349
+
350
+ if initial_state is not None and initial_state.requires_grad:
351
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
352
+ else:
353
+ dh0 = None
354
+
355
+ fused_recurrent_bwd_kernel[grid](
356
+ q=q,
357
+ k=k,
358
+ v=v,
359
+ a=a,
360
+ b=b,
361
+ ha=ha,
362
+ dht=dht,
363
+ dh0=dh0,
364
+ do=do,
365
+ dq=dq,
366
+ dk=dk,
367
+ dv=dv,
368
+ da=da,
369
+ db=db,
370
+ dha=dha,
371
+ h0=initial_state,
372
+ scale=scale,
373
+ cu_seqlens=ctx.cu_seqlens,
374
+ B=B,
375
+ H=H,
376
+ T=T,
377
+ K=K,
378
+ V=V,
379
+ BK=BK,
380
+ BV=BV,
381
+ )
382
+ dq = dq.sum(0)
383
+ dk = dk.sum(0)
384
+ da = da.sum(0)
385
+ db = db.sum(0)
386
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None
387
+
388
+
389
+ def fused_recurrent_iplr_delta_rule(
390
+ q: torch.Tensor,
391
+ k: torch.Tensor,
392
+ v: torch.Tensor,
393
+ a: torch.Tensor,
394
+ b: torch.Tensor,
395
+ scale: float = None,
396
+ initial_state: torch.Tensor = None,
397
+ output_final_state: bool = False,
398
+ cu_seqlens: Optional[torch.Tensor] = None,
399
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ r"""
401
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
402
+
403
+ Args:
404
+ q (torch.Tensor):
405
+ queries of shape `[B, T, H, K]`
406
+ k (torch.Tensor):
407
+ keys of shape `[B, T, H, K]`
408
+ v (torch.Tensor):
409
+ values of shape `[B, T, H, V]`
410
+ a (torch.Tensor):
411
+ as of shape `[B, T, H, K]`
412
+ b (torch.Tensor):
413
+ bs of shape `[B, T, H, K]`
414
+ scale (Optional[int]):
415
+ Scale factor for the RetNet attention scores.
416
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
417
+ initial_state (Optional[torch.Tensor]):
418
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
419
+ output_final_state (Optional[bool]):
420
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
421
+ cu_seqlens (torch.LongTensor):
422
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
423
+ consistent with the FlashAttention API.
424
+
425
+ """
426
+ if cu_seqlens is not None:
427
+ if q.shape[0] != 1:
428
+ raise ValueError(
429
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
430
+ f"Please flatten variable-length inputs before processing."
431
+ )
432
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
433
+ raise ValueError(
434
+ f"The number of initial states is expected to be equal to the number of input sequences, "
435
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
436
+ )
437
+ if scale is None:
438
+ scale = q.shape[-1] ** -0.5
439
+ else:
440
+ assert scale > 0, "scale must be positive"
441
+ o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
442
+ q,
443
+ k,
444
+ v,
445
+ a,
446
+ b,
447
+ scale,
448
+ initial_state,
449
+ output_final_state,
450
+ cu_seqlens
451
+ )
452
+ return o, final_state
fla3/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ....ops.utils import prepare_chunk_indices
12
+ from ....utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4, 8, 16]
24
+ ],
25
+ key=['BK']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def prepare_wy_repr_fwd_kernel_chunk32(
29
+ a,
30
+ b,
31
+ A,
32
+ cu_seqlens,
33
+ chunk_indices,
34
+ T,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ BT: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BC: tl.constexpr, # dummy placeholder
40
+ IS_VARLEN: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if IS_VARLEN:
45
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
54
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
55
+ b_a = tl.load(p_a, boundary_check=(0, 1))
56
+ b_b = tl.load(p_b, boundary_check=(0, 1))
57
+ b_A += tl.dot(b_a, b_b)
58
+
59
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
60
+ for i in range(1, BT):
61
+ mask = tl.arange(0, BT) == i
62
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
63
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
64
+ b_A = tl.where(mask[:, None], b_a, b_A)
65
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
66
+
67
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
69
+
70
+
71
+ @triton.heuristics({
72
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
73
+ })
74
+ @triton.autotune(
75
+ configs=[
76
+ triton.Config({}, num_warps=num_warps)
77
+ for num_warps in [1, 2, 4, 8, 16]
78
+ ],
79
+ key=['BK']
80
+ )
81
+ @triton.jit(do_not_specialize=['T'])
82
+ def prepare_wy_repr_fwd_kernel_chunk64(
83
+ a,
84
+ b,
85
+ A,
86
+ cu_seqlens,
87
+ chunk_indices,
88
+ T,
89
+ H: tl.constexpr,
90
+ K: tl.constexpr,
91
+ BT: tl.constexpr,
92
+ BK: tl.constexpr,
93
+ BC: tl.constexpr,
94
+ IS_VARLEN: tl.constexpr,
95
+ ):
96
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
97
+ i_b, i_h = i_bh // H, i_bh % H
98
+ if IS_VARLEN:
99
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
100
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
101
+ T = eos - bos
102
+ else:
103
+ bos, eos = i_b * T, i_b * T + T
104
+
105
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
106
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
107
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
108
+
109
+ for i_k in range(tl.cdiv(K, BK)):
110
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
111
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
112
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
113
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
114
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
115
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
116
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
117
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
118
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
119
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
120
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
121
+
122
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
123
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
124
+
125
+ for i in range(1, BC):
126
+ mask = tl.arange(0, BC) == i
127
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
128
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
129
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
130
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
131
+ b_A = tl.where(mask[:, None], b_a, b_A)
132
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
133
+
134
+ # blockwise computation of lower triangular matrix's inverse
135
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
136
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
137
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
138
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
139
+
140
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
141
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
142
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
143
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
144
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
145
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
146
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
147
+ # causal mask
148
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
149
+
150
+
151
+ @triton.heuristics({
152
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
153
+ })
154
+ @triton.autotune(
155
+ configs=[
156
+ triton.Config({}, num_warps=num_warps)
157
+ for num_warps in NUM_WARPS
158
+ ],
159
+ key=['BT', 'BK', 'BV']
160
+ )
161
+ @triton.jit(do_not_specialize=['T'])
162
+ def wu_fwd_kernel(
163
+ w,
164
+ u,
165
+ a,
166
+ k,
167
+ v,
168
+ A,
169
+ cu_seqlens,
170
+ chunk_indices,
171
+ T,
172
+ H: tl.constexpr,
173
+ K: tl.constexpr,
174
+ V: tl.constexpr,
175
+ BT: tl.constexpr,
176
+ BK: tl.constexpr,
177
+ BV: tl.constexpr,
178
+ IS_VARLEN: tl.constexpr,
179
+ ):
180
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
181
+ i_b, i_h = i_bh // H, i_bh % H
182
+ if IS_VARLEN:
183
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
184
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
185
+ T = eos - bos
186
+ else:
187
+ bos, eos = i_b * T, i_b * T + T
188
+
189
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
190
+
191
+ b_A = tl.load(p_A, boundary_check=(0, 1))
192
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
193
+
194
+ for i_k in range(tl.cdiv(K, BK)):
195
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
198
+ b_k = tl.load(p_k, boundary_check=(0, 1))
199
+ b_a = tl.load(p_a, boundary_check=(0, 1))
200
+ b_w = tl.dot(b_A, b_a)
201
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
202
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
205
+ b_Aak = b_Aak.to(k.dtype.element_ty)
206
+
207
+ for i_v in range(tl.cdiv(V, BV)):
208
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
209
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
210
+ b_v = tl.load(p_v, boundary_check=(0, 1))
211
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
212
+ b_u = tl.dot(b_A, b_v)
213
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
214
+
215
+
216
+ def prepare_wy_repr_fwd(
217
+ a: torch.Tensor,
218
+ b: torch.Tensor,
219
+ v: torch.Tensor,
220
+ k: torch.Tensor,
221
+ cu_seqlens: Optional[torch.LongTensor],
222
+ chunk_size: int = 64
223
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
224
+ B, T, H, K = a.shape
225
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
226
+
227
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
228
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
229
+ BC = min(BT, 32)
230
+ BK = min(triton.next_power_of_2(K), 64)
231
+
232
+ A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype)
233
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
234
+
235
+ fwd_fn[(NT, B * H)](
236
+ a=a,
237
+ b=b,
238
+ A=A,
239
+ cu_seqlens=cu_seqlens,
240
+ chunk_indices=chunk_indices,
241
+ T=T,
242
+ H=H,
243
+ K=K,
244
+ BT=BT,
245
+ BK=BK,
246
+ BC=BC,
247
+ )
248
+ w, u = wu_fwd(
249
+ a=a,
250
+ v=v,
251
+ k=k,
252
+ A=A,
253
+ cu_seqlens=cu_seqlens,
254
+ chunk_size=chunk_size
255
+ )
256
+ return w, u, A
257
+
258
+
259
+ def wu_fwd(
260
+ a: torch.Tensor,
261
+ v: torch.Tensor,
262
+ k: torch.Tensor,
263
+ A: torch.Tensor,
264
+ cu_seqlens: Optional[torch.LongTensor],
265
+ chunk_size: int
266
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
267
+ B, T, H, K, V = *a.shape, v.shape[-1]
268
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
269
+
270
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
271
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
272
+ CONST_TILING = 64 if check_shared_mem() else 32
273
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
274
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
275
+
276
+ u = torch.empty_like(v)
277
+ w = torch.empty_like(a)
278
+ wu_fwd_kernel[(NT, B*H)](
279
+ a=a,
280
+ v=v,
281
+ w=w,
282
+ u=u,
283
+ A=A,
284
+ k=k,
285
+ cu_seqlens=cu_seqlens,
286
+ chunk_indices=chunk_indices,
287
+ T=T,
288
+ H=H,
289
+ K=K,
290
+ V=V,
291
+ BT=BT,
292
+ BK=BK,
293
+ BV=BV,
294
+ )
295
+ return w, u
296
+
297
+
298
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
299
+
300
+ fwd_wu = wu_fwd
fla3/ops/gla/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (29.5 kB). View file
 
fla3/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (66.3 kB). View file
 
fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (35 kB). View file
 
fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (4.83 kB). View file
 
fla3/ops/gla/chunk.py ADDED
@@ -0,0 +1,1300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
11
+ from fla.ops.utils import prepare_chunk_indices
12
+ from fla.ops.utils.cumsum import chunk_local_cumsum
13
+ from fla.ops.utils.op import exp, safe_exp
14
+ from fla.utils import check_shared_mem, input_guard
15
+
16
+ BK_LIST = [32, 64] if check_shared_mem() else [16, 32]
17
+ BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32]
18
+
19
+
20
+ @triton.heuristics({
21
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
26
+ for BK in [32, 64]
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=["BC"]
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gla_fwd_A_kernel_intra_sub_inter(
34
+ q,
35
+ k,
36
+ g,
37
+ A,
38
+ cu_seqlens,
39
+ chunk_indices,
40
+ scale,
41
+ T,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BC: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ NC: tl.constexpr,
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+ i_i, i_j = i_c // NC, i_c % NC
53
+ if IS_VARLEN:
54
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
55
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
56
+ T = eos - bos
57
+ else:
58
+ bos, eos = i_b * T, i_b * T + T
59
+
60
+ if i_t * BT + i_i * BC >= T:
61
+ return
62
+ if i_i <= i_j:
63
+ return
64
+
65
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
66
+ for i_k in range(tl.cdiv(K, BK)):
67
+ o_k = i_k * BK + tl.arange(0, BK)
68
+ m_k = o_k < K
69
+
70
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
71
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
72
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
73
+ p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
74
+ p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
75
+
76
+ # [BK,]
77
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
78
+ # [BC, BK]
79
+ b_q = tl.load(p_q, boundary_check=(0, 1))
80
+ b_g = tl.load(p_g, boundary_check=(0, 1))
81
+ b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
82
+ # [BK, BC]
83
+ b_k = tl.load(p_k, boundary_check=(0, 1))
84
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
85
+ b_kg = b_k * exp(b_gn[:, None] - b_gk)
86
+ # [BC, BC] using tf32 to improve precision here.
87
+ b_A += tl.dot(b_qg, b_kg)
88
+
89
+ p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
90
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
91
+
92
+
93
+ @triton.heuristics({
94
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
95
+ })
96
+ @triton.autotune(
97
+ configs=[
98
+ triton.Config({}, num_warps=1),
99
+ triton.Config({}, num_warps=2),
100
+ triton.Config({}, num_warps=4),
101
+ triton.Config({}, num_warps=8),
102
+ ],
103
+ key=["BK", "BT"]
104
+ )
105
+ @triton.jit(do_not_specialize=['T'])
106
+ def chunk_gla_fwd_A_kernel_intra_sub_intra(
107
+ q,
108
+ k,
109
+ g,
110
+ A,
111
+ cu_seqlens,
112
+ chunk_indices,
113
+ scale,
114
+ T,
115
+ H: tl.constexpr,
116
+ K: tl.constexpr,
117
+ BT: tl.constexpr,
118
+ BC: tl.constexpr,
119
+ BK: tl.constexpr,
120
+ IS_VARLEN: tl.constexpr,
121
+ ):
122
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
123
+ i_b, i_h = i_bh // H, i_bh % H
124
+ i_j = i_i
125
+ if IS_VARLEN:
126
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
127
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
128
+ T = eos - bos
129
+ else:
130
+ bos, eos = i_b * T, i_b * T + T
131
+
132
+ if i_t * BT + i_i * BC >= T:
133
+ return
134
+
135
+ o_i = tl.arange(0, BC)
136
+ o_k = tl.arange(0, BK)
137
+ m_k = o_k < K
138
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
139
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
140
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
141
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
142
+ p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
143
+ p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
144
+
145
+ b_q = tl.load(p_q, boundary_check=(0, 1))
146
+ b_g = tl.load(p_g, boundary_check=(0, 1))
147
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
148
+ b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)
149
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
150
+ b_A = tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1)
151
+ b_A = tl.where(o_i >= j, b_A * scale, 0.)
152
+
153
+ tl.store(A + o_A + j, b_A, mask=m_A)
154
+ p_k += H*K
155
+ p_gk += H*K
156
+
157
+
158
+ @triton.heuristics({
159
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
160
+ })
161
+ @triton.autotune(
162
+ configs=[
163
+ triton.Config({}, num_warps=1),
164
+ triton.Config({}, num_warps=2),
165
+ triton.Config({}, num_warps=4),
166
+ triton.Config({}, num_warps=8),
167
+ ],
168
+ key=['BC', 'BK']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_gla_fwd_A_kernel_intra_sub_intra_split(
172
+ q,
173
+ k,
174
+ g,
175
+ A,
176
+ cu_seqlens,
177
+ chunk_indices,
178
+ scale,
179
+ T,
180
+ B: tl.constexpr,
181
+ H: tl.constexpr,
182
+ K: tl.constexpr,
183
+ BT: tl.constexpr,
184
+ BC: tl.constexpr,
185
+ BK: tl.constexpr,
186
+ NC: tl.constexpr,
187
+ IS_VARLEN: tl.constexpr,
188
+ ):
189
+ i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
190
+ i_b, i_h = i_bh // H, i_bh % H
191
+ i_t, i_i = i_tc // NC, i_tc % NC
192
+ i_j = i_i
193
+ if IS_VARLEN:
194
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
195
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
196
+ all = T
197
+ T = eos - bos
198
+ else:
199
+ bos, eos = i_b * T, i_b * T + T
200
+ all = B * T
201
+
202
+ if i_t * BT + i_i * BC >= T:
203
+ return
204
+
205
+ o_i = tl.arange(0, BC)
206
+ o_k = i_k * BK + tl.arange(0, BK)
207
+ m_k = o_k < K
208
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
209
+
210
+ o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC
211
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
212
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
213
+ p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
214
+ p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
215
+
216
+ b_q = tl.load(p_q, boundary_check=(0, 1))
217
+ b_g = tl.load(p_g, boundary_check=(0, 1))
218
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
219
+ b_A = tl.zeros([BC], dtype=tl.float32)
220
+ b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)
221
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
222
+ b_A += tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1)
223
+ b_A = tl.where(o_i >= j, b_A * scale, 0.)
224
+ tl.store(A + o_A + j, b_A, mask=m_A)
225
+ p_k += H*K
226
+ p_gk += H*K
227
+
228
+
229
+ @triton.heuristics({
230
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
231
+ })
232
+ @triton.autotune(
233
+ configs=[
234
+ triton.Config({}, num_warps=1),
235
+ triton.Config({}, num_warps=2),
236
+ triton.Config({}, num_warps=4),
237
+ triton.Config({}, num_warps=8),
238
+ ],
239
+ key=['BC']
240
+ )
241
+ @triton.jit(do_not_specialize=['T'])
242
+ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge(
243
+ A,
244
+ A2,
245
+ cu_seqlens,
246
+ chunk_indices,
247
+ T,
248
+ B: tl.constexpr,
249
+ H: tl.constexpr,
250
+ BT: tl.constexpr,
251
+ BC: tl.constexpr,
252
+ NK: tl.constexpr,
253
+ IS_VARLEN: tl.constexpr,
254
+ ):
255
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
256
+ i_b, i_h = i_bh // H, i_bh % H
257
+ if IS_VARLEN:
258
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
259
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
260
+ all = T
261
+ T = eos - bos
262
+ else:
263
+ bos, eos = i_b * T, i_b * T + T
264
+ all = B * T
265
+
266
+ if i_t * BT + i_c * BC >= T:
267
+ return
268
+
269
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
270
+ for i_k in range(0, NK):
271
+ p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0))
272
+ b_A += tl.load(p_A, boundary_check=(0, 1))
273
+ p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0))
274
+ tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1))
275
+
276
+
277
+ @triton.heuristics({
278
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
279
+ })
280
+ @triton.autotune(
281
+ configs=[
282
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
283
+ for BK in [32, 64]
284
+ for BV in [64, 128]
285
+ for num_warps in [2, 4, 8]
286
+ ],
287
+ key=['BT'],
288
+ )
289
+ @triton.jit(do_not_specialize=['T'])
290
+ def chunk_gla_fwd_kernel_o(
291
+ q,
292
+ v,
293
+ g,
294
+ h,
295
+ o,
296
+ A,
297
+ cu_seqlens,
298
+ chunk_indices,
299
+ scale,
300
+ T,
301
+ H: tl.constexpr,
302
+ K: tl.constexpr,
303
+ V: tl.constexpr,
304
+ BT: tl.constexpr,
305
+ BK: tl.constexpr,
306
+ BV: tl.constexpr,
307
+ IS_VARLEN: tl.constexpr,
308
+ ):
309
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
310
+ i_b, i_h = i_bh // H, i_bh % H
311
+ if IS_VARLEN:
312
+ i_tg = i_t
313
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
314
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
315
+ T = eos - bos
316
+ NT = tl.cdiv(T, BT)
317
+ else:
318
+ NT = tl.cdiv(T, BT)
319
+ i_tg = i_b * NT + i_t
320
+ bos, eos = i_b * T, i_b * T + T
321
+
322
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
323
+
324
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
325
+ for i_k in range(tl.cdiv(K, BK)):
326
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
327
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
328
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
329
+
330
+ # [BT, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BT, BK]
334
+ b_g = tl.load(p_g, boundary_check=(0, 1))
335
+ # [BT, BK]
336
+ b_qg = (b_q * exp(b_g)).to(b_q.dtype)
337
+ # [BK, BV]
338
+ b_h = tl.load(p_h, boundary_check=(0, 1))
339
+ # works but dkw, owing to divine benevolence
340
+ # [BT, BV]
341
+ if i_k >= 0:
342
+ b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
343
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
344
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
345
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
346
+ # [BT, BV]
347
+ b_v = tl.load(p_v, boundary_check=(0, 1))
348
+ # [BT, BT]
349
+ b_A = tl.load(p_A, boundary_check=(0, 1))
350
+ b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype)
351
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
352
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
353
+
354
+
355
+ @triton.heuristics({
356
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
357
+ })
358
+ @triton.autotune(
359
+ configs=[
360
+ triton.Config({}, num_warps=num_warps)
361
+ for num_warps in [1, 2, 4, 8]
362
+ ],
363
+ key=['BK', 'NC', 'BT'],
364
+ )
365
+ @triton.jit(do_not_specialize=['T'])
366
+ def chunk_gla_bwd_kernel_intra(
367
+ q,
368
+ k,
369
+ g,
370
+ dA,
371
+ dq,
372
+ dk,
373
+ cu_seqlens,
374
+ chunk_indices,
375
+ T,
376
+ H: tl.constexpr,
377
+ K: tl.constexpr,
378
+ BT: tl.constexpr,
379
+ BC: tl.constexpr,
380
+ BK: tl.constexpr,
381
+ NC: tl.constexpr,
382
+ IS_VARLEN: tl.constexpr,
383
+ ):
384
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
385
+ i_b, i_h = i_bh // H, i_bh % H
386
+ i_t, i_i = i_c // NC, i_c % NC
387
+ if IS_VARLEN:
388
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
389
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
390
+ else:
391
+ bos, eos = i_b * T, i_b * T + T
392
+ T = eos - bos
393
+ if i_t * BT + i_i * BC >= T:
394
+ return
395
+
396
+ o_k = i_k * BK + tl.arange(0, BK)
397
+ m_k = o_k < K
398
+
399
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
400
+ # [BC, BK]
401
+ b_g = tl.load(p_g, boundary_check=(0, 1))
402
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
403
+ if i_i > 0:
404
+ p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k
405
+
406
+ # [BK,]
407
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
408
+ for i_j in range(0, i_i):
409
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0))
410
+ p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0))
411
+ p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
412
+ # [BC, BK]
413
+ b_k = tl.load(p_k, boundary_check=(0, 1))
414
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
415
+ b_kg = (b_k * exp(b_gn[None, :] - b_gk))
416
+ # [BC, BC]
417
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
418
+ # [BC, BK]
419
+ b_dq += tl.dot(b_dA, b_kg)
420
+ b_dq *= exp(b_g - b_gn[None, :])
421
+
422
+ o_i = tl.arange(0, BC)
423
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
424
+ o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC
425
+ p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
426
+ p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
427
+ p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
428
+
429
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
430
+ # [BC,]
431
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
432
+ # [BK,]
433
+ b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32)
434
+ b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32)
435
+ # [BC, BK]
436
+ m_i = o_i[:, None] >= j
437
+ # [BC, BK]
438
+ # (SY 09/17) important to not use bf16 here to have a good precision.
439
+ b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.)
440
+ p_kj += H*K
441
+ p_gkj += H*K
442
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
443
+
444
+ tl.debug_barrier()
445
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
446
+ p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
447
+
448
+ # [BC, BK]
449
+ b_k = tl.load(p_k, boundary_check=(0, 1))
450
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
451
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
452
+
453
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
454
+ if i_i < NC - 1:
455
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k
456
+
457
+ # [BK,]
458
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
459
+ for i_j in range(i_i + 1, NC):
460
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0))
461
+ p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0))
462
+ p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
463
+ # [BC, BK]
464
+ b_q = tl.load(p_q, boundary_check=(0, 1))
465
+ b_gq = tl.load(p_gq, boundary_check=(0, 1))
466
+ b_qg = b_q * safe_exp(b_gq - b_gn[None, :])
467
+ # [BC, BC]
468
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
469
+ # [BC, BK]
470
+ # (SY 09/17) important to not use bf16 here to have a good precision.
471
+ b_dk += tl.dot(b_dA, b_qg)
472
+ b_dk *= exp(b_gn[None, :] - b_gk)
473
+ o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC)
474
+ p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
475
+ p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
476
+ p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
477
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
478
+ # [BC,]
479
+ b_dA = tl.load(dA + o_dA + j * H*BT)
480
+ # [BK,]
481
+ b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32)
482
+ b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32)
483
+ # [BC, BK]
484
+ m_i = o_i[:, None] <= j
485
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.)
486
+ p_qj += H*K
487
+ p_gqj += H*K
488
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
489
+
490
+
491
+ @triton.heuristics({
492
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
493
+ })
494
+ @triton.autotune(
495
+ configs=[
496
+ triton.Config({}, num_warps=1),
497
+ triton.Config({}, num_warps=2),
498
+ triton.Config({}, num_warps=4),
499
+ triton.Config({}, num_warps=8),
500
+ ],
501
+ key=['BV', 'BT'],
502
+ )
503
+ @triton.jit(do_not_specialize=['T'])
504
+ def chunk_gla_bwd_kernel_dA(
505
+ v,
506
+ do,
507
+ dA,
508
+ cu_seqlens,
509
+ chunk_indices,
510
+ scale,
511
+ T,
512
+ H: tl.constexpr,
513
+ V: tl.constexpr,
514
+ BT: tl.constexpr,
515
+ BV: tl.constexpr,
516
+ IS_VARLEN: tl.constexpr,
517
+ ):
518
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
519
+ i_b, i_h = i_bh // H, i_bh % H
520
+ if IS_VARLEN:
521
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
523
+ else:
524
+ bos, eos = i_b * T, i_b * T + T
525
+ T = eos - bos
526
+
527
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
528
+ for i_v in range(tl.cdiv(V, BV)):
529
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
530
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
531
+ b_v = tl.load(p_v, boundary_check=(0, 1))
532
+ b_do = tl.load(p_do, boundary_check=(0, 1))
533
+ b_dA += tl.dot(b_do, b_v)
534
+ p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
535
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
536
+ b_dA = tl.where(m_s, b_dA * scale, 0.)
537
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
538
+
539
+
540
+ @triton.heuristics({
541
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
542
+ })
543
+ @triton.autotune(
544
+ configs=[
545
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
546
+ for BK in BK_LIST
547
+ for BV in BV_LIST
548
+ for num_warps in [2, 4, 8]
549
+ ],
550
+ key=['BT'],
551
+ )
552
+ @triton.jit(do_not_specialize=['T'])
553
+ def chunk_gla_bwd_kernel_dv(
554
+ k,
555
+ g,
556
+ A,
557
+ do,
558
+ dh,
559
+ dv,
560
+ cu_seqlens,
561
+ chunk_indices,
562
+ T,
563
+ H: tl.constexpr,
564
+ K: tl.constexpr,
565
+ V: tl.constexpr,
566
+ BT: tl.constexpr,
567
+ BK: tl.constexpr,
568
+ BV: tl.constexpr,
569
+ IS_VARLEN: tl.constexpr,
570
+ ):
571
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
572
+ i_b, i_h = i_bh // H, i_bh % H
573
+ if IS_VARLEN:
574
+ i_tg = i_t
575
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
576
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
577
+ T = eos - bos
578
+ NT = tl.cdiv(T, BT)
579
+ else:
580
+ NT = tl.cdiv(T, BT)
581
+ i_tg = i_b * NT + i_t
582
+ bos, eos = i_b * T, i_b * T + T
583
+
584
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
585
+ p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
586
+ p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
587
+
588
+ b_A = tl.load(p_A, boundary_check=(0, 1))
589
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.)
590
+ b_do = tl.load(p_do, boundary_check=(0, 1))
591
+ # (SY 09/17) important to disallow tf32 here to maintain a good precision.
592
+ b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False)
593
+
594
+ for i_k in range(tl.cdiv(K, BK)):
595
+ o_k = i_k * BK + tl.arange(0, BK)
596
+ m_k = o_k < K
597
+
598
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
599
+ p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
600
+ p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k
601
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
602
+
603
+ b_k = tl.load(p_k, boundary_check=(0, 1))
604
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
605
+ b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk)
606
+ b_k = (b_k * b_gn).to(b_k.dtype)
607
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
608
+ # [BT, BV]
609
+ # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here
610
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
611
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
612
+
613
+
614
+ @triton.heuristics({
615
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
616
+ })
617
+ @triton.autotune(
618
+ configs=[
619
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
620
+ for BK in BK_LIST
621
+ for BV in BV_LIST
622
+ for num_warps in [2, 4, 8]
623
+ ],
624
+ key=['BT'],
625
+ )
626
+ @triton.jit(do_not_specialize=['T'])
627
+ def chunk_gla_bwd_kernel_inter(
628
+ q,
629
+ k,
630
+ v,
631
+ h,
632
+ g,
633
+ do,
634
+ dh,
635
+ dq,
636
+ dk,
637
+ dq2,
638
+ dk2,
639
+ dg,
640
+ cu_seqlens,
641
+ chunk_indices,
642
+ scale,
643
+ T,
644
+ H: tl.constexpr,
645
+ K: tl.constexpr,
646
+ V: tl.constexpr,
647
+ BT: tl.constexpr,
648
+ BK: tl.constexpr,
649
+ BV: tl.constexpr,
650
+ IS_VARLEN: tl.constexpr,
651
+ ):
652
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
653
+ i_b, i_h = i_bh // H, i_bh % H
654
+ if IS_VARLEN:
655
+ i_tg = i_t
656
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
657
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
658
+ T = eos - bos
659
+ NT = tl.cdiv(T, BT)
660
+ else:
661
+ NT = tl.cdiv(T, BT)
662
+ i_tg = i_b * NT + i_t
663
+ bos, eos = i_b * T, i_b * T + T
664
+ o_k = i_k * BK + tl.arange(0, BK)
665
+ m_k = o_k < K
666
+
667
+ p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
668
+ p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k
669
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
670
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
671
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
672
+ b_dgk = tl.zeros([BK,], dtype=tl.float32)
673
+
674
+ for i_v in range(tl.cdiv(V, BV)):
675
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
676
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
677
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
678
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
679
+ # [BT, BV]
680
+ b_v = tl.load(p_v, boundary_check=(0, 1))
681
+ b_do = tl.load(p_do, boundary_check=(0, 1))
682
+ # [BV, BK]
683
+ b_h = tl.load(p_h, boundary_check=(0, 1))
684
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
685
+ # [BK]
686
+ b_dgk += tl.sum(b_h * b_dh, axis=0)
687
+ # [BT, BK]
688
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
689
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
690
+ b_dgk *= exp(b_gn)
691
+ b_dq *= scale
692
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
693
+ b_dq = b_dq * exp(b_gk)
694
+ b_dk = b_dk * exp(b_gn[None, :] - b_gk)
695
+
696
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
697
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
698
+ p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
699
+ p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
700
+ b_q = tl.load(p_q, boundary_check=(0, 1))
701
+ b_k = tl.load(p_k, boundary_check=(0, 1))
702
+ b_dgk += tl.sum(b_dk * b_k, axis=0)
703
+ b_dq += tl.load(p_dq, boundary_check=(0, 1))
704
+ b_dk += tl.load(p_dk, boundary_check=(0, 1))
705
+ b_dg = b_q * b_dq - b_k * b_dk
706
+ # tl.debug_barrier()
707
+ b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :]
708
+ # Buggy due to strange triton compiler issue.
709
+ # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.)
710
+ # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :]
711
+ p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
712
+ p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
713
+ p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
714
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
715
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
716
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
717
+
718
+
719
+ def chunk_gla_fwd_intra_gk(
720
+ q: torch.Tensor,
721
+ k: torch.Tensor,
722
+ g: torch.Tensor,
723
+ scale: float,
724
+ cu_seqlens: Optional[torch.LongTensor] = None,
725
+ chunk_size: int = 64
726
+ ):
727
+ B, T, H, K = k.shape
728
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
729
+
730
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
731
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
732
+ BC = min(16, BT)
733
+ NC = triton.cdiv(BT, BC)
734
+
735
+ A = q.new_empty(B, T, H, BT, dtype=torch.float)
736
+ grid = (NT, NC * NC, B * H)
737
+ chunk_gla_fwd_A_kernel_intra_sub_inter[grid](
738
+ q,
739
+ k,
740
+ g,
741
+ A,
742
+ cu_seqlens,
743
+ chunk_indices,
744
+ scale,
745
+ T=T,
746
+ H=H,
747
+ K=K,
748
+ BT=BT,
749
+ BC=BC,
750
+ NC=NC,
751
+ )
752
+
753
+ grid = (NT, NC, B * H)
754
+ # load the entire [BC, K] blocks into SRAM at once
755
+ if K <= 256:
756
+ BK = triton.next_power_of_2(K)
757
+ chunk_gla_fwd_A_kernel_intra_sub_intra[grid](
758
+ q,
759
+ k,
760
+ g,
761
+ A,
762
+ cu_seqlens,
763
+ chunk_indices,
764
+ scale,
765
+ T=T,
766
+ H=H,
767
+ K=K,
768
+ BT=BT,
769
+ BC=BC,
770
+ BK=BK,
771
+ )
772
+ # split then merge
773
+ else:
774
+ BK = min(128, triton.next_power_of_2(K))
775
+ NK = triton.cdiv(K, BK)
776
+ A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float)
777
+
778
+ grid = (NK, NT * NC, B * H)
779
+ chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid](
780
+ q,
781
+ k,
782
+ g,
783
+ A_intra,
784
+ cu_seqlens,
785
+ chunk_indices,
786
+ scale,
787
+ T=T,
788
+ B=B,
789
+ H=H,
790
+ K=K,
791
+ BT=BT,
792
+ BC=BC,
793
+ BK=BK,
794
+ NC=NC,
795
+ )
796
+
797
+ grid = (NT, NC, B * H)
798
+ chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid](
799
+ A_intra,
800
+ A,
801
+ cu_seqlens,
802
+ chunk_indices,
803
+ T=T,
804
+ B=B,
805
+ H=H,
806
+ BT=BT,
807
+ BC=BC,
808
+ NK=NK,
809
+ )
810
+ return A
811
+
812
+
813
+ def chunk_gla_fwd_o_gk(
814
+ q: torch.Tensor,
815
+ v: torch.Tensor,
816
+ g: torch.Tensor,
817
+ A: torch.Tensor,
818
+ h: torch.Tensor,
819
+ scale: float,
820
+ cu_seqlens: Optional[torch.LongTensor] = None,
821
+ chunk_size: int = 64
822
+ ):
823
+ B, T, H, K, V = *q.shape, v.shape[-1]
824
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
825
+
826
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
827
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
828
+
829
+ o = torch.empty_like(v)
830
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
831
+ chunk_gla_fwd_kernel_o[grid](
832
+ q,
833
+ v,
834
+ g,
835
+ h,
836
+ o,
837
+ A,
838
+ cu_seqlens,
839
+ chunk_indices,
840
+ scale,
841
+ T=T,
842
+ H=H,
843
+ K=K,
844
+ V=V,
845
+ BT=BT,
846
+ )
847
+ return o
848
+
849
+
850
+ def chunk_gla_bwd_dA(
851
+ v: torch.Tensor,
852
+ do: torch.Tensor,
853
+ scale: float,
854
+ cu_seqlens: Optional[torch.LongTensor] = None,
855
+ chunk_size: int = 64
856
+ ):
857
+ B, T, H, V = v.shape
858
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
859
+
860
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
861
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
862
+ BV = min(64, triton.next_power_of_2(V))
863
+
864
+ dA = v.new_empty(B, T, H, BT, dtype=torch.float)
865
+ grid = (NT, B * H)
866
+ chunk_gla_bwd_kernel_dA[grid](
867
+ v,
868
+ do,
869
+ dA,
870
+ cu_seqlens,
871
+ chunk_indices,
872
+ scale,
873
+ T=T,
874
+ H=H,
875
+ V=V,
876
+ BT=BT,
877
+ BV=BV,
878
+ )
879
+ return dA
880
+
881
+
882
+ def chunk_gla_bwd_dv(
883
+ k: torch.Tensor,
884
+ g: torch.Tensor,
885
+ A: torch.Tensor,
886
+ do: torch.Tensor,
887
+ dh: torch.Tensor,
888
+ cu_seqlens: Optional[torch.LongTensor] = None,
889
+ chunk_size: int = 64
890
+ ):
891
+ B, T, H, K, V = *k.shape, do.shape[-1]
892
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
893
+
894
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
895
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
896
+
897
+ dv = torch.empty_like(do)
898
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
899
+ chunk_gla_bwd_kernel_dv[grid](
900
+ k,
901
+ g,
902
+ A,
903
+ do,
904
+ dh,
905
+ dv,
906
+ cu_seqlens,
907
+ chunk_indices,
908
+ T=T,
909
+ H=H,
910
+ K=K,
911
+ V=V,
912
+ BT=BT,
913
+ )
914
+ return dv
915
+
916
+
917
+ def chunk_gla_bwd_dqk_intra(
918
+ q: torch.Tensor,
919
+ k: torch.Tensor,
920
+ g: torch.Tensor,
921
+ dA: torch.Tensor,
922
+ cu_seqlens: Optional[torch.LongTensor] = None,
923
+ chunk_size: int = 64
924
+ ):
925
+ B, T, H, K = q.shape
926
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
927
+ BC = min(16, BT)
928
+ BK = min(64, triton.next_power_of_2(K))
929
+
930
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
931
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
932
+ NC = triton.cdiv(BT, BC)
933
+ NK = triton.cdiv(K, BK)
934
+
935
+ dq = torch.empty_like(q, dtype=torch.float)
936
+ dk = torch.empty_like(k, dtype=torch.float)
937
+ grid = (NK, NT * NC, B * H)
938
+ chunk_gla_bwd_kernel_intra[grid](
939
+ q,
940
+ k,
941
+ g,
942
+ dA,
943
+ dq,
944
+ dk,
945
+ cu_seqlens,
946
+ chunk_indices,
947
+ T=T,
948
+ H=H,
949
+ K=K,
950
+ BT=BT,
951
+ BC=BC,
952
+ BK=BK,
953
+ NC=NC,
954
+ )
955
+ return dq, dk
956
+
957
+
958
+ def chunk_gla_bwd_dqkg(
959
+ q: torch.Tensor,
960
+ k: torch.Tensor,
961
+ v: torch.Tensor,
962
+ h: torch.Tensor,
963
+ g: torch.Tensor,
964
+ do: torch.Tensor,
965
+ dh: torch.Tensor,
966
+ dq: torch.Tensor,
967
+ dk: torch.Tensor,
968
+ scale: float,
969
+ cu_seqlens: Optional[torch.LongTensor] = None,
970
+ chunk_size: int = 64
971
+ ):
972
+ B, T, H, K, V = *k.shape, v.shape[-1]
973
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
974
+
975
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
976
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
977
+
978
+ dg = torch.empty_like(g)
979
+ dq2 = torch.empty_like(dq)
980
+ dk2 = torch.empty_like(dk)
981
+ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H)
982
+ chunk_gla_bwd_kernel_inter[grid](
983
+ q,
984
+ k,
985
+ v,
986
+ h,
987
+ g,
988
+ do,
989
+ dh,
990
+ dq,
991
+ dk,
992
+ dq2,
993
+ dk2,
994
+ dg,
995
+ cu_seqlens,
996
+ chunk_indices,
997
+ scale,
998
+ T=T,
999
+ H=H,
1000
+ K=K,
1001
+ V=V,
1002
+ BT=BT,
1003
+ )
1004
+ return dq2, dk2, dg
1005
+
1006
+
1007
+ def chunk_gla_fwd(
1008
+ q: torch.Tensor,
1009
+ k: torch.Tensor,
1010
+ v: torch.Tensor,
1011
+ g: torch.Tensor,
1012
+ g_cumsum: Optional[torch.Tensor],
1013
+ scale: float,
1014
+ initial_state: torch.Tensor,
1015
+ output_final_state: bool,
1016
+ cu_seqlens: Optional[torch.LongTensor] = None,
1017
+ chunk_size: int = 64
1018
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1019
+ T = q.shape[1]
1020
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1021
+ if g_cumsum is None:
1022
+ g_cumsum = chunk_local_cumsum(g, BT, cu_seqlens=cu_seqlens)
1023
+
1024
+ h, ht = chunk_fwd_h(
1025
+ k=k,
1026
+ v=v,
1027
+ g=None,
1028
+ gk=g_cumsum,
1029
+ gv=None,
1030
+ h0=initial_state,
1031
+ output_final_state=output_final_state,
1032
+ states_in_fp32=False,
1033
+ cu_seqlens=cu_seqlens,
1034
+ chunk_size=BT
1035
+ )
1036
+
1037
+ # the intra A is kept in fp32
1038
+ # the computation has very marginal effect on the entire throughput
1039
+ A = chunk_gla_fwd_intra_gk(
1040
+ q=q,
1041
+ k=k,
1042
+ g=g_cumsum,
1043
+ scale=scale,
1044
+ cu_seqlens=cu_seqlens,
1045
+ chunk_size=BT
1046
+ )
1047
+ o = chunk_gla_fwd_o_gk(
1048
+ q=q,
1049
+ v=v,
1050
+ g=g_cumsum,
1051
+ A=A,
1052
+ h=h,
1053
+ scale=scale,
1054
+ cu_seqlens=cu_seqlens,
1055
+ chunk_size=BT
1056
+ )
1057
+ return g_cumsum, A, h, ht, o
1058
+
1059
+
1060
+ def chunk_gla_bwd(
1061
+ q: torch.Tensor,
1062
+ k: torch.Tensor,
1063
+ v: torch.Tensor,
1064
+ g: torch.Tensor,
1065
+ g_cumsum: Optional[torch.Tensor],
1066
+ scale: float,
1067
+ initial_state: torch.Tensor,
1068
+ h: torch.Tensor,
1069
+ A: torch.Tensor,
1070
+ do: torch.Tensor,
1071
+ dht: torch.Tensor,
1072
+ cu_seqlens: Optional[torch.LongTensor] = None,
1073
+ chunk_size: int = 64
1074
+ ):
1075
+ T = q.shape[1]
1076
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1077
+ if g_cumsum is None:
1078
+ g_cumsum = chunk_local_cumsum(g, BT, cu_seqlens=cu_seqlens)
1079
+
1080
+ if h is None:
1081
+ h, _ = chunk_fwd_h(
1082
+ k=k,
1083
+ v=v,
1084
+ g=None,
1085
+ gk=g_cumsum,
1086
+ gv=None,
1087
+ h0=initial_state,
1088
+ output_final_state=False,
1089
+ cu_seqlens=cu_seqlens,
1090
+ chunk_size=BT,
1091
+ states_in_fp32=True
1092
+ )
1093
+ dh, dh0 = chunk_bwd_dh(
1094
+ q=q,
1095
+ k=k,
1096
+ v=v,
1097
+ g=None,
1098
+ gk=g_cumsum,
1099
+ gv=None,
1100
+ do=do,
1101
+ h0=initial_state,
1102
+ dht=dht,
1103
+ scale=scale,
1104
+ cu_seqlens=cu_seqlens,
1105
+ chunk_size=BT,
1106
+ states_in_fp32=True
1107
+ )
1108
+
1109
+ dv = chunk_gla_bwd_dv(
1110
+ k=k,
1111
+ g=g_cumsum,
1112
+ A=A,
1113
+ do=do,
1114
+ dh=dh,
1115
+ cu_seqlens=cu_seqlens,
1116
+ chunk_size=BT
1117
+ )
1118
+
1119
+ # dq dk in fp32
1120
+ dA = chunk_gla_bwd_dA(
1121
+ v=v,
1122
+ do=do,
1123
+ scale=scale,
1124
+ cu_seqlens=cu_seqlens,
1125
+ chunk_size=BT
1126
+ )
1127
+ dq, dk = chunk_gla_bwd_dqk_intra(
1128
+ q=q,
1129
+ k=k,
1130
+ g=g_cumsum,
1131
+ dA=dA,
1132
+ cu_seqlens=cu_seqlens,
1133
+ chunk_size=BT
1134
+ )
1135
+ dq, dk, dg = chunk_gla_bwd_dqkg(
1136
+ q=q,
1137
+ k=k,
1138
+ v=v,
1139
+ h=h,
1140
+ g=g_cumsum,
1141
+ do=do,
1142
+ dh=dh,
1143
+ dq=dq,
1144
+ dk=dk,
1145
+ scale=scale,
1146
+ cu_seqlens=cu_seqlens,
1147
+ chunk_size=BT
1148
+ )
1149
+ return dq, dk, dv, dg, dh0
1150
+
1151
+
1152
+ class ChunkGLAFunction(torch.autograd.Function):
1153
+
1154
+ @staticmethod
1155
+ @input_guard
1156
+ def forward(
1157
+ ctx,
1158
+ q,
1159
+ k,
1160
+ v,
1161
+ g,
1162
+ scale,
1163
+ initial_state,
1164
+ output_final_state,
1165
+ cu_seqlens,
1166
+ ):
1167
+ T = q.shape[1]
1168
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1169
+
1170
+ g_cumsum, A, h, ht, o = chunk_gla_fwd(
1171
+ q=q,
1172
+ k=k,
1173
+ v=v,
1174
+ g=g,
1175
+ g_cumsum=None,
1176
+ scale=scale,
1177
+ initial_state=initial_state,
1178
+ output_final_state=output_final_state,
1179
+ cu_seqlens=cu_seqlens,
1180
+ chunk_size=chunk_size
1181
+ )
1182
+ # recompute g_cumsum in bwd pass
1183
+ if g.dtype != torch.float:
1184
+ g_cumsum = None
1185
+ else:
1186
+ g = None
1187
+ ctx.save_for_backward(q, k, v, g, g_cumsum, initial_state, A)
1188
+ ctx.chunk_size = chunk_size
1189
+ ctx.scale = scale
1190
+ ctx.cu_seqlens = cu_seqlens
1191
+ return o, ht
1192
+
1193
+ @staticmethod
1194
+ @input_guard
1195
+ def backward(ctx, do, dht):
1196
+ q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors
1197
+ chunk_size, scale, cu_seqlens = ctx.chunk_size, ctx.scale, ctx.cu_seqlens
1198
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
1199
+ q=q,
1200
+ k=k,
1201
+ v=v,
1202
+ g=g,
1203
+ g_cumsum=g_cumsum,
1204
+ scale=scale,
1205
+ h=None,
1206
+ A=A,
1207
+ initial_state=initial_state,
1208
+ do=do,
1209
+ dht=dht,
1210
+ cu_seqlens=cu_seqlens,
1211
+ chunk_size=chunk_size
1212
+ )
1213
+ return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None
1214
+
1215
+
1216
+ @torch.compiler.disable
1217
+ def chunk_gla(
1218
+ q: torch.Tensor,
1219
+ k: torch.Tensor,
1220
+ v: torch.Tensor,
1221
+ g: torch.Tensor,
1222
+ scale: Optional[int] = None,
1223
+ initial_state: torch.Tensor = None,
1224
+ output_final_state: bool = False,
1225
+ cu_seqlens: Optional[torch.LongTensor] = None,
1226
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1227
+ r"""
1228
+ Args:
1229
+ q (torch.Tensor):
1230
+ queries of shape `[B, T, H, K]`.
1231
+ k (torch.Tensor):
1232
+ keys of shape `[B, T, H, K]`.
1233
+ v (torch.Tensor):
1234
+ values of shape `[B, T, H, V]`.
1235
+ g (torch.Tensor):
1236
+ Forget gates of shape `[B, T, H, K]`.
1237
+ scale (Optional[int]):
1238
+ Scale factor for the attention scores.
1239
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1240
+ initial_state (Optional[torch.Tensor]):
1241
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
1242
+ For equal-length input sequences, `N` equals the batch size `B`.
1243
+ Default: `None`.
1244
+ output_final_state (Optional[bool]):
1245
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
1246
+ cu_seqlens (torch.LongTensor):
1247
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1248
+ consistent with the FlashAttention API.
1249
+
1250
+ Returns:
1251
+ o (torch.Tensor):
1252
+ Outputs of shape `[B, T, H, V]`.
1253
+ final_state (torch.Tensor):
1254
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
1255
+
1256
+ Examples::
1257
+ >>> import torch
1258
+ >>> import torch.nn.functional as F
1259
+ >>> from einops import rearrange
1260
+ >>> from fla.ops.gla import chunk_gla
1261
+ # inputs with equal lengths
1262
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
1263
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1264
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1265
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1266
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
1267
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
1268
+ >>> o, ht = chunk_gla(
1269
+ q, k, v, g,
1270
+ initial_state=h0,
1271
+ output_final_state=True
1272
+ )
1273
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1274
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
1275
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1276
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1277
+ >>> o_var, ht_var = chunk_gla(
1278
+ q, k, v, g,
1279
+ initial_state=h0,
1280
+ output_final_state=True,
1281
+ cu_seqlens=cu_seqlens
1282
+ )
1283
+ >>> assert o.allclose(o_var.view(o.shape))
1284
+ >>> assert ht.allclose(ht_var)
1285
+ """
1286
+ if cu_seqlens is not None:
1287
+ if q.shape[0] != 1:
1288
+ raise ValueError(
1289
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1290
+ f"Please flatten variable-length inputs before processing."
1291
+ )
1292
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
1293
+ raise ValueError(
1294
+ f"The number of initial states is expected to be equal to the number of input sequences, "
1295
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
1296
+ )
1297
+ if scale is None:
1298
+ scale = q.shape[-1] ** -0.5
1299
+ o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens)
1300
+ return o, final_state
fla3/ops/gla/fused_recurrent.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.common.fused_recurrent import fused_recurrent
9
+
10
+
11
+ def fused_recurrent_gla(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ gk: Optional[torch.Tensor] = None,
16
+ gv: Optional[torch.Tensor] = None,
17
+ scale: Optional[int] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ reverse: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
23
+ r"""
24
+ Args:
25
+ q (torch.Tensor):
26
+ queries of shape `[B, T, H, K]`.
27
+ k (torch.Tensor):
28
+ keys of shape `[B, T, H, K]`.
29
+ v (torch.Tensor):
30
+ values of shape `[B, T, H, V]`.
31
+ gk (torch.Tensor):
32
+ Forget gates of shape `[B, T, H, K]`.
33
+ gv (torch.Tensor):
34
+ Forget gates of shape `[B, T, H, V]` applied to values.
35
+ scale (Optional[int]):
36
+ Scale factor for the attention scores.
37
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
38
+ initial_state (Optional[torch.Tensor]):
39
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
40
+ For equal-length input sequences, `N` equals the batch size `B`.
41
+ Default: `None`.
42
+ output_final_state (Optional[bool]):
43
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
44
+ reverse (Optional[bool]):
45
+ If `True`, process the state passing in reverse order. Default: `False`.
46
+ cu_seqlens (torch.LongTensor):
47
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
48
+ consistent with the FlashAttention API.
49
+
50
+ Returns:
51
+ o (torch.Tensor):
52
+ Outputs of shape `[B, T, H, V]`.
53
+ final_state (torch.Tensor):
54
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
55
+
56
+ Examples::
57
+ >>> import torch
58
+ >>> import torch.nn.functional as F
59
+ >>> from einops import rearrange
60
+ >>> from fla.ops.gla import fused_recurrent_gla
61
+ # inputs with equal lengths
62
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
63
+ >>> q = torch.randn(B, T, H, K, device='cuda')
64
+ >>> k = torch.randn(B, T, H, K, device='cuda')
65
+ >>> v = torch.randn(B, T, H, V, device='cuda')
66
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
67
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
68
+ >>> o, ht = fused_recurrent_gla(
69
+ q, k, v, g,
70
+ initial_state=h0,
71
+ output_final_state=True
72
+ )
73
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
74
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
75
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
76
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
77
+ >>> o_var, ht_var = fused_recurrent_gla(
78
+ q, k, v, g,
79
+ initial_state=h0,
80
+ output_final_state=True,
81
+ cu_seqlens=cu_seqlens
82
+ )
83
+ >>> assert o.allclose(o_var.view(o.shape))
84
+ """
85
+ if cu_seqlens is not None:
86
+ if q.shape[0] != 1:
87
+ raise ValueError(
88
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
89
+ f"Please flatten variable-length inputs before processing."
90
+ )
91
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
92
+ raise ValueError(
93
+ f"The number of initial states is expected to be equal to the number of input sequences, "
94
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
95
+ )
96
+ if scale is None:
97
+ scale = k.shape[-1] ** -0.5
98
+ o, final_state = fused_recurrent(
99
+ q=q,
100
+ k=k,
101
+ v=v,
102
+ g=None,
103
+ gk=gk,
104
+ gv=gv,
105
+ scale=scale,
106
+ initial_state=initial_state,
107
+ output_final_state=output_final_state,
108
+ reverse=reverse,
109
+ cu_seqlens=cu_seqlens,
110
+ )
111
+ return o, final_state
fla3/ops/gla/naive.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def ceildiv(a, b):
9
+ return -(a // -b)
10
+
11
+
12
+ def naive_recurrent_gla(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ v: torch.Tensor,
16
+ gk: torch.Tensor,
17
+ initial_state: Optional[torch.Tensor] = None,
18
+ output_final_state: bool = False
19
+ ):
20
+ dtype = q.dtype
21
+ q, k, v, gk = map(lambda x: x.transpose(1, 2).float(), (q, k, v, gk))
22
+ B, H, T, K, V = *q.shape, v.shape[-1]
23
+ o = torch.zeros_like(v)
24
+ scale = K ** -0.5
25
+
26
+ h = q.new_zeros(B, H, K, V, dtype=torch.float32)
27
+ if initial_state is not None:
28
+ h += initial_state.float()
29
+
30
+ for i in range(T):
31
+ q_i = q[:, :, i] * scale
32
+ k_i = k[:, :, i]
33
+ v_i = v[:, :, i]
34
+ gk_i = gk[:, :, i].exp()
35
+ kv_i = k_i[..., None] * v_i[..., None, :]
36
+ h = h * gk_i[..., None] + kv_i
37
+ o[:, :, i] = (q_i[..., None] * h).sum(-2)
38
+
39
+ if not output_final_state:
40
+ h = None
41
+ return o.transpose(1, 2).to(dtype), h
fla3/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (305 Bytes). View file
 
fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (320 Bytes). View file
 
fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (57 kB). View file
 
fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
fla3/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange, reduce
11
+
12
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
13
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
14
+ from fla.ops.utils import prepare_chunk_indices
15
+ from fla.ops.utils.cumsum import chunk_local_cumsum
16
+ from fla.ops.utils.op import exp, safe_exp
17
+ from fla.ops.utils.softmax import softmax_bwd, softmax_fwd
18
+ from fla.utils import input_guard
19
+
20
+
21
+ @triton.heuristics({
22
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
23
+ })
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
27
+ for BK in [32, 64]
28
+ for BV in [32, 64]
29
+ for num_warps in [2, 4, 8]
30
+ for num_stages in [2, 3, 4]
31
+ ],
32
+ key=['BT']
33
+ )
34
+ @triton.jit(do_not_specialize=['T'])
35
+ def chunk_gsa_fwd_k_kernel_inter(
36
+ q,
37
+ k,
38
+ h,
39
+ g,
40
+ o,
41
+ A,
42
+ cu_seqlens,
43
+ chunk_indices,
44
+ scale,
45
+ T,
46
+ HQ: tl.constexpr,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ V: tl.constexpr,
50
+ BT: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ BV: tl.constexpr,
53
+ NG: tl.constexpr,
54
+ IS_VARLEN: tl.constexpr,
55
+ ):
56
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
58
+ i_h = i_hq // NG
59
+ if IS_VARLEN:
60
+ i_tg = i_t
61
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
62
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ else:
66
+ NT = tl.cdiv(T, BT)
67
+ i_tg = i_b * NT + i_t
68
+ bos, eos = i_b * T, i_b * T + T
69
+
70
+ o_i = tl.arange(0, BT)
71
+ m_s = o_i[:, None] >= o_i[None, :]
72
+
73
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
74
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
75
+ for i_k in range(tl.cdiv(K, BK)):
76
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+
80
+ # [BT, BK]
81
+ b_q = tl.load(p_q, boundary_check=(0, 1))
82
+ b_q = (b_q * scale).to(b_q.dtype)
83
+ # [BK, BT]
84
+ b_k = tl.load(p_k, boundary_check=(0, 1))
85
+ # [BK, BV]
86
+ b_h = tl.load(p_h, boundary_check=(0, 1))
87
+ # [BT, BV]
88
+ b_o += tl.dot(b_q, b_h)
89
+ # [BT, BT]
90
+ b_A += tl.dot(b_q, b_k)
91
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
92
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
93
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
94
+ # [BT, BV]
95
+ b_g = tl.load(p_g, boundary_check=(0, 1))
96
+ b_o = b_o * exp(b_g)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+ # [BT, BT]
100
+ b_A = tl.where(m_s, b_A, 0.)
101
+ if i_v == 0:
102
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
103
+
104
+
105
+ @triton.heuristics({
106
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
107
+ })
108
+ @triton.jit(do_not_specialize=['T'])
109
+ def chunk_gsa_fwd_k_kernel_intra(
110
+ v,
111
+ g,
112
+ o,
113
+ A,
114
+ cu_seqlens,
115
+ chunk_indices,
116
+ T,
117
+ HQ: tl.constexpr,
118
+ H: tl.constexpr,
119
+ V: tl.constexpr,
120
+ BT: tl.constexpr,
121
+ BC: tl.constexpr,
122
+ BV: tl.constexpr,
123
+ NC: tl.constexpr,
124
+ NG: tl.constexpr,
125
+ IS_VARLEN: tl.constexpr,
126
+ ):
127
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
128
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
129
+ i_h = i_hq // NG
130
+ i_t, i_i = i_c // NC, i_c % NC
131
+ if IS_VARLEN:
132
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
133
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
134
+ T = eos - bos
135
+ else:
136
+ bos, eos = i_b * T, i_b * T + T
137
+
138
+ o_v = i_v * BV + tl.arange(0, BV)
139
+ m_v = o_v < V
140
+
141
+ if i_t * BT + i_i * BC > T:
142
+ return
143
+
144
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
145
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
146
+ # [BV,]
147
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
148
+ # [BC, BV]
149
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
150
+ for i_j in range(0, i_i):
151
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
152
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
153
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
154
+ # [BC, BV]
155
+ b_v = tl.load(p_v, boundary_check=(0, 1))
156
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
157
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
158
+ # [BC, BC]
159
+ b_A = tl.load(p_A, boundary_check=(0, 1))
160
+ b_o += tl.dot(b_A, b_vg)
161
+ # [BC, BV]
162
+ b_g = tl.load(p_g, boundary_check=(0, 1))
163
+ b_o *= exp(b_g - b_gn[None, :])
164
+
165
+ o_i = tl.arange(0, BC)
166
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
167
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
168
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
169
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
170
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
171
+ # [BC,]
172
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
173
+ # [BV,]
174
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
175
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
176
+ # [BC, BV]
177
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
178
+ # avoid 0 * inf = inf
179
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
180
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
181
+ b_o += tl.load(p_o, boundary_check=(0, 1))
182
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
183
+
184
+
185
+ @triton.heuristics({
186
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
187
+ })
188
+ @triton.autotune(
189
+ configs=[
190
+ triton.Config({}, num_warps=num_warps)
191
+ for num_warps in [2, 4, 8]
192
+ ],
193
+ key=["BT"]
194
+ )
195
+ @triton.jit(do_not_specialize=['T'])
196
+ def chunk_gsa_bwd_k_kernel_dA(
197
+ v,
198
+ g,
199
+ do,
200
+ dA,
201
+ chunk_indices,
202
+ cu_seqlens,
203
+ scale,
204
+ T,
205
+ B: tl.constexpr,
206
+ HQ: tl.constexpr,
207
+ H: tl.constexpr,
208
+ V: tl.constexpr,
209
+ BT: tl.constexpr,
210
+ BC: tl.constexpr,
211
+ BV: tl.constexpr,
212
+ NC: tl.constexpr,
213
+ NG: tl.constexpr,
214
+ IS_VARLEN: tl.constexpr,
215
+ ):
216
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
217
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
218
+ i_h = i_hq // NG
219
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
220
+ if IS_VARLEN:
221
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
222
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
223
+ all = T
224
+ T = eos - bos
225
+ else:
226
+ bos, eos = i_b * T, i_b * T + T
227
+ all = B * T
228
+
229
+ o_v = i_v * BV + tl.arange(0, BV)
230
+ m_v = o_v < V
231
+
232
+ if i_t * BT + i_i * BC > T:
233
+ return
234
+
235
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
236
+
237
+ # [BC, BC]
238
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
239
+ if i_i > i_j:
240
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
241
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
242
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
243
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
244
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
245
+ # [BV,]
246
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
247
+ # [BC, BV]
248
+ b_g = tl.load(p_g, boundary_check=(0, 1))
249
+ b_do = tl.load(p_do, boundary_check=(0, 1))
250
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
251
+ # [BV, BC]
252
+ b_v = tl.load(p_v, boundary_check=(0, 1))
253
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
254
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
255
+ # [BC, BC]
256
+ b_dA = tl.dot(b_do, b_vg)
257
+ elif i_i == i_j:
258
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
259
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
260
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
261
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
262
+ # [BC, BV]
263
+ b_g = tl.load(p_g, boundary_check=(0, 1))
264
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
265
+ m_v = o_v < V
266
+
267
+ o_i = tl.arange(0, BC)
268
+ # [BC, BC]
269
+ m_dA = o_i[:, None] >= o_i[None, :]
270
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
271
+ # [BV,]
272
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
273
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
274
+ # [BC,]
275
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
276
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
277
+
278
+ p_v += H*V
279
+ p_gv += H*V
280
+ b_dA = tl.where(m_dA, b_dA, 0.)
281
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
282
+
283
+
284
+ @triton.heuristics({
285
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
286
+ })
287
+ @triton.autotune(
288
+ configs=[
289
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
290
+ for num_warps in [2, 4]
291
+ for num_stages in [2, 3, 4]
292
+ ],
293
+ key=['BT']
294
+ )
295
+ @triton.jit(do_not_specialize=['T'])
296
+ def chunk_gsa_bwd_k_kernel_dqkvg(
297
+ q,
298
+ k,
299
+ v,
300
+ h,
301
+ g,
302
+ A,
303
+ do,
304
+ dh,
305
+ dq,
306
+ dk,
307
+ dv,
308
+ dg,
309
+ dgv,
310
+ dA,
311
+ cu_seqlens,
312
+ chunk_indices,
313
+ scale,
314
+ T,
315
+ B: tl.constexpr,
316
+ HQ: tl.constexpr,
317
+ H: tl.constexpr,
318
+ K: tl.constexpr,
319
+ V: tl.constexpr,
320
+ BT: tl.constexpr,
321
+ BK: tl.constexpr,
322
+ BV: tl.constexpr,
323
+ NG: tl.constexpr,
324
+ IS_VARLEN: tl.constexpr,
325
+ ):
326
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
327
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
328
+ i_h = i_hq // NG
329
+ if IS_VARLEN:
330
+ i_tg = i_t
331
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
332
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
333
+ all = T
334
+ T = eos - bos
335
+ NT = tl.cdiv(T, BT)
336
+ else:
337
+ NT = tl.cdiv(T, BT)
338
+ i_tg = i_b * NT + i_t
339
+ bos, eos = i_b * T, i_b * T + T
340
+ all = B * T
341
+
342
+ o_i = tl.arange(0, BT)
343
+ o_t = min(i_t * BT + BT, T)
344
+ m_s = o_i[:, None] >= o_i[None, :]
345
+
346
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
347
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
348
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
349
+
350
+ # [BT, BK]
351
+ b_q = tl.load(p_q, boundary_check=(0, 1))
352
+ b_k = tl.load(p_k, boundary_check=(0, 1))
353
+ # [BT, BT]
354
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
355
+ b_A = tl.where(m_s, b_A, 0.)
356
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
357
+
358
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
359
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
360
+ for i_v in range(tl.cdiv(V, BV)):
361
+ o_v = i_v * BV + tl.arange(0, BV)
362
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
364
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
365
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
366
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
367
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
368
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
370
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
371
+ m_v = o_v < V
372
+
373
+ # [BV,]
374
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
375
+ # [BT, BV]
376
+ b_v = tl.load(p_v, boundary_check=(0, 1))
377
+ b_g = tl.load(p_g, boundary_check=(0, 1))
378
+ b_gv = exp(b_gn[None, :] - b_g)
379
+ # [BV, BK]
380
+ b_h = tl.load(p_h, boundary_check=(0, 1))
381
+ # [BT, BV]
382
+ b_do = tl.load(p_do, boundary_check=(0, 1))
383
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
384
+ # [BK, BV]
385
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
386
+ # [BV]
387
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
388
+
389
+ b_dh = b_dh.to(b_k.dtype)
390
+ # [BT, BK]
391
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
392
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
393
+ # [BT, BV]
394
+ b_dv = tl.dot(b_k, b_dh) * b_gv
395
+ # [BV]
396
+ b_dg += tl.sum(b_dv * b_v, 0)
397
+
398
+ if i_k == 0:
399
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
400
+ else:
401
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
402
+
403
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
404
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
405
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
406
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
407
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
408
+ # [BT, BT]
409
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
410
+ # [BT, BK]
411
+ b_dq += tl.dot(b_dA, b_k)
412
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
413
+
414
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
415
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
416
+
417
+
418
+ @triton.heuristics({
419
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
420
+ })
421
+ @triton.jit(do_not_specialize=['T'])
422
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
423
+ v,
424
+ g,
425
+ o,
426
+ A,
427
+ do,
428
+ dv,
429
+ dg,
430
+ cu_seqlens,
431
+ chunk_indices,
432
+ T,
433
+ HQ: tl.constexpr,
434
+ H: tl.constexpr,
435
+ V: tl.constexpr,
436
+ BT: tl.constexpr,
437
+ BC: tl.constexpr,
438
+ BV: tl.constexpr,
439
+ NC: tl.constexpr,
440
+ NG: tl.constexpr,
441
+ IS_VARLEN: tl.constexpr,
442
+ ):
443
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
444
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
445
+ i_h = i_hq // NG
446
+ i_t, i_i = i_c // NC, i_c % NC
447
+ if IS_VARLEN:
448
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
449
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
450
+ T = eos - bos
451
+ else:
452
+ bos, eos = i_b * T, i_b * T + T
453
+
454
+ o_v = i_v * BV + tl.arange(0, BV)
455
+ m_v = o_v < V
456
+
457
+ if i_t * BT + i_i * BC > T:
458
+ return
459
+
460
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
461
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
462
+ # [BV,]
463
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
464
+ # [BC, BV]
465
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
466
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
467
+ for i_j in range(i_i + 1, NC):
468
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
469
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
470
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
471
+ # [BC, BV]
472
+ b_g = tl.load(p_g, boundary_check=(0, 1))
473
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
474
+ # [BC, BC]
475
+ b_A = tl.load(p_A, boundary_check=(0, 1))
476
+ # [BC, BV]
477
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
478
+ b_dv *= exp(b_gn[None, :] - b_gv)
479
+
480
+ o_i = tl.arange(0, BC)
481
+ o_c = i_i * BC + tl.arange(0, BC)
482
+
483
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
484
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
485
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
486
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
487
+ # [BC,]
488
+ b_A = tl.load(p_A)
489
+ # [BV,]
490
+ b_g = tl.load(p_g, mask=m_v, other=0)
491
+ b_do = tl.load(p_do, mask=m_v, other=0)
492
+ # [BC, BV]
493
+ m_i = o_i[:, None] <= j
494
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
495
+
496
+ p_g += H * V
497
+ p_A += HQ * BT
498
+ p_do += HQ * V
499
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
500
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
501
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
502
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
503
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
504
+
505
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
506
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
507
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
508
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
509
+ b_dg = b_o * b_do - b_v * b_dv
510
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
511
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
512
+
513
+
514
+ def chunk_gsa_fwd_v(
515
+ q: torch.Tensor,
516
+ k: torch.Tensor,
517
+ v: torch.Tensor,
518
+ g: torch.Tensor,
519
+ scale: float = 1.,
520
+ initial_state: Optional[torch.Tensor] = None,
521
+ output_final_state: bool = False,
522
+ cu_seqlens: Optional[torch.LongTensor] = None,
523
+ chunk_size: int = 64
524
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
525
+ _, A, h, ht, o = chunk_gla_fwd(
526
+ q=q,
527
+ k=k,
528
+ v=v,
529
+ g=None,
530
+ g_cumsum=g,
531
+ scale=scale,
532
+ initial_state=initial_state,
533
+ output_final_state=output_final_state,
534
+ cu_seqlens=cu_seqlens,
535
+ chunk_size=chunk_size
536
+ )
537
+ return A, h, ht, o
538
+
539
+
540
+ def chunk_gsa_fwd_k(
541
+ q: torch.Tensor,
542
+ k: torch.Tensor,
543
+ v: torch.Tensor,
544
+ g: torch.Tensor,
545
+ h0: Optional[torch.Tensor] = None,
546
+ output_final_state: bool = False,
547
+ scale: float = 1.,
548
+ cu_seqlens: Optional[torch.LongTensor] = None,
549
+ chunk_size: int = 64
550
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
551
+ B, T, H, K, V = *k.shape, v.shape[-1]
552
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
553
+ BC = min(16, BT)
554
+ BV = min(64, triton.next_power_of_2(V))
555
+ HQ = q.shape[2]
556
+
557
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
558
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
559
+ NC = triton.cdiv(BT, BC)
560
+ NG = HQ // H
561
+
562
+ h, ht = chunk_fwd_h(
563
+ k=k,
564
+ v=v,
565
+ g=None,
566
+ gk=None,
567
+ gv=g,
568
+ h0=h0,
569
+ output_final_state=output_final_state,
570
+ cu_seqlens=cu_seqlens,
571
+ chunk_size=BT,
572
+ states_in_fp32=False
573
+ )
574
+ o = v.new_empty(B, T, HQ, V)
575
+ A = q.new_empty(B, T, HQ, BT)
576
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
577
+ chunk_gsa_fwd_k_kernel_inter[grid](
578
+ q,
579
+ k,
580
+ h,
581
+ g,
582
+ o,
583
+ A,
584
+ cu_seqlens=cu_seqlens,
585
+ chunk_indices=chunk_indices,
586
+ scale=scale,
587
+ T=T,
588
+ HQ=HQ,
589
+ H=H,
590
+ K=K,
591
+ V=V,
592
+ BT=BT,
593
+ NG=NG,
594
+ )
595
+
596
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
597
+ chunk_gsa_fwd_k_kernel_intra[grid](
598
+ v,
599
+ g,
600
+ o,
601
+ A,
602
+ cu_seqlens=cu_seqlens,
603
+ chunk_indices=chunk_indices,
604
+ T=T,
605
+ HQ=HQ,
606
+ H=H,
607
+ V=V,
608
+ BT=BT,
609
+ BC=BC,
610
+ BV=BV,
611
+ NC=NC,
612
+ NG=NG,
613
+ num_warps=4,
614
+ num_stages=2
615
+ )
616
+ return A, h, ht, o
617
+
618
+
619
+ def chunk_gsa_bwd_v(
620
+ q: torch.Tensor,
621
+ k: torch.Tensor,
622
+ v: torch.Tensor,
623
+ g: torch.Tensor,
624
+ h0: torch.Tensor,
625
+ h: torch.Tensor,
626
+ A: torch.Tensor,
627
+ do: torch.Tensor,
628
+ dht: torch.Tensor,
629
+ dg: torch.Tensor,
630
+ scale: float = 1.,
631
+ cu_seqlens: Optional[torch.LongTensor] = None,
632
+ chunk_size: int = 64
633
+ ):
634
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
635
+ q=q,
636
+ k=k,
637
+ v=v,
638
+ g=None,
639
+ g_cumsum=g,
640
+ scale=scale,
641
+ initial_state=h0,
642
+ h=h,
643
+ A=A,
644
+ do=do,
645
+ dht=dht,
646
+ cu_seqlens=cu_seqlens,
647
+ chunk_size=chunk_size
648
+ )
649
+ return dq, dk, dv, dg, dh0
650
+
651
+
652
+ def chunk_gsa_bwd_k(
653
+ q: torch.Tensor,
654
+ k: torch.Tensor,
655
+ v: torch.Tensor,
656
+ g: torch.Tensor,
657
+ h: torch.Tensor,
658
+ h0: torch.Tensor,
659
+ o: torch.Tensor,
660
+ do: torch.Tensor,
661
+ dht: torch.Tensor,
662
+ dg: torch.Tensor,
663
+ scale: float = 1.,
664
+ cu_seqlens: Optional[torch.LongTensor] = None,
665
+ chunk_size: int = 64
666
+ ):
667
+ B, T, H, K, V = *k.shape, v.shape[-1]
668
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
669
+ BC = min(16, BT)
670
+ BK = min(64, triton.next_power_of_2(K))
671
+ BV = min(64, triton.next_power_of_2(V))
672
+ HQ = q.shape[2]
673
+
674
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
675
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
676
+ NC = triton.cdiv(BT, BC)
677
+ NK = triton.cdiv(K, BK)
678
+ NV = triton.cdiv(V, BV)
679
+ NG = HQ // H
680
+
681
+ if h is None:
682
+ h, _ = chunk_fwd_h(
683
+ k=k,
684
+ v=v,
685
+ g=None,
686
+ gk=None,
687
+ gv=g,
688
+ h0=h0,
689
+ output_final_state=False,
690
+ cu_seqlens=cu_seqlens,
691
+ chunk_size=BT,
692
+ states_in_fp32=False
693
+ )
694
+ dh, dh0 = chunk_bwd_dh(
695
+ q=q,
696
+ k=k,
697
+ v=v,
698
+ g=None,
699
+ gk=None,
700
+ gv=g,
701
+ do=do,
702
+ h0=h0,
703
+ dht=dht,
704
+ scale=scale,
705
+ cu_seqlens=cu_seqlens,
706
+ chunk_size=BT,
707
+ states_in_fp32=True
708
+ )
709
+ dA = q.new_empty(NV, B, T, HQ, BT)
710
+ grid = (NV, NT * NC * NC, B * HQ)
711
+ chunk_gsa_bwd_k_kernel_dA[grid](
712
+ v,
713
+ g,
714
+ do,
715
+ dA,
716
+ cu_seqlens=cu_seqlens,
717
+ chunk_indices=chunk_indices,
718
+ scale=scale,
719
+ T=T,
720
+ B=B,
721
+ HQ=HQ,
722
+ H=H,
723
+ V=V,
724
+ BT=BT,
725
+ BC=BC,
726
+ BV=BV,
727
+ NC=NC,
728
+ NG=NG,
729
+ )
730
+ dA = dA.sum(0, dtype=dA.dtype)
731
+
732
+ A = do.new_empty(NK, B, T, HQ, BT)
733
+ dq = torch.empty_like(q)
734
+ dk = k.new_empty(B, T, HQ, K)
735
+ dv = v.new_empty(NK, B, T, HQ, V)
736
+ dgv = g.new_empty(NK, B, T, HQ, V, dtype=torch.float)
737
+ grid = (NK, NT, B * HQ)
738
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
739
+ q,
740
+ k,
741
+ v,
742
+ h,
743
+ g,
744
+ A,
745
+ do,
746
+ dh,
747
+ dq,
748
+ dk,
749
+ dv,
750
+ dg,
751
+ dgv,
752
+ dA,
753
+ cu_seqlens=cu_seqlens,
754
+ chunk_indices=chunk_indices,
755
+ scale=scale,
756
+ T=T,
757
+ B=B,
758
+ HQ=HQ,
759
+ H=H,
760
+ K=K,
761
+ V=V,
762
+ BT=BT,
763
+ BK=BK,
764
+ BV=BV,
765
+ NG=NG,
766
+ )
767
+ A = A.sum(0, dtype=A.dtype)
768
+ dv = dv.sum(0, dtype=dv.dtype)
769
+ dgv = dgv.sum(0, dtype=dgv.dtype)
770
+
771
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
772
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
773
+ v,
774
+ g,
775
+ o,
776
+ A,
777
+ do,
778
+ dv,
779
+ dg,
780
+ cu_seqlens=cu_seqlens,
781
+ chunk_indices=chunk_indices,
782
+ T=T,
783
+ HQ=HQ,
784
+ H=H,
785
+ V=V,
786
+ BT=BT,
787
+ BC=BC,
788
+ BV=BV,
789
+ NC=NC,
790
+ NG=NG,
791
+ num_warps=4,
792
+ num_stages=2
793
+ )
794
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, cu_seqlens=cu_seqlens))
795
+
796
+ return dq, dk, dv, dg, dh0
797
+
798
+
799
+ def chunk_gsa_fwd(
800
+ q: torch.Tensor,
801
+ k: torch.Tensor,
802
+ v: torch.Tensor,
803
+ s: torch.Tensor,
804
+ g: torch.Tensor,
805
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
806
+ output_final_state: bool = False,
807
+ scale: float = 1.,
808
+ cu_seqlens: Optional[torch.LongTensor] = None,
809
+ chunk_size: int = 64
810
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
811
+ hk0, hv0 = None, None
812
+ if initial_state is not None:
813
+ hk0, hv0 = initial_state
814
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
815
+ q=q,
816
+ k=k,
817
+ v=s,
818
+ g=g,
819
+ h0=hk0,
820
+ output_final_state=output_final_state,
821
+ scale=scale,
822
+ cu_seqlens=cu_seqlens,
823
+ chunk_size=chunk_size
824
+ )
825
+
826
+ # p is kept in fp32 for safe softmax backward
827
+ p = softmax_fwd(ok, dtype=torch.float)
828
+
829
+ qv = p.to(q.dtype)
830
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
831
+ q=qv,
832
+ k=s,
833
+ v=v,
834
+ g=g,
835
+ scale=1.,
836
+ initial_state=hv0,
837
+ output_final_state=output_final_state,
838
+ cu_seqlens=cu_seqlens,
839
+ chunk_size=chunk_size
840
+ )
841
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
842
+
843
+
844
+ def chunk_gsa_bwd(
845
+ q: torch.Tensor,
846
+ k: torch.Tensor,
847
+ v: torch.Tensor,
848
+ s: torch.Tensor,
849
+ g: torch.Tensor,
850
+ ok: torch.Tensor,
851
+ p: torch.Tensor,
852
+ A: Tuple[torch.Tensor, torch.Tensor],
853
+ h: Tuple[torch.Tensor, torch.Tensor],
854
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
855
+ scale: float,
856
+ do: torch.Tensor,
857
+ dht: Tuple[torch.Tensor, torch.Tensor],
858
+ cu_seqlens: Optional[torch.LongTensor] = None,
859
+ chunk_size: int = 64
860
+ ):
861
+ hk0, hv0 = None, None
862
+ if initial_state is not None:
863
+ hk0, hv0 = initial_state
864
+
865
+ _, Av = A
866
+ hk, hv = h
867
+ dhkt, dhvt = dht
868
+
869
+ qv = p.to(q.dtype)
870
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
871
+ q=qv,
872
+ k=s,
873
+ v=v,
874
+ g=g,
875
+ h0=hv0,
876
+ h=hv,
877
+ A=Av,
878
+ do=do,
879
+ dht=dhvt,
880
+ dg=None,
881
+ scale=1.,
882
+ cu_seqlens=cu_seqlens,
883
+ chunk_size=chunk_size
884
+ )
885
+
886
+ # softmax gradient, equivalent to:
887
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
888
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
889
+
890
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
891
+ q=q,
892
+ k=k,
893
+ v=s,
894
+ g=g,
895
+ h0=hk0,
896
+ h=hk,
897
+ o=ok,
898
+ do=dok,
899
+ dht=dhkt,
900
+ dg=dg,
901
+ scale=scale,
902
+ cu_seqlens=cu_seqlens,
903
+ chunk_size=chunk_size
904
+ )
905
+
906
+ ds = dsv.add_(dsk)
907
+ if q.shape[1] != k.shape[1]:
908
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
909
+ dg = dg.to(s.dtype)
910
+ return dq, dk, dv, ds, dg, dhk0, dhv0
911
+
912
+
913
+ class ChunkGSAFunction(torch.autograd.Function):
914
+
915
+ @staticmethod
916
+ @input_guard
917
+ def forward(
918
+ ctx,
919
+ q: torch.Tensor,
920
+ k: torch.Tensor,
921
+ v: torch.Tensor,
922
+ s: torch.Tensor,
923
+ g: torch.Tensor,
924
+ scale: float,
925
+ hk0: Optional[torch.Tensor],
926
+ hv0: Optional[torch.Tensor],
927
+ output_final_state: bool,
928
+ checkpoint_level: int,
929
+ cu_seqlens: Optional[torch.LongTensor],
930
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
931
+ T = q.shape[1]
932
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
933
+
934
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens)
935
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
936
+ q=q,
937
+ k=k,
938
+ v=v,
939
+ s=s,
940
+ g=g,
941
+ initial_state=(hk0, hv0),
942
+ output_final_state=output_final_state,
943
+ scale=scale,
944
+ cu_seqlens=cu_seqlens,
945
+ chunk_size=chunk_size
946
+ )
947
+
948
+ if checkpoint_level >= 1:
949
+ del g
950
+ g = g_org
951
+ if checkpoint_level > 1:
952
+ del hk
953
+ del hv
954
+ hk, hv = None, None
955
+ else:
956
+ hk0, hv0 = None, None
957
+
958
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
959
+ ctx.checkpoint_level = checkpoint_level
960
+ ctx.scale = scale
961
+ ctx.cu_seqlens = cu_seqlens
962
+ ctx.chunk_size = chunk_size
963
+ return ov, hkt, hvt
964
+
965
+ @staticmethod
966
+ @input_guard
967
+ def backward(ctx, dov, dhkt=None, dhvt=None):
968
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
969
+ scale = ctx.scale
970
+ cu_seqlens = ctx.cu_seqlens
971
+ chunk_size = ctx.chunk_size
972
+
973
+ if ctx.checkpoint_level >= 1:
974
+ g = chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens)
975
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
976
+ q=q,
977
+ k=k,
978
+ v=v,
979
+ s=s,
980
+ g=g,
981
+ ok=ok,
982
+ p=p,
983
+ A=(None, Av),
984
+ h=(hk, hv),
985
+ initial_state=(hk0, hv0),
986
+ scale=scale,
987
+ do=dov,
988
+ dht=(dhkt, dhvt),
989
+ cu_seqlens=cu_seqlens,
990
+ chunk_size=chunk_size
991
+ )
992
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
993
+
994
+
995
+ @torch.compiler.disable
996
+ def chunk_gsa(
997
+ q: torch.Tensor,
998
+ k: torch.Tensor,
999
+ v: torch.Tensor,
1000
+ s: torch.Tensor,
1001
+ g: Optional[torch.Tensor] = None,
1002
+ scale: Optional[int] = None,
1003
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1004
+ output_final_state: Optional[bool] = False,
1005
+ checkpoint_level: Optional[int] = 2,
1006
+ cu_seqlens: Optional[torch.LongTensor] = None,
1007
+ head_first: Optional[bool] = False
1008
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1009
+ r"""
1010
+ Args:
1011
+ q (torch.Tensor):
1012
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, H, T, K]`.
1013
+ k (torch.Tensor):
1014
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1015
+ GQA is performed if `H` is not equal to `HQ`.
1016
+ v (torch.Tensor):
1017
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1018
+ s (torch.Tensor):
1019
+ slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]`.
1020
+ g (torch.Tensor):
1021
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1022
+ If not provided, this function is equivalent to vanilla ABC.
1023
+ scale (Optional[int]):
1024
+ Scale factor for attention scores.
1025
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1026
+ initial_state (Optional[Tuple[torch.Tensor]]):
1027
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1028
+ For equal-length input sequences, `N` equals the batch size `B`.
1029
+ Default: `None`.
1030
+ output_final_state (Optional[bool]):
1031
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1032
+ Default: `False`.
1033
+ checkpoint_level (Optional[int]):
1034
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1035
+ Default: `2`:
1036
+ - Level `0`: no memory saved, no recomputation.
1037
+ - Level `1`: recompute the fp32 cumulative values during backward.
1038
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1039
+ cu_seqlens (torch.LongTensor):
1040
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1041
+ consistent with the FlashAttention API.
1042
+ head_first (Optional[bool]):
1043
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1044
+ Default: `False`.
1045
+
1046
+ Returns:
1047
+ o (torch.Tensor):
1048
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1049
+ final_state (Tuple[torch.Tensor]):
1050
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1051
+ `None` otherwise.
1052
+
1053
+ Examples::
1054
+ >>> import torch
1055
+ >>> import torch.nn.functional as F
1056
+ >>> from einops import rearrange
1057
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1058
+ # inputs with equal lengths
1059
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1060
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1061
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1062
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1063
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1064
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1065
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1066
+ >>> o, (hk, hv) = chunk_gsa(
1067
+ q, k, v, s, g,
1068
+ initial_state=h0,
1069
+ output_final_state=True
1070
+ )
1071
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1072
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1073
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1074
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1075
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(
1076
+ q, k, v, s, g,
1077
+ initial_state=h0,
1078
+ output_final_state=True,
1079
+ cu_seqlens=cu_seqlens
1080
+ )
1081
+ >>> assert o.allclose(o_var.view(o.shape))
1082
+ >>> assert hk.allclose(hk_var)
1083
+ >>> assert hv.allclose(hv_var)
1084
+ """
1085
+ if head_first:
1086
+ raise DeprecationWarning(
1087
+ "head_first is deprecated and will be removed in a future version. "
1088
+ "Please use head_first=False for now instead."
1089
+ )
1090
+ q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g))
1091
+ if not head_first and q.shape[1] < q.shape[2]:
1092
+ warnings.warn(
1093
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
1094
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
1095
+ "when head_first=False was specified. "
1096
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
1097
+ )
1098
+ if cu_seqlens is not None:
1099
+ if q.shape[0] != 1:
1100
+ raise ValueError(
1101
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1102
+ f"Please flatten variable-length inputs before processing."
1103
+ )
1104
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1105
+ raise ValueError(
1106
+ f"The number of initial states is expected to be equal to the number of input sequences, "
1107
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}."
1108
+ )
1109
+ assert checkpoint_level in [0, 1, 2]
1110
+ if g is None:
1111
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1112
+ z = s.float().logcumsumexp(2)
1113
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 1) - z
1114
+ s = torch.exp(s - z).to(k.dtype)
1115
+ if scale is None:
1116
+ scale = q.shape[-1] ** -0.5
1117
+
1118
+ hk0, hv0 = None, None
1119
+ if initial_state is not None:
1120
+ hk0, hv0 = initial_state
1121
+ o, *final_state = ChunkGSAFunction.apply(
1122
+ q,
1123
+ k,
1124
+ v,
1125
+ s,
1126
+ g,
1127
+ scale,
1128
+ hk0,
1129
+ hv0,
1130
+ output_final_state,
1131
+ checkpoint_level,
1132
+ cu_seqlens
1133
+ )
1134
+ if head_first:
1135
+ o = rearrange(o, 'b h t ... -> b t h ...')
1136
+ return o, final_state
fla3/ops/gsa/fused_recurrent.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.fused_recurrent import fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel
11
+ from fla.ops.utils import chunk_global_cumsum
12
+ from fla.ops.utils.op import exp
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
14
+
15
+
16
+ @triton.jit
17
+ def fused_recurrent_gsa_inference_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ s,
22
+ g,
23
+ o,
24
+ hk0,
25
+ hv0,
26
+ hkt,
27
+ hvt,
28
+ scale,
29
+ K: tl.constexpr,
30
+ V: tl.constexpr,
31
+ M: tl.constexpr,
32
+ BK: tl.constexpr,
33
+ BV: tl.constexpr,
34
+ NG: tl.constexpr
35
+ ):
36
+ i_bh = tl.program_id(0)
37
+ i_bg = i_bh // NG
38
+
39
+ b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32)
40
+ b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32)
41
+ b_g = exp(b_g)
42
+
43
+ b_ok = tl.zeros([M], dtype=tl.float32)
44
+ for i_k in range(tl.cdiv(K, BK)):
45
+ o_k = i_k * BK + tl.arange(0, BK)
46
+
47
+ p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None]
48
+ # [BK,]
49
+ mask_k = o_k < K
50
+ # [M, BK]
51
+ mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :]
52
+ # [M, BK]
53
+ b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32)
54
+ # [BK,]
55
+ b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale
56
+ b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32)
57
+ b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None]
58
+ b_ok += tl.sum(b_hk * b_q[None, :], axis=1)
59
+
60
+ if i_bh % NG == 0:
61
+ p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None]
62
+ tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk)
63
+
64
+ b_qv = tl.softmax(b_ok)
65
+ for i_v in range(tl.cdiv(V, BV)):
66
+ o_v = i_v * BV + tl.arange(0, BV)
67
+
68
+ p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
69
+ # [BV,]
70
+ mask_v = o_v < V
71
+ # [BV, M]
72
+ mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :]
73
+ # [BV, M]
74
+ b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32)
75
+ # [BV,]
76
+ b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32)
77
+ b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None]
78
+ b_ov = tl.sum(b_hv * b_qv[None, :], axis=1)
79
+
80
+ tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v)
81
+
82
+ if i_bh % NG == 0:
83
+ p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
84
+ tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv)
85
+
86
+
87
+ def fused_recurrent_gsa_inference(
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ v: torch.Tensor,
91
+ s: torch.Tensor,
92
+ g: torch.Tensor,
93
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
94
+ output_final_state: bool = False,
95
+ scale: float = 1.,
96
+ ) -> torch.Tensor:
97
+ B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
98
+ HQ = q.shape[2]
99
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
100
+ NG = HQ // H
101
+
102
+ if initial_state != (None, None) and initial_state is not None:
103
+ hk0, hv0 = initial_state
104
+ else:
105
+ hk0, hv0 = q.new_zeros(B, H, K, M, dtype=torch.float), q.new_zeros(B, H, M, V, dtype=torch.float)
106
+
107
+ hkt, hvt = None, None
108
+ if output_final_state:
109
+ if NG == 1:
110
+ hkt, hvt = hk0, hv0
111
+ else:
112
+ hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float)
113
+
114
+ o = v.new_empty(B, T, HQ, V)
115
+ grid = (B * HQ,)
116
+ fused_recurrent_gsa_inference_kernel[grid](
117
+ q,
118
+ k,
119
+ v,
120
+ s,
121
+ g,
122
+ o,
123
+ hk0,
124
+ hv0,
125
+ hkt,
126
+ hvt,
127
+ scale=scale,
128
+ K=K,
129
+ V=V,
130
+ M=M,
131
+ BK=BK,
132
+ BV=BV,
133
+ NG=NG
134
+ )
135
+ return o, (hkt, hvt)
136
+
137
+
138
+ def fused_recurrent_gsa_fwd(
139
+ q: torch.Tensor,
140
+ k: torch.Tensor,
141
+ v: torch.Tensor,
142
+ s: torch.Tensor,
143
+ g: torch.Tensor,
144
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
145
+ output_final_state: bool = False,
146
+ scale: float = 1.,
147
+ reverse: bool = False,
148
+ cu_seqlens: Optional[torch.LongTensor] = None,
149
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
150
+ B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
151
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
152
+ HQ = q.shape[2]
153
+ if HQ != H:
154
+ raise ValueError("GQA not supported yet.")
155
+
156
+ BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64)
157
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
158
+
159
+ hk0, hv0 = None, None
160
+ if initial_state != (None, None) and initial_state is not None:
161
+ hk0, hv0 = initial_state
162
+ hkt, hvt = None, None
163
+ if output_final_state:
164
+ hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float)
165
+
166
+ ok = q.new_empty(NK, *s.shape, dtype=torch.float)
167
+ gk, gv = None, g
168
+ grid = (NM, NK, N * H)
169
+ fused_recurrent_fwd_kernel[grid](
170
+ q=q,
171
+ k=k,
172
+ v=s,
173
+ g=None,
174
+ gk=gk,
175
+ gv=gv,
176
+ o=ok,
177
+ h0=hk0,
178
+ ht=hkt,
179
+ cu_seqlens=cu_seqlens,
180
+ scale=scale,
181
+ B=B,
182
+ T=T,
183
+ H=H,
184
+ K=K,
185
+ V=M,
186
+ BK=BK,
187
+ BV=BM,
188
+ USE_G=False,
189
+ USE_GK=False,
190
+ USE_GV=True,
191
+ REVERSE=reverse
192
+ )
193
+ ok = ok.sum(0)
194
+
195
+ qv = ok.softmax(-1, dtype=torch.float)
196
+ ov = q.new_empty(NM, *v.shape, dtype=torch.float)
197
+ gk, gv = g, None
198
+ grid = (NV, NM, N * H)
199
+ fused_recurrent_fwd_kernel[grid](
200
+ q=qv,
201
+ k=s,
202
+ v=v,
203
+ g=None,
204
+ gk=gk,
205
+ gv=gv,
206
+ o=ov,
207
+ h0=hv0,
208
+ ht=hvt,
209
+ cu_seqlens=cu_seqlens,
210
+ scale=1.,
211
+ B=B,
212
+ T=T,
213
+ H=H,
214
+ K=M,
215
+ V=V,
216
+ BK=BM,
217
+ BV=BV,
218
+ USE_G=False,
219
+ USE_GK=True,
220
+ USE_GV=False,
221
+ REVERSE=reverse,
222
+ )
223
+ ov = ov.sum(0)
224
+ return ok, hkt, qv, ov, hvt
225
+
226
+
227
+ def fused_recurrent_gsa_bwd(
228
+ q: torch.Tensor,
229
+ k: torch.Tensor,
230
+ v: torch.Tensor,
231
+ s: torch.Tensor,
232
+ g: torch.Tensor,
233
+ qv: torch.Tensor,
234
+ hk0: Optional[torch.Tensor] = None,
235
+ hv0: Optional[torch.Tensor] = None,
236
+ ok: Optional[torch.Tensor] = None,
237
+ do: Optional[torch.Tensor] = None,
238
+ dhkt: Optional[torch.Tensor] = None,
239
+ dhvt: Optional[torch.Tensor] = None,
240
+ scale: float = 1.,
241
+ reverse: bool = False,
242
+ cu_seqlens: Optional[torch.LongTensor] = None,
243
+ ) -> Tuple[torch.Tensor]:
244
+ B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
245
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
246
+
247
+ BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)
248
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
249
+
250
+ dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float)
251
+ dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float)
252
+ dv = q.new_empty(NM, B, T, H, V, dtype=torch.float)
253
+ dhk0 = torch.empty_like(hk0)if hk0 is not None else None
254
+ dhv0 = torch.empty_like(hv0)if hv0 is not None else None
255
+
256
+ gk, gv = g, None
257
+ grid = (NV, NM, N * H)
258
+ fused_recurrent_bwd_kernel[grid](
259
+ q=qv,
260
+ k=s,
261
+ v=v,
262
+ g=None,
263
+ gk=gk,
264
+ gv=gv,
265
+ h0=hv0,
266
+ do=do,
267
+ dq=dqv,
268
+ dk=dsv,
269
+ dv=dv,
270
+ dht=dhvt,
271
+ dh0=dhv0,
272
+ cu_seqlens=cu_seqlens,
273
+ scale=1.,
274
+ B=B,
275
+ T=T,
276
+ H=H,
277
+ K=M,
278
+ V=V,
279
+ BK=BM,
280
+ BV=BV,
281
+ USE_G=False,
282
+ USE_GK=True,
283
+ USE_GV=False,
284
+ REVERSE=reverse,
285
+ )
286
+ dqv = dqv.sum(0)
287
+ dsv = dsv.sum(0)
288
+ dv = dv.sum(0)
289
+ dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens)
290
+
291
+ dok = qv * (dqv - (qv * dqv).sum(-1, True))
292
+ dq = q.new_empty(NM, B, T, H, K, dtype=torch.float)
293
+ dk = q.new_empty(NM, B, T, H, K, dtype=torch.float)
294
+ dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float)
295
+ gk, gv = None, g
296
+ grid = (NM, NK, N * H)
297
+ fused_recurrent_bwd_kernel[grid](
298
+ q=q,
299
+ k=k,
300
+ v=s,
301
+ g=None,
302
+ gk=gk,
303
+ gv=gv,
304
+ h0=hk0,
305
+ do=dok,
306
+ dq=dq,
307
+ dk=dk,
308
+ dv=dsk,
309
+ dht=dhkt,
310
+ dh0=dhk0,
311
+ cu_seqlens=cu_seqlens,
312
+ scale=scale,
313
+ B=B,
314
+ T=T,
315
+ H=H,
316
+ K=K,
317
+ V=M,
318
+ BK=BK,
319
+ BV=BM,
320
+ USE_G=False,
321
+ USE_GK=False,
322
+ USE_GV=True,
323
+ REVERSE=reverse,
324
+ )
325
+ dq = dq.sum(0)
326
+ dk = dk.sum(0)
327
+ dsk = dsk.sum(0)
328
+
329
+ dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens)
330
+
331
+ ds = dsk.add_(dsv)
332
+ dg = dgk.add_(dgv)
333
+
334
+ return dq, dk, dv, ds, dg, dhk0, dhv0
335
+
336
+
337
+ class FusedRecurrentGSAFunction(torch.autograd.Function):
338
+
339
+ @staticmethod
340
+ @input_guard
341
+ @autocast_custom_fwd
342
+ def forward(
343
+ ctx,
344
+ q: torch.Tensor,
345
+ k: torch.Tensor,
346
+ v: torch.Tensor,
347
+ s: torch.Tensor,
348
+ g: torch.Tensor,
349
+ scale: Optional[float] = None,
350
+ hk0: Optional[torch.Tensor] = None,
351
+ hv0: Optional[torch.Tensor] = None,
352
+ output_final_state: bool = False,
353
+ reverse: bool = False,
354
+ cu_seqlens: Optional[torch.LongTensor] = None,
355
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
356
+ T = q.shape[1]
357
+ if T == 1 and not q.requires_grad:
358
+ o, (hkt, hvt) = fused_recurrent_gsa_inference(
359
+ q=q,
360
+ k=k,
361
+ v=v,
362
+ s=s,
363
+ g=g,
364
+ initial_state=(hk0, hv0),
365
+ output_final_state=output_final_state,
366
+ scale=scale,
367
+ )
368
+ return o, hkt, hvt
369
+ ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd(
370
+ q=q,
371
+ k=k,
372
+ v=v,
373
+ s=s,
374
+ g=g,
375
+ initial_state=(hk0, hv0),
376
+ output_final_state=output_final_state,
377
+ scale=scale,
378
+ reverse=reverse,
379
+ cu_seqlens=cu_seqlens,
380
+ )
381
+ ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok)
382
+ ctx.scale = scale
383
+ ctx.reverse = reverse
384
+ ctx.cu_seqlens = cu_seqlens
385
+ return ov.to(q.dtype), hkt, hvt
386
+
387
+ @staticmethod
388
+ @input_guard
389
+ @autocast_custom_bwd
390
+ def backward(ctx, do, dhkt=None, dhvt=None):
391
+ q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors
392
+ scale = ctx.scale
393
+ reverse = ctx.reverse
394
+ cu_seqlens = ctx.cu_seqlens
395
+
396
+ # not supported yet.
397
+ if dhkt is not None or dhvt is not None:
398
+ if g is not None:
399
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
400
+ dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd(
401
+ q=q,
402
+ k=k,
403
+ v=v,
404
+ s=s,
405
+ g=g,
406
+ qv=qv,
407
+ hk0=hk0,
408
+ hv0=hv0,
409
+ ok=ok,
410
+ do=do,
411
+ dhkt=dhkt,
412
+ dhvt=dhvt,
413
+ scale=scale,
414
+ reverse=reverse,
415
+ cu_seqlens=cu_seqlens,
416
+ )
417
+ return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None
418
+
419
+
420
+ def fused_recurrent_gsa(
421
+ q: torch.Tensor,
422
+ k: torch.Tensor,
423
+ v: torch.Tensor,
424
+ s: torch.Tensor,
425
+ g: Optional[torch.Tensor] = None,
426
+ scale: Optional[int] = None,
427
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
428
+ output_final_state: Optional[bool] = False,
429
+ reverse: Optional[bool] = False,
430
+ cu_seqlens: Optional[torch.LongTensor] = None,
431
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
432
+ r"""
433
+ Args:
434
+ q (torch.Tensor):
435
+ queries of shape `[B, T, H, K]`.
436
+ k (torch.Tensor):
437
+ keys of shape `[B, T, H, K]`.
438
+ v (torch.Tensor):
439
+ values of shape `[B, T, H, V]`.
440
+ s (torch.Tensor):
441
+ slot representations of shape `[B, T, H, M]`.
442
+ g (torch.Tensor):
443
+ Forget gates of shape `[B, H, T, M]` applied to keys.
444
+ scale (Optional[int]):
445
+ Scale factor for the attention scores.
446
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
447
+ initial_state (Optional[Tuple[torch.Tensor]]):
448
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
449
+ For equal-length input sequences, `N` equals the batch size `B`.
450
+ Default: `None`.
451
+ output_final_state (Optional[bool]):
452
+ Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`.
453
+ Default: `False`.
454
+ reverse (Optional[bool]):
455
+ If `True`, process the state passing in reverse order. Default: `False`.
456
+ cu_seqlens (torch.LongTensor):
457
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
458
+ consistent with the FlashAttention API.
459
+
460
+ Returns:
461
+ o (torch.Tensor):
462
+ Outputs of shape `[B, T, H, V]`.
463
+ final_state (Tuple[torch.Tensor]):
464
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
465
+
466
+ Examples::
467
+ >>> import torch
468
+ >>> import torch.nn.functional as F
469
+ >>> from einops import rearrange
470
+ >>> from fla.ops.gsa import fused_recurrent_gsa
471
+ # inputs with equal lengths
472
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
473
+ >>> q = torch.randn(B, T, H, K, device='cuda')
474
+ >>> k = torch.randn(B, T, H, K, device='cuda')
475
+ >>> v = torch.randn(B, T, H, V, device='cuda')
476
+ >>> s = torch.randn(B, T, H, M, device='cuda')
477
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
478
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
479
+ >>> o, (hk, hv) = fused_recurrent_gsa(
480
+ q, k, v, s, g,
481
+ initial_state=h0,
482
+ output_final_state=True
483
+ )
484
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
485
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
486
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
487
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
488
+ >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa(
489
+ q, k, v, s, g,
490
+ initial_state=h0,
491
+ output_final_state=True,
492
+ cu_seqlens=cu_seqlens
493
+ )
494
+ >>> assert o.allclose(o_var.view(o.shape))
495
+ >>> assert hk.allclose(hk_var)
496
+ >>> assert hv.allclose(hv_var)
497
+ """
498
+ if cu_seqlens is not None:
499
+ if q.shape[0] != 1:
500
+ raise ValueError(
501
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
502
+ f"Please flatten variable-length inputs before processing."
503
+ )
504
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
505
+ raise ValueError(
506
+ f"The number of initial states is expected to be equal to the number of input sequences, "
507
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}."
508
+ )
509
+ if scale is None:
510
+ scale = k.shape[-1] ** -0.5
511
+ if initial_state is None:
512
+ initial_state = (None, None)
513
+ o, *final_state = FusedRecurrentGSAFunction.apply(
514
+ q,
515
+ k,
516
+ v,
517
+ s,
518
+ g,
519
+ scale,
520
+ *initial_state,
521
+ output_final_state,
522
+ reverse,
523
+ cu_seqlens,
524
+ )
525
+ return o, final_state
fla3/ops/gsa/naive.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+ q, k, v, s, g = map(lambda x: x.transpose(1, 2).contiguous().float(), (q, k, v, s, g))
21
+
22
+ NG = q.shape[1]//k.shape[1]
23
+ # [batch_size, n_heads, seq_len, n_slots]
24
+ if g is None:
25
+ z = s.float().logcumsumexp(2)
26
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
27
+ s = torch.exp(s - z)
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ ov = ov.transpose(1, 2).contiguous()
69
+ return ov.to(dtype), final_state
fla3/ops/hgrn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_hgrn
4
+ from .fused_recurrent import fused_recurrent_hgrn
5
+
6
+ __all__ = [
7
+ 'chunk_hgrn',
8
+ 'fused_recurrent_hgrn'
9
+ ]
fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (308 Bytes). View file