msj19 commited on
Commit
0a2b89e
·
verified ·
1 Parent(s): 671f302

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc +0 -0
  2. fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc +0 -0
  3. fla2/ops/mask_delta_rule/__pycache__/utils.cpython-38.pyc +0 -0
  4. fla2/ops/mask_delta_rule/__pycache__/utils.cpython-39.pyc +0 -0
  5. fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-310.pyc +0 -0
  6. fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-312.pyc +0 -0
  7. fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-38.pyc +0 -0
  8. fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-39.pyc +0 -0
  9. fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-310.pyc +0 -0
  10. fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-312.pyc +0 -0
  11. fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-312.pyc +0 -0
  12. fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-310.pyc +0 -0
  13. fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-312.pyc +0 -0
  14. fla2/ops/mask_gated_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc +0 -0
  15. fla2/ops/mask_gated_delta_rule_t/wy_fast.py +541 -0
  16. fla2/ops/mask_gated_delta_rule_t/wy_fast_test.py +676 -0
  17. fla2/ops/retention/__pycache__/chunk_fuse.cpython-312.pyc +0 -0
  18. fla2/ops/retention/__pycache__/chunk_fuse.cpython-38.pyc +0 -0
  19. fla2/ops/retention/__pycache__/chunk_fuse.cpython-39.pyc +0 -0
  20. fla2/ops/retention/__pycache__/parallel.cpython-312.pyc +0 -0
  21. fla2/ops/retention/__pycache__/parallel.cpython-38.pyc +0 -0
  22. fla2/ops/retention/__pycache__/parallel.cpython-39.pyc +0 -0
  23. fla2/ops/retention/__pycache__/recurrent_fuse.cpython-312.pyc +0 -0
  24. fla2/ops/retention/__pycache__/recurrent_fuse.cpython-38.pyc +0 -0
  25. fla2/ops/retention/__pycache__/recurrent_fuse.cpython-39.pyc +0 -0
  26. fla2/ops/rwkv6/__pycache__/__init__.cpython-38.pyc +0 -0
  27. fla2/ops/rwkv6/__pycache__/__init__.cpython-39.pyc +0 -0
  28. fla2/ops/rwkv6/__pycache__/chunk.cpython-312.pyc +0 -0
  29. fla2/ops/rwkv6/__pycache__/chunk.cpython-38.pyc +0 -0
  30. fla2/ops/rwkv6/__pycache__/chunk.cpython-39.pyc +0 -0
  31. fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-312.pyc +0 -0
  32. fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-38.pyc +0 -0
  33. fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-39.pyc +0 -0
  34. fla2/ops/rwkv6/chunk.py +931 -0
  35. fla2/ops/rwkv6/chunk_naive.py +43 -0
  36. fla2/ops/rwkv6/recurrent_fuse.py +368 -0
  37. fla2/ops/rwkv6/recurrent_naive.py +103 -0
  38. fla2/ops/simple_gla/README.md +5 -0
  39. fla2/ops/simple_gla/__init__.py +7 -0
  40. fla2/ops/simple_gla/chunk.py +299 -0
  41. fla2/ops/simple_gla/naive.py +81 -0
  42. fla2/ops/simple_gla/recurrent_fuse.py +21 -0
  43. fla3/__pycache__/__init__.cpython-310.pyc +0 -0
  44. fla3/__pycache__/__init__.cpython-312.pyc +0 -0
  45. fla3/__pycache__/utils.cpython-310.pyc +0 -0
  46. fla3/__pycache__/utils.cpython-312.pyc +0 -0
  47. fla3/layers/__init__.py +51 -0
  48. fla3/layers/__pycache__/__init__.cpython-310.pyc +0 -0
  49. fla3/layers/__pycache__/__init__.cpython-312.pyc +0 -0
  50. fla3/layers/__pycache__/abc.cpython-310.pyc +0 -0
fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc ADDED
Binary file (7.25 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc ADDED
Binary file (7.19 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/utils.cpython-38.pyc ADDED
Binary file (8.73 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/utils.cpython-39.pyc ADDED
Binary file (8.67 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-310.pyc ADDED
Binary file (21.2 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (34.1 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-312.pyc ADDED
Binary file (32.2 kB). View file
 
fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (37.8 kB). View file
 
fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (97.3 kB). View file
 
fla2/ops/mask_gated_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
fla2/ops/mask_gated_delta_rule_t/wy_fast.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import pdb
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from einops import rearrange
7
+ # from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
8
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
9
+ # Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
10
+ # o: cumprod
11
+ # o2: cumprodsum
12
+ from typing import Optional
13
+ @triton.jit
14
+ def safe_exp(x):
15
+ return tl.exp(tl.where(x <= 0, x, float('-inf')))
16
+
17
+
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=1),
21
+ triton.Config({}, num_warps=2),
22
+ triton.Config({}, num_warps=4),
23
+ triton.Config({}, num_warps=8),
24
+ triton.Config({}, num_warps=16)
25
+ ],
26
+ key=["BT", "BK", "BV"],
27
+ )
28
+ @triton.jit
29
+ def gated_fwd_recompute_w_u_kernel(
30
+ k,
31
+ v,
32
+ beta,
33
+ mask_ij,
34
+ w,
35
+ u,
36
+ Aw,
37
+ Au,
38
+ s_qk_h,
39
+ s_qk_t,
40
+ s_qk_d,
41
+ s_vo_h,
42
+ s_vo_t,
43
+ s_vo_d,
44
+ T,
45
+ K,
46
+ V,
47
+ r: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr
51
+ ):
52
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
53
+ dk = K//r
54
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
55
+ b_beta = tl.load(p_beta, boundary_check=(0,))
56
+ p_Aw = tl.make_block_ptr(Aw + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))
57
+ b_Aw = tl.load(p_Aw, boundary_check=(0, 1)).to(k.dtype.element_ty)
58
+ for i_r in range(r):
59
+ p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0))
60
+ b_mask = tl.load(p_mask)#BT r 1
61
+ for i_k in range(tl.cdiv(dk, BK)):
62
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0))
63
+ b_k = tl.load(p_k, boundary_check=(0, 1))
64
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d
65
+ b_kb = tl.reshape(b_kb,(BT*r,BK))
66
+ b_w = tl.dot(b_Aw, b_kb, allow_tf32=False)#get BT*r *BK
67
+ p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0))
68
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
69
+ tl.debug_barrier()
70
+ b_Aw = None
71
+ p_Au = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))
72
+ b_Au = tl.load(p_Au, boundary_check=(0, 1)).to(k.dtype.element_ty)
73
+
74
+ for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask
75
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ b_v = tl.load(p_v, boundary_check=(0, 1))
77
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]
78
+ b_vb = tl.reshape(b_vb,(BT*r,BV))
79
+ b_u = tl.dot(b_Au, b_vb, allow_tf32=False)
80
+ p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0))
81
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
82
+
83
+
84
+ @triton.autotune(
85
+ configs=[
86
+ triton.Config({}, num_warps=1),
87
+ triton.Config({}, num_warps=2),
88
+ triton.Config({}, num_warps=4),
89
+ triton.Config({}, num_warps=8),
90
+ triton.Config({}, num_warps=16)
91
+ ],
92
+ key=["BT", "BK","r"],
93
+ )
94
+ @triton.jit
95
+ def gated_chunk_scaled_dot_kkt_fwd_kernel(
96
+ k,
97
+ beta,
98
+ g_cumsum,
99
+ mask_ij,
100
+ A,
101
+ Ag,
102
+ s_qk_h,
103
+ s_qk_t,
104
+ s_qk_d,
105
+ T,
106
+ K,
107
+ r: tl.constexpr,
108
+ BT: tl.constexpr,
109
+ BK: tl.constexpr,
110
+ ):
111
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
112
+ b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT
113
+ dk = K//r
114
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
115
+ b_beta = tl.load(p_beta, boundary_check=(0,))
116
+ for i_r in range(r):
117
+ r_mask = tl.arange(0, r) == i_r
118
+ p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0))
119
+ b_mask = tl.load(p_mask)#BT r 1
120
+ ij_mask = b_mask*r_mask[None,None,:]#行数 #BT [r,r]
121
+
122
+ for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算
123
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0))
124
+ b_k = tl.load(p_k, boundary_check=(0, 1))
125
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
126
+ dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)#BT BT
127
+ b_A += dot[:,:,None,None]*ij_mask[:,None,:,:]#BT r r
128
+
129
+ b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0)
130
+ p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0))
131
+ tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3))
132
+
133
+ p_g = tl.make_block_ptr(g_cumsum + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
134
+ b_g = tl.load(p_g, boundary_check=(0,))
135
+ b_g_diff = b_g[:, None] - b_g[None, :]
136
+ b_g_diff = safe_exp(b_g_diff)
137
+
138
+ b_Ag = b_A * ((b_g_diff)[:,:,None,None])#BT BT
139
+ p_Ag = tl.make_block_ptr(Ag + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0))
140
+ tl.store(p_Ag, (b_Ag).to(p_Ag.dtype.element_ty),boundary_check=(0,1,2,3))
141
+
142
+
143
+ @triton.autotune(
144
+ configs=[
145
+ triton.Config({}, num_warps=1),
146
+ triton.Config({}, num_warps=2),
147
+ triton.Config({}, num_warps=4),
148
+ triton.Config({}, num_warps=8),
149
+ triton.Config({}, num_warps=16)
150
+ ],
151
+ key=["BT", "r"],
152
+ )
153
+ @triton.jit
154
+ def solve_tril_16x16_kernel(
155
+ A,
156
+ Ad,
157
+ s_A_bh,
158
+ s_Ad_bh,
159
+ T,
160
+ r: tl.constexpr,
161
+ BT: tl.constexpr,
162
+ ):
163
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
164
+ offset = (i_t * 16) % BT
165
+
166
+ p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0))
167
+ b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32)
168
+ b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0)
169
+
170
+ for i in range(1, 16):
171
+ mask = tl.arange(0, 16) == i
172
+ b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)
173
+ q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2))
174
+ b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None])
175
+ b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果
176
+ b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :]))
177
+
178
+ b_A = tl.permute(b_A,(0,2,1,3))
179
+ b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r
180
+ p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0))
181
+ tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1))
182
+
183
+ @triton.autotune(
184
+ configs=[
185
+ triton.Config({}, num_warps=1),
186
+ triton.Config({}, num_warps=2),
187
+ triton.Config({}, num_warps=4),
188
+ triton.Config({}, num_warps=8),
189
+ triton.Config({}, num_warps=16)
190
+ ],
191
+ key=["r"],
192
+ )
193
+ @triton.jit
194
+ def merge_16x16_to_32x32_inverse_kernel(
195
+ A,
196
+ Ad,
197
+ Ai,
198
+ s_A_bh,
199
+ s_Ad_bh,
200
+ T,
201
+ r: tl.constexpr,
202
+ BT: tl.constexpr
203
+ ):
204
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
205
+
206
+ p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0))
207
+ b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32)
208
+
209
+ p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0))
210
+ p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0))
211
+
212
+ p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0))
213
+ p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0))
214
+ p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0))
215
+
216
+ Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32)
217
+ Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32)
218
+ Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee')
219
+ tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
220
+ tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
221
+ tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
222
+
223
+
224
+ @triton.autotune(
225
+ configs=[
226
+ triton.Config({}, num_warps=1),
227
+ triton.Config({}, num_warps=2),
228
+ triton.Config({}, num_warps=4),
229
+ triton.Config({}, num_warps=8),
230
+ triton.Config({}, num_warps=16)
231
+ ],
232
+ key=["r"],
233
+ )
234
+ @triton.jit
235
+ def merge_16x16_to_64x64_inverse_kernel(
236
+ A,
237
+ Ad,
238
+ Ai,
239
+ s_A_bh,
240
+ s_Ad_bh,
241
+ T,
242
+ r: tl.constexpr,
243
+ BT: tl.constexpr
244
+ ):
245
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
246
+
247
+ p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0))
248
+ p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0))
249
+ p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0))
250
+ p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0))
251
+ p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0))
252
+ p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0))
253
+
254
+ b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32)
255
+ b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32)
256
+ b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32)
257
+ b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32)
258
+ b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32)
259
+ b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32)
260
+
261
+
262
+ p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0))
263
+ p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0))
264
+ p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0))
265
+ p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0))
266
+
267
+
268
+ p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0))
269
+ p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0))
270
+ p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0))
271
+ p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0))
272
+ p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0))
273
+ p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0))
274
+ p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0))
275
+ p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0))
276
+ p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0))
277
+ p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0))
278
+
279
+
280
+ Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32)
281
+ Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32)
282
+ Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32)
283
+ Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32)
284
+
285
+ Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee')
286
+ Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee')
287
+ Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee')
288
+
289
+ Ai31 = -tl.dot(
290
+ Ai33,
291
+ tl.dot(b_A31,Ai11, input_precision='ieee')+
292
+ tl.dot(b_A32,Ai21, input_precision='ieee'),
293
+ input_precision='ieee')
294
+
295
+ Ai42 = -tl.dot(
296
+ Ai44,
297
+ tl.dot(b_A42,Ai22, input_precision='ieee')+
298
+ tl.dot(b_A43,Ai32, input_precision='ieee'),
299
+ input_precision='ieee')
300
+
301
+ Ai41 = -tl.dot(
302
+ Ai44,
303
+ tl.dot(b_A41, Ai11, input_precision='ieee') +
304
+ tl.dot(b_A42, Ai21, input_precision='ieee') +
305
+ tl.dot(b_A43, Ai31, input_precision='ieee'),
306
+ input_precision='ieee'
307
+ )
308
+
309
+ tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
310
+ tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
311
+ tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
312
+ tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
313
+ tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
314
+ tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
315
+ tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
316
+ tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
317
+ tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
318
+ tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
319
+
320
+
321
+
322
+ def gated_chunk_scaled_dot_kkt_fwd(k: torch.Tensor,
323
+ beta: torch.Tensor,
324
+ mask: torch.Tensor,
325
+ g_cumsum:Optional[torch.Tensor] = None,
326
+ BT:int = 32,
327
+ output_dtype: torch.dtype=torch.float32):
328
+ B, H, T, K = k.shape
329
+ r = mask.shape[-1] #B H T r r
330
+ NT = triton.cdiv(T, BT)
331
+ BK = min(triton.next_power_of_2(K//r), 64)
332
+ A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous()
333
+ Ag = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous()
334
+ gated_chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)](
335
+ k, beta, g_cumsum, mask, A,Ag,
336
+ T*K, K, 1,
337
+ T, K, r, BT, BK
338
+ )
339
+ return A,Ag
340
+
341
+ def solve_tril(A,mask,k,BT,output_dtype=torch.float32):
342
+ B, H, T, K = k.shape
343
+ r = mask.shape[-1]
344
+ NT = triton.cdiv(T, 16)
345
+ Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype)
346
+ solve_tril_16x16_kernel[(NT, B*H)](
347
+ A,Ad,
348
+ T*BT*r*r,#s_abh
349
+ T*16*r*r,#s_adbh
350
+ T,
351
+ r, BT
352
+ )
353
+ if BT == 16:
354
+ return Ad
355
+
356
+ A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r
357
+ if BT == 32:
358
+ NT = triton.cdiv(T, BT)
359
+ Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype)
360
+ merge_16x16_to_32x32_inverse_kernel[(NT, B*H)](
361
+ A,Ad,Ai,
362
+ T*BT*r*r,#s_a_bh and s_ai_bh
363
+ T*16*r*r,#s_ad_bh
364
+ T,r,BT
365
+ )
366
+ return Ai
367
+
368
+ if BT == 64:
369
+ NT = triton.cdiv(T, BT)
370
+ Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype)
371
+ merge_16x16_to_64x64_inverse_kernel[(NT, B*H)](
372
+ A,Ad,Ai,
373
+ T*BT*r*r,#s_a_bh and s_ai_bh
374
+ T*16*r*r,#s_ad_bh
375
+ T,r,BT
376
+ )
377
+ return Ai
378
+
379
+
380
+ def gated_fwd_recompute_w_u(k, v, beta,mask, Aw,Au,BT):
381
+ B, H, T, K, V = *k.shape, v.shape[-1]
382
+ r = mask.shape[-1]
383
+ u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype)
384
+ w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype)
385
+ NT = triton.cdiv(T, BT)
386
+ BK = min(triton.next_power_of_2(K//r), 64)#32
387
+ BV = min(triton.next_power_of_2(V), 64)
388
+ gated_fwd_recompute_w_u_kernel[(NT, B*H)](
389
+ k, v, beta,mask, w, u, Aw,Au,
390
+ T*K, K, 1,
391
+ T*V, V, 1,
392
+ T, K, V, r,BT, BK, BV
393
+ )
394
+ return w, u
395
+
396
+
397
+
398
+
399
+ # class WYRepresentationPrepration(torch.autograd.Function):
400
+ # @staticmethod
401
+ # @contiguous
402
+ # @autocast_custom_fwd
403
+ # def forward(ctx, k, v, beta,mask,chunk_size=64):
404
+ # ctx.BT = chunk_size
405
+ # w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT)
406
+ # ctx.save_for_backward(k, v, beta,mask,A)
407
+ # return w, u
408
+ # @staticmethod
409
+ # @contiguous
410
+ # @autocast_custom_bwd
411
+ # def backward(ctx, dw, du):
412
+ # k, v, beta,mask, A = ctx.saved_tensors
413
+ # BT = ctx.BT
414
+ # dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT)
415
+ # return dk, dv, dbeta, dmask, None
416
+
417
+ # prepare_wy_repr = WYRepresentationPrepration.apply
418
+
419
+
420
+ # def naive(k, v, beta,maskij,chunk_size):
421
+ # l_org = k.shape[2]
422
+ # l_new = triton.next_power_of_2(l_org)
423
+ # k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
424
+ # v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
425
+ # beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
426
+ # k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
427
+ # beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
428
+
429
+ # b,h,nt,BT,dk = k.shape
430
+ # dv = v.shape[-1]
431
+ # r = maskij.shape[-1]
432
+ # k_beta = k * beta[..., None]
433
+ # k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r)
434
+ # k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij)
435
+ # k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org
436
+ # v_beta = v * beta[..., None]
437
+ # v_beta = v_beta
438
+ # v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1)
439
+ # ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r)
440
+
441
+ # attn = (ki @ ki.transpose(-1, -2))
442
+ # attn = torch.tril(attn, diagonal=-1)#bhnr cc
443
+ # attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc
444
+ # attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta)
445
+
446
+ # o = torch.zeros_like(k_beta)
447
+ # o2 = torch.zeros_like(v_beta)
448
+
449
+ # o[..., 0, :,:] = k_beta[..., 0,:,:].clone()
450
+ # o2[..., 0,:, :] = v_beta[..., 0,:,:].clone()
451
+ # for i in range(1, chunk_size):
452
+ # o_i = (o[..., :i,:,:]).clone()#bhn :t cc
453
+ # o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:])
454
+ # o2_i = (o2[..., :i,:,:]).clone()#少一个维度
455
+ # o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:])
456
+ # return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2))
457
+
458
+
459
+ # if __name__ == "__main__":
460
+ # #all compute here
461
+ # import sys
462
+ # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy')
463
+ # torch.set_default_dtype(torch.bfloat16)
464
+ # seq_len = 32
465
+ # b = 2
466
+ # h = 2
467
+ # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128
468
+ # v = torch.randn(b, h, seq_len, 128)
469
+ # beta = torch.rand(b, h, seq_len).sigmoid()
470
+ # require_grad = True
471
+ # BT = 16
472
+ # k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta))
473
+ # r = 4
474
+ # # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous()
475
+ # mask = torch.randn([r,r])
476
+ # mask = mask.cuda().requires_grad_(require_grad).contiguous()
477
+ # # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16)
478
+ # # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16)
479
+ # # from einops import rearrange
480
+
481
+ # k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r)
482
+ # b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16)
483
+ # a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt
484
+ # qq = torch.tril(a1,diagonal=-1)
485
+ # qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask)
486
+ # sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)')
487
+ # sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个
488
+
489
+
490
+ # # #长条对角线
491
+ # i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :]))
492
+ # s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda()
493
+ # s = rearrange(s,'b h n a d c r->b h n (a c) (d r)')
494
+ # s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr
495
+
496
+
497
+ # # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r
498
+ # # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.float32)
499
+ # # s = rearrange(s,'b h n a c->(b h) (n a) c')
500
+ # # print(Ad)
501
+ # # print(s)
502
+ # # print((Ad-s).abs().max())
503
+
504
+ # w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, 16)
505
+ # As = rearrange(As,'b h (n t) l->(b h n) t l',t =BT*r)
506
+ # # print((As-s).abs().max())
507
+ # # B*H*NT,BT*r,16*r
508
+ # # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2)
509
+ # # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask)
510
+ # # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)')
511
+ # # wc = s_copy@k_exp
512
+
513
+ # # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT)
514
+ # # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2)
515
+ # # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1)
516
+ # # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v')
517
+ # # uc = s_copy@v_exp
518
+ # # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc))
519
+ # # do = torch.rand_like(wc)
520
+ # # do2 = torch.rand_like(uc)#b h n t t
521
+ # # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题
522
+ # # do = torch.rand_like(o1)
523
+ # # do2 = torch.rand_like(o2)#b h n t t
524
+ # # if require_grad:
525
+ # # o1.backward(do, retain_graph=True)
526
+ # # o2.backward(do2, retain_graph=True)
527
+ # # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad
528
+
529
+ # # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16)
530
+ # # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT)
531
+
532
+ # # print((o1-w0).abs().max())
533
+ # # print((o2-u0).abs().max())
534
+ # # print((k_grad-k_grad2).abs().max())
535
+ # # print((v_grad-v_grad2).abs().max())
536
+ # # print((beta_grad-beta_grad2).abs().max())
537
+ # # print((mask_grad-mask_grad2).abs().max())
538
+ # # print(mask_grad)
539
+ # # print(mask_grad2)
540
+
541
+
fla2/ops/mask_gated_delta_rule_t/wy_fast_test.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import pdb
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from einops import rearrange
7
+ # from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
8
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
9
+ # Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
10
+ # o: cumprod
11
+ # o2: cumprodsum
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=1),
16
+ triton.Config({}, num_warps=2),
17
+ triton.Config({}, num_warps=4),
18
+ triton.Config({}, num_warps=8),
19
+ triton.Config({}, num_warps=16)
20
+ ],
21
+ key=["BT", "BK", "BV"],
22
+ )
23
+ @triton.jit
24
+ def fwd_prepare_wy_repr_kernel(
25
+ k,
26
+ v,
27
+ beta,
28
+ mask_ij,
29
+ w,
30
+ u,
31
+ A,
32
+ s_qk_h,
33
+ s_qk_t,
34
+ s_qk_d,
35
+ s_vo_h,
36
+ s_vo_t,
37
+ s_vo_d,
38
+ T,
39
+ K,
40
+ V,
41
+ r: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr
45
+ ):
46
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
47
+ b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT
48
+ dk = K//r
49
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
50
+ b_beta = tl.load(p_beta, boundary_check=(0,))
51
+ for i_r in range(r):
52
+ r_mask = tl.arange(0, r) == i_r
53
+ p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目
54
+ b_mask = tl.load(p_mask)
55
+ ij_mask = b_mask[:,None]*r_mask[None,:]#行数
56
+
57
+ for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算
58
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0))
59
+ b_k = tl.load(p_k, boundary_check=(0, 1))
60
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
61
+ dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
62
+ b_A += dot[:,:,None,None]*ij_mask[None,None,:,:]
63
+ b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0)
64
+ #先save这个看看
65
+
66
+ for i in range(1, BT):#此时矩阵为 BT,r,BT,r
67
+ mask = tl.arange(0, BT) == i
68
+ b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r
69
+ q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r
70
+ b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r
71
+ b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果
72
+ b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :]))
73
+ b_A = tl.permute(b_A,(0,2,1,3))
74
+ b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r
75
+ p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法
76
+ tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1))
77
+ #解决矩阵求逆
78
+ b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果
79
+
80
+ for i_r in range(r):
81
+ p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列
82
+ b_mask = tl.load(p_mask)
83
+ for i_k in range(tl.cdiv(dk, BK)):
84
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0))
85
+ b_k = tl.load(p_k, boundary_check=(0, 1))
86
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d
87
+ b_kb = tl.reshape(b_kb,(BT*r,BK))
88
+ b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK
89
+ p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0))
90
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
91
+
92
+ for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask
93
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]
96
+ b_vb = tl.reshape(b_vb,(BT*r,BV))
97
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
98
+ p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0))
99
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
100
+
101
+
102
+ @triton.autotune(
103
+ configs=[
104
+ triton.Config({}, num_warps=1),
105
+ triton.Config({}, num_warps=2),
106
+ triton.Config({}, num_warps=4),
107
+ triton.Config({}, num_warps=8),
108
+ triton.Config({}, num_warps=16)
109
+ ],
110
+ key=["BT", "BK", "BV"],
111
+ )
112
+ @triton.jit
113
+ def fwd_recompute_w_u_kernel(
114
+ k,
115
+ v,
116
+ beta,
117
+ mask_ij,
118
+ w,
119
+ u,
120
+ A,
121
+ s_qk_h,
122
+ s_qk_t,
123
+ s_qk_d,
124
+ s_vo_h,
125
+ s_vo_t,
126
+ s_vo_d,
127
+ T,
128
+ K,
129
+ V,
130
+ r: tl.constexpr,
131
+ BT: tl.constexpr,
132
+ BK: tl.constexpr,
133
+ BV: tl.constexpr
134
+ ):
135
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
136
+ dk = K//r
137
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
138
+ b_beta = tl.load(p_beta, boundary_check=(0,))
139
+ p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))
140
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
141
+ for i_r in range(r):
142
+ # r_mask = tl.arange(0, r) == i_r #
143
+ p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列
144
+ b_mask = tl.load(p_mask)
145
+ for i_k in range(tl.cdiv(dk, BK)):
146
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0))
147
+ b_k = tl.load(p_k, boundary_check=(0, 1))
148
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d
149
+ b_kb = tl.reshape(b_kb,(BT*r,BK))
150
+ b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK
151
+ p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0))
152
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+ for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask
155
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]
158
+ b_vb = tl.reshape(b_vb,(BT*r,BV))
159
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
160
+ p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0))
161
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
162
+
163
+ #compute this
164
+ @triton.autotune(
165
+ configs=[
166
+ triton.Config({}, num_warps=1),
167
+ triton.Config({}, num_warps=2),
168
+ triton.Config({}, num_warps=4),
169
+ triton.Config({}, num_warps=8),
170
+ triton.Config({}, num_warps=16)
171
+ ],
172
+ key=["BT", "BK", "BV"],
173
+ )
174
+ @triton.jit
175
+ def bwd_prepare_wy_repr_kernel(
176
+ k, v, beta,mask_ij,A,
177
+ dw, du,
178
+ dk, dv, dbeta,dmask,
179
+ s_qk_h,
180
+ s_qk_t,
181
+ s_qk_d,
182
+ s_vo_h,
183
+ s_vo_t,
184
+ s_vo_d,
185
+ T,
186
+ K,
187
+ V,
188
+ r: tl.constexpr,
189
+ BT: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr
192
+ ):
193
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
194
+ p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0))
195
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
196
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
197
+ b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32)
198
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
199
+ b_beta = tl.load(p_beta, boundary_check=(0,))
200
+
201
+ b_dmask = tl.zeros([r,r],dtype=tl.float32)
202
+ for i_v in range(tl.cdiv(V, BV)):#分块r
203
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV
205
+ b_v = tl.load(p_v, boundary_check=(0, 1))
206
+ b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV
207
+ b_v_beta = tl.reshape(b_v_beta,(BT*r,BV))
208
+ b_du = tl.load(p_du, boundary_check=(0, 1))
209
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r
210
+ b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV
211
+ b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))#
212
+ sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果
213
+ b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢
214
+ b_dbeta += tl.sum(sum_dv * b_v, 1)
215
+ p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
216
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
217
+ block_k = K//r
218
+ for i_r in range(r):
219
+ p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列
220
+ b_mask = tl.load(p_mask)#第r列
221
+ rmask = tl.arange(0, r) == i_r #第r列
222
+ for i_k in range(tl.cdiv(block_k, BK)):
223
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0))
224
+ b_k = tl.load(p_k, boundary_check=(0, 1))
225
+ p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0))
226
+ b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d
227
+ b_k_beta = tl.reshape(b_k_beta,(BT*r,BK))
228
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
229
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
230
+ b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
231
+ b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK))
232
+ sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1)
233
+ b_dk = sum_dk* b_beta[:, None]
234
+ b_dbeta += tl.sum(sum_dk * b_k, 1)
235
+
236
+
237
+ b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:]
238
+ b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r))
239
+ b_ss = tl.sum(b_ss,0)
240
+ # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1))
241
+ b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32)
242
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0))
243
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
244
+
245
+ i = tl.arange(0, BT * r)[:, None]
246
+ j = tl.arange(0, BT * r)[None, :]
247
+ iB = i // r
248
+ jB = j // r
249
+ da_mask = iB > jB
250
+ b_dA = tl.where(da_mask, b_dA, 0)
251
+ b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False)
252
+ b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False)
253
+ b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处
254
+
255
+
256
+ b_dA = tl.reshape(b_dA,(BT,r,BT,r))
257
+ #bt r bt r
258
+
259
+
260
+ for i_r in range(r):#只取ir项
261
+ p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列
262
+ b_mask = tl.load(p_mask)#第ir列
263
+ rmask = tl.arange(0, r) == i_r #第ir列
264
+ g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列
265
+ ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT
266
+ #对应的c部分
267
+
268
+ for i_k in range(tl.cdiv(block_k, BK)):#ik = 1
269
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0))
270
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0))
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
273
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK
274
+
275
+ b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False)
276
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
277
+ b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False)
278
+ b_dk += b_dk_beta * b_beta[:, None]
279
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
280
+
281
+ beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT
282
+
283
+ beta_y = (beta_kkt[:,None,:]*g)
284
+ beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r))
285
+ betas = tl.sum(beta_y,0)
286
+ b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32)
287
+
288
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
289
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
290
+
291
+ p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0))
292
+ tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1))
293
+
294
+
295
+ @triton.autotune(
296
+ configs=[
297
+ triton.Config({}, num_warps=1),
298
+ triton.Config({}, num_warps=2),
299
+ triton.Config({}, num_warps=4),
300
+ triton.Config({}, num_warps=8),
301
+ triton.Config({}, num_warps=16)
302
+ ],
303
+ key=["BT", "BK", "r"],
304
+ )
305
+ @triton.jit
306
+ def chunk_scaled_dot_kkt_fwd_kernel(
307
+ k,
308
+ beta,
309
+ mask_ij,
310
+ A,
311
+ s_qk_h,
312
+ s_qk_t,
313
+ s_qk_d,
314
+ T,
315
+ K,
316
+ r: tl.constexpr,
317
+ BT: tl.constexpr,
318
+ BK: tl.constexpr,
319
+ ):
320
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
321
+ b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT
322
+ dk = K//r
323
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
324
+ b_beta = tl.load(p_beta, boundary_check=(0,))
325
+ for i_r in range(r):
326
+ r_mask = tl.arange(0, r) == i_r
327
+ p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目
328
+ b_mask = tl.load(p_mask)
329
+ ij_mask = b_mask[:,None]*r_mask[None,:]#行数
330
+
331
+ for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算
332
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0))
333
+ b_k = tl.load(p_k, boundary_check=(0, 1))
334
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
335
+ dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
336
+ b_A += dot[:,:,None,None]*ij_mask[None,None,:,:]
337
+ b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0)
338
+ p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0))
339
+ tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3))
340
+
341
+ @triton.autotune(
342
+ configs=[
343
+ triton.Config({}, num_warps=1),
344
+ triton.Config({}, num_warps=2),
345
+ triton.Config({}, num_warps=4),
346
+ triton.Config({}, num_warps=8),
347
+ triton.Config({}, num_warps=16)
348
+ ],
349
+ key=["BT", "r"],
350
+ )
351
+ @triton.jit
352
+ def solve_tril_16x16_kernel(
353
+ A,
354
+ Ad,
355
+ s_A_bh,
356
+ s_Ad_bh,
357
+ T,
358
+ r: tl.constexpr,
359
+ BT: tl.constexpr,
360
+ ):
361
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
362
+ offset = (i_t * 16) % BT
363
+
364
+ p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0))
365
+ b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32)
366
+ b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0)
367
+
368
+ for i in range(1, 16):
369
+ mask = tl.arange(0, 16) == i
370
+ b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)
371
+ q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2))
372
+ b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None])
373
+ b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果
374
+ b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :]))
375
+
376
+ b_A = tl.permute(b_A,(0,2,1,3))
377
+ b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r
378
+ p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0))
379
+ tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1))
380
+
381
+ @triton.autotune(
382
+ configs=[
383
+ triton.Config({}, num_warps=1),
384
+ triton.Config({}, num_warps=2),
385
+ triton.Config({}, num_warps=4),
386
+ triton.Config({}, num_warps=8),
387
+ triton.Config({}, num_warps=16)
388
+ ],
389
+ key=["r"],
390
+ )
391
+ @triton.jit
392
+ def merge_16x16_to_32x32_inverse_kernel(
393
+ A,
394
+ Ad,
395
+ Ai,
396
+ s_A_bh,
397
+ s_Ad_bh,
398
+ T,
399
+ r: tl.constexpr,
400
+ BT: tl.constexpr
401
+ ):
402
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
403
+
404
+ p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0))
405
+ b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32)
406
+ b_A21 = tl.permute(b_A21,(0,2,1,3))
407
+ b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r
408
+
409
+ p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0))
410
+ p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0))
411
+
412
+
413
+ p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0))
414
+ p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0))
415
+ p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0))
416
+
417
+ Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32)
418
+ Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32)
419
+ Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee')
420
+ tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
421
+ tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
422
+ tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
423
+
424
+ def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32):
425
+ B, H, T, K = k.shape
426
+ r = mask.shape[-1]
427
+ NT = triton.cdiv(T, BT)
428
+ BK = min(triton.next_power_of_2(K//r), 64)
429
+ A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous()
430
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)](
431
+ k, beta, mask, A,
432
+ T*K, K, 1,
433
+ T, K, r, BT, BK
434
+ )
435
+ return A
436
+
437
+ def solve_tril(A,mask,k,BT,output_dtype=torch.float32):
438
+ B, H, T, K = k.shape
439
+ r = mask.shape[-1]
440
+ NT = triton.cdiv(T, 16)
441
+ Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype)
442
+ solve_tril_16x16_kernel[(NT, B*H)](
443
+ A,Ad,
444
+ T*BT*r*r,#s_abh
445
+ T*16*r*r,#s_adbh
446
+ T,
447
+ r, BT
448
+ )
449
+ if BT == 16:
450
+ return Ad
451
+
452
+ NT = triton.cdiv(T, BT)
453
+ Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype)
454
+ merge_16x16_to_32x32_inverse_kernel[(NT, B*H)](
455
+ A,Ad,Ai,
456
+ T*BT*r*r,#s_a_bh and s_ai_bh
457
+ T*16*r*r,#s_ad_bh
458
+ T,r,BT
459
+ )
460
+ return Ai
461
+
462
+
463
+ def fwd_prepare_wy_repr2(k, v, beta,mask, BT):
464
+ A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32)
465
+ A = solve_tril(A=A,mask=mask,k=k,BT=BT,output_dtype=k.dtype)
466
+ w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT)
467
+ return w, u, A
468
+
469
+ def fwd_prepare_wy_repr(k, v, beta,mask, BT):
470
+ B, H, T, K, V = *k.shape, v.shape[-1]
471
+ r = mask.shape[-1]
472
+ u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype)
473
+ w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype)
474
+ NT = triton.cdiv(T, BT)
475
+ BK = min(triton.next_power_of_2(K//r), 64)
476
+ BV = min(triton.next_power_of_2(V), 64)
477
+ A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype)
478
+ fwd_prepare_wy_repr_kernel[(NT, B*H)](
479
+ k, v, beta, mask, w, u, A,
480
+ T*K, K, 1,
481
+ T*V, V, 1,
482
+ T, K, V, r, BT, BK, BV
483
+ )
484
+ return w, u, A
485
+
486
+
487
+ def fwd_recompute_w_u(k, v, beta,mask, A, BT):
488
+ B, H, T, K, V = *k.shape, v.shape[-1]
489
+ r = mask.shape[-1]
490
+ u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype)
491
+ w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype)
492
+ NT = triton.cdiv(T, BT)
493
+ BK = min(triton.next_power_of_2(K//r), 64)#32
494
+ BV = min(triton.next_power_of_2(V), 64)
495
+ fwd_recompute_w_u_kernel[(NT, B*H)](
496
+ k, v, beta,mask, w, u, A,
497
+ T*K, K, 1,
498
+ T*V, V, 1,
499
+ T, K, V, r,BT, BK, BV
500
+ )
501
+ return w, u
502
+
503
+ def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT):
504
+ B, H, T, K, V = *k.shape, v.shape[-1]
505
+ r = mask.shape[-1]
506
+ NT = triton.cdiv(T, BT)
507
+ BK = min(triton.next_power_of_2(K//r), 64)
508
+ BV = min(triton.next_power_of_2(V), 64)
509
+ NT = triton.cdiv(T, BT)
510
+ dk = torch.empty_like(k)
511
+ dv = torch.empty_like(v).contiguous()
512
+ dbeta = torch.zeros_like(beta)
513
+ dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous()
514
+ bwd_prepare_wy_repr_kernel[(NT, B*H)](
515
+ k, v, beta, mask, A,
516
+ dw, du,
517
+ dk, dv, dbeta,dmask,
518
+ T*K, K, 1,
519
+ T*V, V, 1,
520
+ T, K, V, r, BT, BK, BV
521
+ )
522
+ dmask = dmask.sum(0)
523
+ return dk, dv, dbeta, dmask
524
+
525
+
526
+ class WYRepresentationPrepration(torch.autograd.Function):
527
+ @staticmethod
528
+ @contiguous
529
+ @autocast_custom_fwd
530
+ def forward(ctx, k, v, beta,mask,chunk_size=64):
531
+ ctx.BT = chunk_size
532
+ w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT)
533
+ ctx.save_for_backward(k, v, beta,mask,A)
534
+ return w, u
535
+ @staticmethod
536
+ @contiguous
537
+ @autocast_custom_bwd
538
+ def backward(ctx, dw, du):
539
+ k, v, beta,mask, A = ctx.saved_tensors
540
+ BT = ctx.BT
541
+ dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT)
542
+ return dk, dv, dbeta, dmask, None
543
+
544
+ prepare_wy_repr = WYRepresentationPrepration.apply
545
+
546
+
547
+ def naive(k, v, beta,maskij,chunk_size):
548
+ l_org = k.shape[2]
549
+ l_new = triton.next_power_of_2(l_org)
550
+ k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
551
+ v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
552
+ beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
553
+ k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
554
+ beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
555
+
556
+ b,h,nt,BT,dk = k.shape
557
+ dv = v.shape[-1]
558
+ r = maskij.shape[-1]
559
+ k_beta = k * beta[..., None]
560
+ k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r)
561
+ k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij)
562
+ k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org
563
+ v_beta = v * beta[..., None]
564
+ v_beta = v_beta
565
+ v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1)
566
+ ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r)
567
+
568
+ attn = (ki @ ki.transpose(-1, -2))
569
+ attn = torch.tril(attn, diagonal=-1)#bhnr cc
570
+ attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc
571
+ attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta)
572
+
573
+ o = torch.zeros_like(k_beta)
574
+ o2 = torch.zeros_like(v_beta)
575
+
576
+ o[..., 0, :,:] = k_beta[..., 0,:,:].clone()
577
+ o2[..., 0,:, :] = v_beta[..., 0,:,:].clone()
578
+ for i in range(1, chunk_size):
579
+ o_i = (o[..., :i,:,:]).clone()#bhn :t cc
580
+ o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:])
581
+ o2_i = (o2[..., :i,:,:]).clone()#少一个维度
582
+ o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:])
583
+ return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2))
584
+
585
+
586
+ if __name__ == "__main__":
587
+ #all compute here
588
+ import sys
589
+ torch.manual_seed(42)
590
+ sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy')
591
+ torch.set_default_dtype(torch.bfloat16)
592
+ seq_len = 128
593
+ b = 2
594
+ h = 2
595
+ k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128
596
+ v = torch.randn(b, h, seq_len, 128)
597
+ beta = torch.rand(b, h, seq_len).sigmoid()
598
+ require_grad = True
599
+ BT = 32
600
+ k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta))
601
+ r = 4
602
+ # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous()
603
+ mask = torch.randn([r,r])
604
+ mask = mask.cuda().requires_grad_(require_grad).contiguous()
605
+ # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16)
606
+ # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16)
607
+ # from einops import rearrange
608
+
609
+ k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = BT,r=r)
610
+ b2 = rearrange(beta,'b h (n t)-> b h n t',t = BT)
611
+ a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt
612
+ qq = torch.tril(a1,diagonal=-1)
613
+ qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask)
614
+ sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)')
615
+ sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个
616
+
617
+ # #长条对角线
618
+ i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :]))
619
+ s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda()
620
+ s = rearrange(s,'b h n a d c r->b h n (a c) (d r)')
621
+ s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr
622
+
623
+
624
+ # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r
625
+ # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.bfloat16)
626
+ # s = rearrange(s,'b h n a c->(b h n) a c')
627
+ # print(Ad.shape)
628
+ # print(s.shape)
629
+
630
+ w,u,As = fwd_prepare_wy_repr2(k, v, beta,mask, BT)
631
+ # w2,u2,Ad2 = fwd_prepare_wy_repr(k, v, beta,mask, BT)
632
+
633
+ # print((w2-w).abs().max())
634
+ # print((u2-u).abs().max())
635
+ # print((As-Ad2).abs().max())
636
+
637
+ # print((Ad-s).abs().max())
638
+ # print(Ad-s)
639
+
640
+ # print((As-s).abs().max())
641
+ # print(As-s)
642
+ # B*H*NT,BT*r,16*r
643
+ # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2)
644
+ # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask)
645
+ # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)')
646
+ # wc = s_copy@k_exp
647
+
648
+ # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT)
649
+ # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2)
650
+ # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1)
651
+ # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v')
652
+ # uc = s_copy@v_exp
653
+ # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc))
654
+ # do = torch.rand_like(wc)
655
+ # do2 = torch.rand_like(uc)#b h n t t
656
+ # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题
657
+ # do = torch.rand_like(o1)
658
+ # do2 = torch.rand_like(o2)#b h n t t
659
+ # if require_grad:
660
+ # o1.backward(do, retain_graph=True)
661
+ # o2.backward(do2, retain_graph=True)
662
+ # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad
663
+
664
+ # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16)
665
+ # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT)
666
+
667
+ # print((o1-w0).abs().max())
668
+ # print((o2-u0).abs().max())
669
+ # print((k_grad-k_grad2).abs().max())
670
+ # print((v_grad-v_grad2).abs().max())
671
+ # print((beta_grad-beta_grad2).abs().max())
672
+ # print((mask_grad-mask_grad2).abs().max())
673
+ # print(mask_grad)
674
+ # print(mask_grad2)
675
+
676
+
fla2/ops/retention/__pycache__/chunk_fuse.cpython-312.pyc ADDED
Binary file (19.8 kB). View file
 
fla2/ops/retention/__pycache__/chunk_fuse.cpython-38.pyc ADDED
Binary file (7.94 kB). View file
 
fla2/ops/retention/__pycache__/chunk_fuse.cpython-39.pyc ADDED
Binary file (7.89 kB). View file
 
fla2/ops/retention/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (21.7 kB). View file
 
fla2/ops/retention/__pycache__/parallel.cpython-38.pyc ADDED
Binary file (8.42 kB). View file
 
fla2/ops/retention/__pycache__/parallel.cpython-39.pyc ADDED
Binary file (8.24 kB). View file
 
fla2/ops/retention/__pycache__/recurrent_fuse.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
fla2/ops/retention/__pycache__/recurrent_fuse.cpython-38.pyc ADDED
Binary file (5.99 kB). View file
 
fla2/ops/retention/__pycache__/recurrent_fuse.cpython-39.pyc ADDED
Binary file (5.93 kB). View file
 
fla2/ops/rwkv6/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (279 Bytes). View file
 
fla2/ops/rwkv6/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (281 Bytes). View file
 
fla2/ops/rwkv6/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (57.7 kB). View file
 
fla2/ops/rwkv6/__pycache__/chunk.cpython-38.pyc ADDED
Binary file (24.1 kB). View file
 
fla2/ops/rwkv6/__pycache__/chunk.cpython-39.pyc ADDED
Binary file (23.7 kB). View file
 
fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-38.pyc ADDED
Binary file (9.92 kB). View file
 
fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-39.pyc ADDED
Binary file (9.85 kB). View file
 
fla2/ops/rwkv6/chunk.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.ops.utils import chunk_global_reversed_cumsum
12
+ from fla.utils import contiguous
13
+
14
+
15
+ @triton.autotune(
16
+ configs=[
17
+ triton.Config({'BS': 16}, num_warps=2),
18
+ triton.Config({'BS': 16}, num_warps=4),
19
+ triton.Config({'BS': 16}, num_warps=8),
20
+ triton.Config({'BS': 32}, num_warps=2),
21
+ triton.Config({'BS': 32}, num_warps=4),
22
+ triton.Config({'BS': 32}, num_warps=8),
23
+ triton.Config({'BS': 64}, num_warps=2),
24
+ triton.Config({'BS': 64}, num_warps=4),
25
+ triton.Config({'BS': 64}, num_warps=8),
26
+ ],
27
+ key=['S']
28
+ )
29
+ @triton.jit
30
+ def chunk_rwkv6_fwd_kernel_cum(
31
+ s,
32
+ o,
33
+ o_minus_s,
34
+ s_s_h,
35
+ s_s_t,
36
+ s_s_d,
37
+ T: tl.constexpr,
38
+ S: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ BS: tl.constexpr
41
+ ):
42
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
43
+ o_i = tl.arange(0, BT)
44
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
45
+
46
+ p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
47
+ p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
48
+ p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
49
+ # [BT, BS]
50
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
51
+ b_o = tl.dot(m_s, b_s, allow_tf32=False)
52
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
53
+ tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))
54
+
55
+
56
+ @triton.jit
57
+ def post_process_grad(
58
+ q,
59
+ k,
60
+ v,
61
+ u,
62
+ do,
63
+ dk,
64
+ dq,
65
+ du,
66
+ scale,
67
+ s_k_h,
68
+ s_k_t,
69
+ s_k_d,
70
+ s_v_h,
71
+ s_v_t,
72
+ s_v_d,
73
+ H,
74
+ T: tl.constexpr,
75
+ BT: tl.constexpr,
76
+ K: tl.constexpr,
77
+ V: tl.constexpr,
78
+ BK: tl.constexpr,
79
+ BV: tl.constexpr,
80
+ ):
81
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
82
+ i_h = i_bh % H
83
+
84
+ # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)
85
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
86
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
87
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
88
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
89
+ p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
90
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
91
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
92
+ p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))
93
+
94
+ b_q = tl.load(p_q, boundary_check=(0, 1))
95
+ b_k = tl.load(p_k, boundary_check=(0, 1))
96
+ b_v = tl.load(p_v, boundary_check=(0, 1))
97
+ b_do = tl.load(p_do, boundary_check=(0, 1))
98
+ b_u = tl.load(p_u, boundary_check=(0,))
99
+
100
+ b_vdo = tl.sum(b_v * b_do, axis=1)
101
+ b_du = b_vdo[:, None] * b_k * b_q * scale
102
+ b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale
103
+ b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale
104
+
105
+ b_dq += tl.load(p_dq, boundary_check=(0, 1))
106
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
107
+
108
+ b_dk += tl.load(p_dk, boundary_check=(0, 1))
109
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
110
+
111
+ tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.jit
115
+ def chunk_rwkv6_fwd_kernel_h(
116
+ k,
117
+ v,
118
+ g,
119
+ h,
120
+ h0,
121
+ ht,
122
+ s_k_h,
123
+ s_k_t,
124
+ s_k_d,
125
+ s_v_h,
126
+ s_v_t,
127
+ s_v_d,
128
+ s_h_h,
129
+ s_h_t,
130
+ s_h_d,
131
+ T: tl.constexpr,
132
+ K: tl.constexpr,
133
+ V: tl.constexpr,
134
+ BT: tl.constexpr,
135
+ BK: tl.constexpr,
136
+ BV: tl.constexpr,
137
+ NT: tl.constexpr,
138
+ USE_INITIAL_STATE: tl.constexpr,
139
+ STORE_FINAL_STATE: tl.constexpr
140
+ ):
141
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
142
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
143
+
144
+ if USE_INITIAL_STATE:
145
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
146
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
147
+ for i_t in range(NT):
148
+ o_t = min(i_t * BT + BT, T)
149
+
150
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
151
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
152
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
153
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
154
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))
155
+
156
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
157
+ # [BK, BT]
158
+ b_k = tl.load(p_k, boundary_check=(0, 1))
159
+ # [BT, BV]
160
+ b_v = tl.load(p_v, boundary_check=(0, 1))
161
+ # [BK, BT]
162
+ b_g = tl.load(p_g, boundary_check=(0, 1))
163
+ if i_t < NT - 1:
164
+ # [BK,]
165
+ b_gn = tl.load(p_gn, boundary_check=(0,))
166
+ else:
167
+ b_gn = tl.min(b_g, axis=1)
168
+ b_h *= tl.exp(b_gn)[:, None]
169
+ b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
170
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
171
+
172
+ if STORE_FINAL_STATE:
173
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
174
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
175
+
176
+
177
+ @triton.jit
178
+ def chunk_rwkv6_fwd_kernel_intra(
179
+ q,
180
+ k,
181
+ g,
182
+ gs,
183
+ u,
184
+ A,
185
+ s_k_h,
186
+ s_k_t,
187
+ s_k_d,
188
+ scale,
189
+ H,
190
+ T: tl.constexpr,
191
+ K: tl.constexpr,
192
+ BT: tl.constexpr,
193
+ BC: tl.constexpr,
194
+ BK: tl.constexpr,
195
+ NC: tl.constexpr,
196
+ DK: tl.constexpr
197
+ ):
198
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
199
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
200
+ i_h = i_bh % H
201
+ n_bh = tl.num_programs(2)
202
+
203
+ o_k = i_k * BK + tl.arange(0, BK)
204
+ o_q = i_t * BT + i_i * BC
205
+ m_k = o_k < K
206
+
207
+ if i_i > i_j:
208
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
210
+ p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
211
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
212
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
213
+ # [BK,]
214
+ b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
215
+ # [BC, BK]
216
+ b_q = tl.load(p_q, boundary_check=(0, 1))
217
+ b_gs = tl.load(p_gs, boundary_check=(0, 1))
218
+ b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype)
219
+ # [BK, BC]
220
+ b_k = tl.load(p_k, boundary_check=(0, 1))
221
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
222
+ b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
223
+ # [BC, BC]
224
+ b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
225
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
226
+ elif i_i == i_j:
227
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
228
+ p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
229
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
230
+ p_q_u = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
231
+
232
+ # [BC, BK]
233
+ b_q = tl.load(p_q, boundary_check=(0, 1))
234
+ b_gs = tl.load(p_gs, boundary_check=(0, 1))
235
+ o_i = tl.arange(0, BC)
236
+ o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k
237
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
238
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
239
+ p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,))
240
+ b_u = tl.load(p_u, boundary_check=(0,))
241
+ for j in range(0, BC):
242
+ # [BK,]
243
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
244
+ b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32)
245
+ # [BC,]
246
+ b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1)
247
+ b_A = tl.where(o_i > j, b_A, 0.)
248
+ # self
249
+ b_q_u = tl.load(p_q_u, boundary_check=(0,)).to(tl.float32)
250
+ b_A_u = tl.sum(b_q_u * b_k * b_u * scale, axis=0)
251
+ m_u = tl.arange(0, BC) == j
252
+ b_A = tl.where(m_u, b_A_u, b_A)
253
+ tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A)
254
+ p_k = tl.advance(p_k, (K,))
255
+ p_q_u = tl.advance(p_q_u, (K,))
256
+
257
+
258
+ @triton.jit
259
+ def chunk_rwkv6_fwd_kernel_inter(
260
+ q,
261
+ v,
262
+ gs,
263
+ h,
264
+ o,
265
+ A,
266
+ s_k_h,
267
+ s_k_t,
268
+ s_k_d,
269
+ s_v_h,
270
+ s_v_t,
271
+ s_v_d,
272
+ s_h_h,
273
+ s_h_t,
274
+ s_h_d,
275
+ scale,
276
+ T: tl.constexpr,
277
+ K: tl.constexpr,
278
+ V: tl.constexpr,
279
+ BT: tl.constexpr,
280
+ BK: tl.constexpr,
281
+ BV: tl.constexpr
282
+ ):
283
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
284
+
285
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
286
+ for i_k in range(tl.cdiv(K, BK)):
287
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
288
+ p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
289
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
290
+
291
+ # [BT, BK]
292
+ b_q = tl.load(p_q, boundary_check=(0, 1))
293
+ b_q = (b_q * scale).to(b_q.dtype)
294
+ # [BT, BK]
295
+ b_gs = tl.load(p_gs, boundary_check=(0, 1))
296
+ # [BT, BK]
297
+ b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype)
298
+ # [BK, BV]
299
+ b_h = tl.load(p_h, boundary_check=(0, 1))
300
+ # works but dkw, owing to divine benevolence
301
+ # [BT, BV]
302
+ if i_k >= 0:
303
+ b_o += tl.dot(b_qg, b_h, allow_tf32=False)
304
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
305
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
306
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
307
+ # [BT, BV]
308
+ b_v = tl.load(p_v, boundary_check=(0, 1))
309
+ # [BT, BT]
310
+ b_A = tl.load(p_A, boundary_check=(0, 1))
311
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
312
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
313
+
314
+
315
+ @triton.jit
316
+ def chunk_rwkv6_bwd_kernel_dh(
317
+ q,
318
+ g,
319
+ gs,
320
+ do,
321
+ dh,
322
+ dh0,
323
+ s_k_h,
324
+ s_k_t,
325
+ s_k_d,
326
+ s_v_h,
327
+ s_v_t,
328
+ s_v_d,
329
+ s_h_h,
330
+ s_h_t,
331
+ s_h_d,
332
+ scale,
333
+ T: tl.constexpr,
334
+ K: tl.constexpr,
335
+ V: tl.constexpr,
336
+ BT: tl.constexpr,
337
+ BK: tl.constexpr,
338
+ BV: tl.constexpr,
339
+ NT: tl.constexpr,
340
+ USE_INITIAL_STATE: tl.constexpr
341
+ ):
342
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
343
+
344
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
345
+ for i_t in range(NT - 1, -1, -1):
346
+ o_t = min(i_t * BT + BT, T)
347
+
348
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
349
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
350
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
351
+ p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
352
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))
353
+
354
+ # [BK, BT]
355
+ b_q = tl.load(p_q, boundary_check=(0, 1))
356
+ b_q = (b_q * scale).to(b_q.dtype)
357
+ # [BT, BV]
358
+ b_do = tl.load(p_do, boundary_check=(0, 1))
359
+
360
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
361
+
362
+ # [BK,]
363
+ b_gn = tl.load(p_gn, boundary_check=(0,))
364
+ # [BK, BV]
365
+ b_dh *= tl.exp(b_gn)[:, None]
366
+ # [BK, BT]
367
+ b_gs = tl.load(p_gs, boundary_check=(0, 1))
368
+ b_q = (b_q * tl.exp(b_gs)).to(b_q.dtype)
369
+
370
+ # [BK, BV]
371
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
372
+
373
+ if USE_INITIAL_STATE:
374
+ p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
375
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
376
+
377
+
378
+ @triton.jit
379
+ def chunk_rwkv6_bwd_kernel_inter(
380
+ k,
381
+ v,
382
+ h,
383
+ g,
384
+ gs,
385
+ A,
386
+ do,
387
+ dh,
388
+ dq,
389
+ dk,
390
+ dv,
391
+ dA,
392
+ s_k_h,
393
+ s_k_t,
394
+ s_k_d,
395
+ s_v_h,
396
+ s_v_t,
397
+ s_v_d,
398
+ s_h_h,
399
+ s_h_t,
400
+ s_h_d,
401
+ scale,
402
+ T: tl.constexpr,
403
+ K: tl.constexpr,
404
+ V: tl.constexpr,
405
+ BT: tl.constexpr,
406
+ BK: tl.constexpr,
407
+ BV: tl.constexpr
408
+ ):
409
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
410
+ n_bh = tl.num_programs(2)
411
+ o_t = min(i_t * BT + BT, T)
412
+
413
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
414
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
415
+ p_gq = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
416
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))
417
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
418
+
419
+ # [BT, BK]
420
+ b_k = tl.load(p_k, boundary_check=(0, 1))
421
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
422
+ b_gq = tl.load(p_gq, boundary_check=(0, 1))
423
+ b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
424
+ b_k = (b_k * b_gn).to(b_k.dtype)
425
+ # [BT, BT]
426
+ b_A = tl.load(p_A, boundary_check=(0, 1))
427
+
428
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
429
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
430
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
431
+ for i_v in range(tl.cdiv(V, BV)):
432
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
433
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
434
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
436
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
437
+
438
+ # [BT, BV]
439
+ b_v = tl.load(p_v, boundary_check=(0, 1))
440
+ # [BV, BK]
441
+ b_h = tl.load(p_h, boundary_check=(0, 1))
442
+ # [BT, BV]
443
+ b_do = tl.load(p_do, boundary_check=(0, 1))
444
+ # [BK, BV]
445
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
446
+
447
+ # [BT, BV]
448
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
449
+ if i_k == 0:
450
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
451
+ b_do = (b_do * scale).to(b_do.dtype)
452
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
453
+ # [BT, BT]
454
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
455
+ # [BT, BK]
456
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
457
+ # [BT, BK]
458
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
459
+
460
+ b_dq = b_dq * tl.exp(b_gq)
461
+ b_dk = b_dk * b_gn
462
+
463
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
464
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
465
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
466
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
467
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
468
+
469
+ o_i = tl.arange(0, BT)
470
+ m_s = o_i[:, None] > o_i[None, :]
471
+ # [BT, BT]
472
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
473
+ if i_k == 0:
474
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
475
+
476
+
477
+ @triton.jit
478
+ def chunk_rwkv6_bwd_kernel_intra(
479
+ q,
480
+ k,
481
+ g,
482
+ gs,
483
+ dA,
484
+ dq,
485
+ dk,
486
+ s_k_h,
487
+ s_k_t,
488
+ s_k_d,
489
+ T: tl.constexpr,
490
+ K: tl.constexpr,
491
+ BT: tl.constexpr,
492
+ BC: tl.constexpr,
493
+ BK: tl.constexpr,
494
+ NC: tl.constexpr
495
+ ):
496
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
497
+ i_t, i_i = i_c // NC, i_c % NC
498
+
499
+ o_k = i_k * BK + tl.arange(0, BK)
500
+ o_q = i_t * BT + i_i * BC
501
+ m_k = o_k < K
502
+
503
+ p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
504
+ # [BK,]
505
+ b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
506
+ # [BC, BK]
507
+ b_gs = tl.load(p_gs, boundary_check=(0, 1))
508
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
509
+ for i_j in range(0, i_i):
510
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
511
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
512
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
513
+ # [BC, BK]
514
+ b_k = tl.load(p_k, boundary_check=(0, 1))
515
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
516
+ b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
517
+ # [BC, BC]
518
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
519
+ # [BC, BK]
520
+ b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
521
+ b_dq *= tl.exp(b_gs - b_gn[None, :])
522
+
523
+ o_i = tl.arange(0, BC)
524
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
525
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
526
+
527
+ for j in range(0, BC):
528
+ p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
529
+
530
+ # [BC,]
531
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
532
+ # [BK,]
533
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
534
+ b_gkj = tl.load(g + i_bh * T * K + (o_q + j) * K + o_k, mask=(m_k & ((o_q + j) < T)), other=0)
535
+ # [BC, BK]
536
+ m_i = o_i[:, None] > j
537
+ # [BC, BK]
538
+ b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_gs - b_gkj[None, :]), 0.)
539
+
540
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
541
+
542
+ b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
543
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
544
+
545
+ tl.debug_barrier()
546
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
547
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
548
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
549
+ # [BK,]
550
+ b_gn = tl.load(p_gn, boundary_check=(0,))
551
+ # [BC, BK]
552
+ b_k = tl.load(p_k, boundary_check=(0, 1))
553
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
554
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
555
+ for i_j in range(i_i + 1, NC):
556
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
557
+ p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
558
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
559
+ # [BC, BK]
560
+ b_q = tl.load(p_q, boundary_check=(0, 1))
561
+ b_gs = tl.load(p_gs, boundary_check=(0, 1))
562
+ b_qg = (b_q * tl.exp(b_gs - b_gn[None, :])).to(b_q.dtype)
563
+ # [BC, BC]
564
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
565
+ # [BC, BK]
566
+ b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
567
+ b_dk *= tl.exp(b_gn[None, :] - b_gk)
568
+
569
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
570
+ for j in range(0, BC):
571
+ p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
572
+ p_gqj = tl.make_block_ptr(gs + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
573
+ # [BC,]
574
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
575
+ # [BK,]
576
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
577
+ b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
578
+ # [BC, BK]
579
+ m_i = o_i[:, None] < j
580
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
581
+
582
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
583
+ b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
584
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
585
+
586
+
587
+ class ChunkRWKV6Function(torch.autograd.Function):
588
+
589
+ @staticmethod
590
+ @contiguous
591
+ def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):
592
+ q = r # alias
593
+ B, H, T, K, V = *q.shape, v.shape[-1]
594
+ BT, BC = 64, 16
595
+ BK = min(64, triton.next_power_of_2(K))
596
+ BV = min(64, triton.next_power_of_2(V))
597
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
598
+ NK = triton.cdiv(K, BK)
599
+ NV = triton.cdiv(V, BV)
600
+ num_warps = 4 if BK == 64 else 2
601
+ num_stages = 1
602
+
603
+ def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
604
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
605
+ h = q.new_empty(B, H, NT * K, V)
606
+ grid = (NV, NK, B * H)
607
+ chunk_rwkv6_fwd_kernel_h[grid](
608
+ k, v, g, h, h0, ht,
609
+ k.stride(1), k.stride(2), k.stride(3),
610
+ v.stride(1), v.stride(2), v.stride(3),
611
+ h.stride(1), h.stride(2), h.stride(3),
612
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
613
+ USE_INITIAL_STATE=h0 is not None,
614
+ STORE_FINAL_STATE=ht is not None,
615
+ num_warps=num_warps,
616
+ num_stages=num_stages
617
+ )
618
+ return h
619
+
620
+ final_state = None
621
+ if output_final_state:
622
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float)
623
+
624
+ g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
625
+ def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
626
+ # keep cummulative normalizer in fp32
627
+ # this kernel is equivalent to
628
+ # g_org = g_org.view(B, H, NT, BT, -1)
629
+ # g = g_org.cumsum(-2).view(B, H, T, -1)
630
+ # gs = g - g_org
631
+ chunk_rwkv6_fwd_kernel_cum[grid](
632
+ g_org, g, gs,
633
+ g.stride(1), g.stride(2), g.stride(3),
634
+ T=T, S=K, BT=BT
635
+ )
636
+ h = fwd_inner(
637
+ q=q, k=k, v=v, g=g,
638
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
639
+ h0=initial_state if initial_state is not None else None,
640
+ ht=final_state if final_state is not None else None
641
+ )
642
+ A = q.new_zeros(NK, B, H, T, BT)
643
+ grid = (NK, NT * NC * NC, B * H)
644
+ chunk_rwkv6_fwd_kernel_intra[grid](
645
+ q, k, g, gs, u, A,
646
+ k.stride(1), k.stride(2), k.stride(3),
647
+ scale,
648
+ H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,
649
+ num_warps=num_warps,
650
+ num_stages=num_stages
651
+ )
652
+ A = A.sum(0, dtype=A.dtype)
653
+ o = torch.empty_like(v)
654
+
655
+ grid = (NV, NT, B * H)
656
+ chunk_rwkv6_fwd_kernel_inter[grid](
657
+ q, v, gs, h, o, A,
658
+ k.stride(1), k.stride(2), k.stride(3),
659
+ v.stride(1), v.stride(2), v.stride(3),
660
+ h.stride(1), h.stride(2), h.stride(3),
661
+ scale,
662
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
663
+ num_warps=num_warps,
664
+ num_stages=num_stages
665
+ )
666
+
667
+ if checkpoint_level > 1:
668
+ del h
669
+ h, initial_state = None, None
670
+ del g, gs
671
+ ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A)
672
+ ctx.BT = BT
673
+ ctx.scale = scale
674
+ ctx.checkpoint_level = checkpoint_level
675
+ return o, final_state
676
+
677
+ @staticmethod
678
+ @contiguous
679
+ def backward(ctx, do, dht=None):
680
+ q, k, v, g, u, h, initial_state, A = ctx.saved_tensors
681
+ B, H, T, K, V = *q.shape, v.shape[-1]
682
+ BT, BC = ctx.BT, 16
683
+ BK = min(64, triton.next_power_of_2(K))
684
+ BV = min(64, triton.next_power_of_2(V))
685
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
686
+ NK = triton.cdiv(K, BK)
687
+ num_warps = 4 if BK == 64 else 2
688
+ num_stages = 1
689
+
690
+ def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
691
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
692
+ h = q.new_empty(B, H, NT * K, V)
693
+ grid = (NV, NK, B * H)
694
+ chunk_rwkv6_fwd_kernel_h[grid](
695
+ k, v, g, h, h0, ht,
696
+ k.stride(1), k.stride(2), k.stride(3),
697
+ v.stride(1), v.stride(2), v.stride(3),
698
+ h.stride(1), h.stride(2), h.stride(3),
699
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
700
+ USE_INITIAL_STATE=h0 is not None,
701
+ STORE_FINAL_STATE=ht is not None,
702
+ num_warps=num_warps,
703
+ num_stages=num_stages
704
+ )
705
+ return h
706
+
707
+ def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):
708
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
709
+ dh = q.new_empty(B, H, NT * K, V)
710
+ dh0 = torch.empty_like(h0) if h0 is not None else None
711
+ grid = (NK, NV, B * H)
712
+ chunk_rwkv6_bwd_kernel_dh[grid](
713
+ q, g, gs, do, dh, dh0,
714
+ q.stride(1), q.stride(2), q.stride(3),
715
+ do.stride(1), do.stride(2), do.stride(3),
716
+ dh.stride(1), dh.stride(2), dh.stride(3),
717
+ scale,
718
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
719
+ USE_INITIAL_STATE=h0 is not None,
720
+ num_warps=num_warps,
721
+ num_stages=num_stages
722
+ )
723
+ return dh, dh0
724
+
725
+ # recompute cumulative log decays.
726
+ g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
727
+ def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
728
+ # keep cummulative normalizer in fp32
729
+ # this kernel is equivalent to
730
+ # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
731
+ chunk_rwkv6_fwd_kernel_cum[grid](
732
+ g_org, g, gs,
733
+ g.stride(1), g.stride(2), g.stride(3),
734
+ T=T, S=K, BT=BT
735
+ )
736
+
737
+ # rerun the forward pass to get h if checkpoint_level >= 1
738
+ if ctx.checkpoint_level == 1:
739
+ h = fwd_inner(
740
+ q=q, k=k, v=v, g=g,
741
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
742
+ h0=initial_state if initial_state is not None else None,
743
+ ht=None
744
+ )
745
+
746
+ scale = ctx.scale
747
+ # g, gs: torch.float32
748
+ dh, dh0 = bwd_inner(
749
+ q.to(torch.float), g, gs, initial_state, do.to(torch.float),
750
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
751
+ scale=scale
752
+ )
753
+ dh = dh.to(q)
754
+ if initial_state is not None:
755
+ dh0 = dh0.to(q)
756
+ dq = torch.empty_like(q, dtype=torch.float)
757
+ dk = torch.empty_like(k, dtype=torch.float)
758
+ dv = v.new_empty(NK, *v.shape)
759
+ dA = q.new_zeros(B, H, T, BT)
760
+ grid = (NK, NT, B * H)
761
+ chunk_rwkv6_bwd_kernel_inter[grid](
762
+ k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,
763
+ k.stride(1), k.stride(2), k.stride(3),
764
+ v.stride(1), v.stride(2), v.stride(3),
765
+ h.stride(1), h.stride(2), h.stride(3),
766
+ scale,
767
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
768
+ num_warps=num_warps,
769
+ num_stages=num_stages
770
+ )
771
+ dv = dv.sum(0, dtype=dv.dtype)
772
+ grid = (NK, NT * NC, B * H)
773
+ chunk_rwkv6_bwd_kernel_intra[grid](
774
+ q, k, g, gs, dA, dq, dk,
775
+ k.stride(1), k.stride(2), k.stride(3),
776
+ T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
777
+ num_warps=num_warps,
778
+ num_stages=num_stages
779
+ )
780
+
781
+ # TODO: fuse?
782
+ dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]
783
+ dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
784
+ dg = chunk_global_reversed_cumsum(dg).to(g)
785
+ # equivalent to the following pytorch code.
786
+ # du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)
787
+ # dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])
788
+ # dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])
789
+ BT = 64
790
+ grid = (triton.cdiv(T, BT), B * H)
791
+ du = torch.empty_like(g, dtype=torch.float)
792
+ post_process_grad[grid](
793
+ q, k, v, u, do, dk, dq, du, scale,
794
+ q.stride(1), q.stride(2), q.stride(3),
795
+ v.stride(1), v.stride(2), v.stride(3), H=H,
796
+ T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),
797
+ num_warps=4
798
+ )
799
+ du = du.sum([0, 2])
800
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None
801
+
802
+
803
+ def chunk_rwkv6(
804
+ r: torch.Tensor,
805
+ k: torch.Tensor,
806
+ v: torch.Tensor,
807
+ g: torch.Tensor,
808
+ u: torch.Tensor,
809
+ scale: Optional[int] = None,
810
+ initial_state: torch.Tensor = None,
811
+ output_final_state: bool = False,
812
+ checkpoint_level: Optional[int] = 0
813
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
814
+ r"""
815
+ Args:
816
+ r (torch.Tensor):
817
+ reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
818
+ k (torch.Tensor):
819
+ keys of shape `(B, H, T, K)`
820
+ v (torch.Tensor):
821
+ values of shape `(B, H, T, V)`
822
+ w (torch.Tensor):
823
+ data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
824
+ u (torch.Tensor):
825
+ bonus of shape `(H, K)`
826
+ scale (Optional[int]):
827
+ Scale factor for the RWKV6 attention scores.
828
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
829
+ initial_state (Optional[torch.Tensor]):
830
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
831
+ output_final_state (Optional[bool]):
832
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
833
+ checkpoint_level (Optional[int]):
834
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
835
+ Default: `0`:
836
+ - Level `0`: store forward hidden states for backprop.
837
+ - Level `1`: recompute the forward hidden states during backward.
838
+ """
839
+ assert checkpoint_level in [0, 1]
840
+ if scale is None:
841
+ scale = r.shape[-1] ** -0.5
842
+ o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)
843
+ return o, final_state
844
+
845
+
846
+ if __name__ == "__main__":
847
+ import torch.nn.functional as F
848
+
849
+ from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
850
+ B = 8
851
+ H = 4
852
+ L = 1024
853
+ K = 100
854
+ V = 120
855
+
856
+ torch.manual_seed(0)
857
+ dtype = torch.float
858
+ q = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
859
+ k = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
860
+ v = torch.randn(B, H, L, V).cuda().to(dtype).requires_grad_(True)
861
+ w = (-torch.randn(B, H, L, K).exp()).cuda().requires_grad_(True)
862
+ u = torch.randn(H, K).cuda().to(dtype).requires_grad_(True)
863
+ h0 = torch.randn(B, H, K, V).cuda().to(dtype).requires_grad_(True)
864
+ do = torch.rand_like(v).cuda()
865
+ o, ht = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
866
+ o.backward(do)
867
+ dq, q.grad = q.grad.clone(), None
868
+ dk, k.grad = k.grad.clone(), None
869
+ dv, v.grad = v.grad.clone(), None
870
+ dw, w.grad = w.grad.clone(), None
871
+ du, u.grad = u.grad.clone(), None
872
+ dh0, h0.grad = h0.grad.clone(), None
873
+ o2, ht2 = chunk_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
874
+ o2.backward(do)
875
+ torch.testing.assert_close(o, o2, rtol=0, atol=1e-4)
876
+ torch.testing.assert_close(ht, ht2, rtol=0, atol=1e-4)
877
+ torch.testing.assert_close(q.grad, dq, rtol=0, atol=1e-4)
878
+ torch.testing.assert_close(k.grad, dk, rtol=0, atol=1e-4)
879
+ torch.testing.assert_close(v.grad, dv, rtol=0, atol=1e-4)
880
+ torch.testing.assert_close(w.grad, dw, rtol=0, atol=1e-4)
881
+ torch.testing.assert_close(u.grad, du, rtol=0, atol=2e-4)
882
+ torch.testing.assert_close(h0.grad, dh0, rtol=0, atol=2e-4)
883
+
884
+ print("All tests passed!")
885
+
886
+ @triton.testing.perf_report(
887
+ triton.testing.Benchmark(
888
+ # argument names to use as an x-axis for the plot
889
+ x_names=['T'],
890
+ # different possible values for `x_name`
891
+ x_vals=[128 * 2 ** i for i in range(0, 8)],
892
+ # argument name whose value corresponds to a different line in the plot
893
+ line_arg='provider',
894
+ # possible values for `line_arg``
895
+ line_vals=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
896
+ # label name for the lines
897
+ line_names=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
898
+ # line styles
899
+ styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
900
+ ylabel="Execution Time (ms)", # label name for the y-axis
901
+ # name for the plot. Used also as a file name for saving the plot.
902
+ plot_name="Performance",
903
+ args={},
904
+ )
905
+ )
906
+ def benchmark(T, provider):
907
+ device = 'cuda'
908
+ dtype = torch.bfloat16
909
+ requires_grad = True
910
+ B, H, K = 16, 4, 128
911
+
912
+ q = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
913
+ k = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
914
+ v = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
915
+ w = F.logsigmoid(torch.randn(B, H, T, K)).to(dtype=dtype, device=device).requires_grad_(True)
916
+ u = torch.randn(H, K, device=device, requires_grad=requires_grad, dtype=dtype)
917
+
918
+ do = torch.ones_like(q, dtype=dtype)
919
+ quantiles = [0.5, 0.2, 0.8]
920
+ results = 0, 0, 0
921
+ if provider == 'recurrent':
922
+ results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u), quantiles=quantiles)
923
+ if provider == 'chunk':
924
+ results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles)
925
+ if provider == 'recurrent_bwd':
926
+ results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u)
927
+ [0].backward(do), quantiles=quantiles)
928
+ if provider == 'chunk_bwd':
929
+ results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles)
930
+ return results
931
+ benchmark.run(print_data=True)
fla2/ops/rwkv6/chunk_naive.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def naive_chunk_rwkv6(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ w: torch.Tensor,
12
+ u: torch.Tensor,
13
+ chunk_size: int = 32
14
+ ):
15
+ assert q.shape[-2] % chunk_size == 0
16
+ orig_dtype = q.dtype
17
+ num_chunk = q.shape[-2] // chunk_size
18
+ u = u.unsqueeze(0)
19
+
20
+ q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w))
21
+
22
+ w_cumsum = w.cumsum(-2)
23
+
24
+ kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp()
25
+ wkv = kw.transpose(-1, -2) @ v
26
+
27
+ wkv_new = torch.zeros_like(wkv)
28
+
29
+ for i in range(num_chunk - 1):
30
+ wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i]
31
+
32
+ o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp()))
33
+
34
+ o_intra = torch.zeros_like(o_inter)
35
+ for i in range(chunk_size):
36
+ attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1)
37
+ mask = (torch.arange(0, chunk_size) < i).to(attn.device)
38
+ attn.masked_fill_(~mask, 0)
39
+ intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2)
40
+ intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i]
41
+ o_intra[:, :, :, i] = intra_inter_o + intra_intra_o
42
+ o = o_inter + o_intra
43
+ return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype)
fla2/ops/rwkv6/recurrent_fuse.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2024, Songlin Yang
4
+
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.ops.utils import chunk_global_reversed_cumsum
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
13
+
14
+
15
+ @triton.jit
16
+ def fused_recurrent_rwkv6_fwd_kernel(
17
+ q, # query [B, H, T, K]
18
+ k, # key [B, H, T, K]
19
+ v, # value [B, H, T, V]
20
+ w, # log gate [B, H, T, K]
21
+ u, # bonus [B, H, K]
22
+ o, # output [B, H, T, V]
23
+ # initial hidden state initialization [B, H, K, V]
24
+ h0,
25
+ ht, # final hidden state [B, H, K, V]
26
+ s_k_h, # stride size: T * K
27
+ s_v_h, # stride size: T * V
28
+ scale, # K ** -0.5
29
+ B: tl.constexpr,
30
+ H: tl.constexpr,
31
+ T: tl.constexpr,
32
+ K: tl.constexpr,
33
+ V: tl.constexpr,
34
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
35
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
36
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
37
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
38
+ REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
39
+ ):
40
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
41
+ i_h = i_bh % H
42
+
43
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
44
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
45
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
46
+ p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
47
+ p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
48
+ p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
49
+
50
+ mask_bk = (i_k * BK + tl.arange(0, BK)) < K
51
+ mask_bv = (i_v * BV + tl.arange(0, BV)) < V
52
+ mask_kv = mask_bv[:, None] & mask_bk[None, :]
53
+
54
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
55
+ if USE_INITIAL_STATE:
56
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
57
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
58
+
59
+ b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
62
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
63
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
64
+ b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
65
+ b_w = tl.exp(b_w)
66
+ b_kv = b_k[None, :] * b_v[:, None]
67
+ b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]
68
+ b_o = tl.sum(b_o, axis=1)
69
+ b_h = b_h * b_w[None, :]
70
+ b_h += b_kv
71
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
72
+ p_q += -K if REVERSE else K
73
+ p_k += -K if REVERSE else K
74
+ p_o += -V if REVERSE else V
75
+ p_v += -V if REVERSE else V
76
+ p_w += -K if REVERSE else K
77
+
78
+ if STORE_FINAL_STATE:
79
+ p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
80
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
81
+
82
+
83
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
84
+ @triton.jit
85
+ def fused_recurrent_rwkv6_bwd_kernel_dq(
86
+ # B: B, H: H, T: T, D: d_head
87
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
88
+ k, # key [B, H, T, V]
89
+ v, # value [B, H, T, V]
90
+ w, # log gate [B, H, T, K]
91
+ u, # bonus [B, H, K]
92
+
93
+ do, # gradient of output [B, H, T, V]
94
+ dq, # gradient of query [NV, B, H, T, K]
95
+ dq_aux, # gradient of query_aux [NV, B, H, T, K]
96
+
97
+ # initial hidden state initialization [B, H, K, V]
98
+ h0,
99
+
100
+ s_k_h, # stride size: T * K
101
+ s_v_h, # stride size: T * V
102
+
103
+ scale, # K ** -0.5
104
+ B: tl.constexpr, # B
105
+ H: tl.constexpr, # H
106
+ T: tl.constexpr, # T
107
+ K: tl.constexpr, # K
108
+ V: tl.constexpr, # V
109
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
110
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
111
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
112
+ REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
113
+ ):
114
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
115
+ i_h = i_bh % H
116
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
117
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
118
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
119
+ p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
120
+ p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
121
+ p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
122
+ p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
123
+
124
+ mask_bk = i_k * BK + tl.arange(0, BK) < K
125
+ mask_bv = i_v * BV + tl.arange(0, BV) < V
126
+ mask_kv = mask_bv[:, None] & mask_bk[None, :]
127
+ b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
128
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
129
+
130
+ if USE_INITIAL_STATE:
131
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
132
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
133
+
134
+ for _ in range(0, T):
135
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
136
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
137
+ b_kv = b_k[None, :] * b_v[:, None]
138
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
139
+ b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
140
+ b_w = tl.exp(b_w)
141
+ h_q = b_h * b_do[:, None]
142
+ b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0)
143
+ b_dq *= scale
144
+ b_dq_aux = tl.sum(h_q, axis=0)
145
+ b_h = b_h * b_w[None, :]
146
+ b_h += b_kv
147
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
148
+ tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk)
149
+ p_k += -K if REVERSE else K
150
+ p_do += -V if REVERSE else V
151
+ p_v += -V if REVERSE else V
152
+ p_w += -K if REVERSE else K
153
+ p_dq += -K if REVERSE else K
154
+ p_dq_aux += -K if REVERSE else K
155
+
156
+
157
+ @triton.jit
158
+ def fused_recurrent_rwkv6_bwd_kernel_dkv(
159
+ # B: B, H: H, T: T, D: d_head
160
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
161
+ q, # query [B, H, T, K]
162
+ k, # key [B, H, T, V]
163
+ v, # value [B, H, T, V]
164
+ w, # log gate [B, H, T, K]
165
+ u, # bonus [B, H, K]
166
+
167
+ do, # gradient of output [B, H, T, V]
168
+ dk,
169
+ dk_aux,
170
+ dv,
171
+ dh0,
172
+
173
+ # initial hidden state initialization [B, H, K, V]
174
+ s_k_h, # stride size: T * K
175
+ s_v_h, # stride size: T * V
176
+
177
+ scale, # K ** -0.5
178
+ B: tl.constexpr, # B
179
+ H: tl.constexpr, # H
180
+ T: tl.constexpr, # T
181
+ K: tl.constexpr, # K
182
+ V: tl.constexpr, # V
183
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
184
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
185
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
186
+ REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
187
+ ):
188
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
189
+ i_h = i_bh % H
190
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
191
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
192
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
193
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
194
+ p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
195
+ p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
196
+ p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
197
+ p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
198
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
199
+ mask_bk = i_k * BK + tl.arange(0, BK) < K
200
+ mask_bv = i_v * BV + tl.arange(0, BV) < V
201
+ mask_kv = mask_bk[:, None] & mask_bv[None, :]
202
+
203
+ p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
204
+ b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
205
+
206
+ for _ in range(T-1, -1, -1):
207
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
208
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
209
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
210
+ b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
211
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
212
+ b_dkv = b_q[:, None] * b_do[None, :]
213
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
214
+ tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk)
215
+ b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1)
216
+ b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0)
217
+
218
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
219
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
220
+ b_dh *= tl.exp(b_w)[:, None]
221
+ b_dh += b_dkv
222
+
223
+ p_q += K if REVERSE else -K
224
+ p_k += K if REVERSE else -K
225
+ p_v += V if REVERSE else -V
226
+ p_w += K if REVERSE else -K
227
+ p_do += V if REVERSE else -V
228
+ p_dk += K if REVERSE else -K
229
+ p_dk_aux += K if REVERSE else -K
230
+ p_dv += V if REVERSE else -V
231
+
232
+ if USE_INITIAL_STATE:
233
+ p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
234
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)
235
+
236
+
237
+ class FusedRecurrentRWKV6Function(torch.autograd.Function):
238
+
239
+ @staticmethod
240
+ @contiguous
241
+ @autocast_custom_fwd
242
+ def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):
243
+ q = r
244
+ B, H, T, K, V = *q.shape, v.shape[-1]
245
+
246
+ BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
247
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
248
+ num_stages = 1
249
+ num_warps = 1
250
+
251
+ final_state = q.new_empty(B, H, K, V) if output_final_state else None
252
+
253
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
254
+ grid = (NV, NK, B * H)
255
+ fused_recurrent_rwkv6_fwd_kernel[grid](
256
+ q, k, v, w, u, o, initial_state, final_state,
257
+ k.stride(1),
258
+ v.stride(1),
259
+ scale,
260
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
261
+ USE_INITIAL_STATE=initial_state is not None,
262
+ STORE_FINAL_STATE=final_state is not None,
263
+ REVERSE=reverse,
264
+ num_warps=num_warps,
265
+ num_stages=num_stages
266
+ )
267
+
268
+ o = o.sum(0)
269
+ ctx.save_for_backward(q, k, v, w, u, initial_state)
270
+ ctx.scale = scale
271
+ ctx.reverse = reverse
272
+ return o.to(q.dtype), final_state
273
+
274
+ @staticmethod
275
+ @contiguous
276
+ @autocast_custom_bwd
277
+ def backward(ctx, do, dht=None):
278
+ q, k, v, w, u, initial_state = ctx.saved_tensors
279
+ B, H, T, K, V = *q.shape, v.shape[-1]
280
+ scale = ctx.scale
281
+
282
+ BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)
283
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
284
+ num_stages = 1
285
+ num_warps = 1
286
+ dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
287
+ dq_aux = torch.empty_like(dq)
288
+ grid = (NV, NK, B * H)
289
+
290
+ fused_recurrent_rwkv6_bwd_kernel_dq[grid](
291
+ k, v, w, u, do, dq, dq_aux, initial_state,
292
+ q.stride(1),
293
+ v.stride(1),
294
+ scale,
295
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
296
+ USE_INITIAL_STATE=initial_state is not None,
297
+ REVERSE=ctx.reverse,
298
+ num_warps=num_warps,
299
+ num_stages=num_stages
300
+ )
301
+ dq = dq.sum(0).to(q)
302
+ dq_aux = dq_aux.sum(0)
303
+
304
+ BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
305
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
306
+
307
+ dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
308
+ dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
309
+ dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
310
+ dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None
311
+ grid = (NV, NK, B * H)
312
+ fused_recurrent_rwkv6_bwd_kernel_dkv[grid](
313
+ q, k, v, w, u, do, dk, dk_aux, dv, dh0,
314
+ q.stride(1),
315
+ v.stride(1),
316
+ scale,
317
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
318
+ num_warps=num_warps,
319
+ num_stages=num_stages,
320
+ USE_INITIAL_STATE=initial_state is not None,
321
+ REVERSE=ctx.reverse,
322
+ )
323
+ dk = dk.sum(0).to(k)
324
+ dv = dv.sum(0).to(v)
325
+ dk_aux = dk_aux.sum(0)
326
+
327
+ dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1]
328
+ dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
329
+ dw = chunk_global_reversed_cumsum(dw).to(w)
330
+
331
+ du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u)
332
+ return dq, dk, dv, dw, du, None, dh0, None, None
333
+
334
+
335
+ def fused_recurrent_rwkv6(
336
+ r: torch.Tensor,
337
+ k: torch.Tensor,
338
+ v: torch.Tensor,
339
+ w: torch.Tensor,
340
+ u: torch.Tensor,
341
+ scale: float = -1,
342
+ initial_state: torch.Tensor = None,
343
+ output_final_state: bool = False
344
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ r"""
346
+ Args:
347
+ r (torch.Tensor):
348
+ reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
349
+ k (torch.Tensor):
350
+ keys of shape `(B, H, T, K)`
351
+ v (torch.Tensor):
352
+ values of shape `(B, H, T, V)`
353
+ w (torch.Tensor):
354
+ data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
355
+ u (torch.Tensor):
356
+ bonus of shape `(H, K)`
357
+ scale (Optional[int]):
358
+ Scale factor for the RWKV6 attention scores.
359
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
360
+ initial_state (Optional[torch.Tensor]):
361
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
362
+ output_final_state (Optional[bool]):
363
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
364
+ """
365
+ if scale == -1:
366
+ scale = r.shape[-1] ** -0.5
367
+ o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)
368
+ return o, final_state
fla2/ops/rwkv6/recurrent_naive.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_rwkv6(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ w: torch.Tensor,
13
+ u: torch.Tensor,
14
+ scale: Optional[float] = None,
15
+ initial_state: Optional[torch.Tensor] = None,
16
+ output_final_state: Optional[bool] = False
17
+ ):
18
+ orig_dtype = q.dtype
19
+ B, H, T, K, V = *q.shape, v.shape[-1]
20
+ q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
21
+ h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
22
+ o = torch.zeros_like(v)
23
+
24
+ if scale is None:
25
+ scale = K ** -0.5
26
+
27
+ if initial_state is not None:
28
+ h += initial_state
29
+
30
+ for i in range(T):
31
+ q_i = q[:, :, i, :] * scale
32
+ k_i = k[:, :, i]
33
+ v_i = v[:, :, i, :]
34
+ w_i = w[:, :, i].exp()
35
+ kv_i = k_i[..., None] * v_i[..., None, :]
36
+ o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
37
+ o[:, :, i] = o_i.sum(-2)
38
+ h = h * w_i[..., None] + kv_i
39
+ ht = h if output_final_state else None
40
+ return o.to(orig_dtype), ht
41
+
42
+
43
+ @torch.no_grad
44
+ @torch.jit.script
45
+ def naive_recurrent_rwkv6_bwd(
46
+ q: torch.Tensor,
47
+ k: torch.Tensor,
48
+ v: torch.Tensor,
49
+ w: torch.Tensor,
50
+ u: torch.Tensor,
51
+ o: torch.Tensor,
52
+ do: torch.Tensor,
53
+ initial_state: Optional[torch.Tensor] = None
54
+ ):
55
+ q, k, v, w, u, o, do = (x.to(dtype=torch.float32) for x in (q, k, v, w, u, o, do))
56
+ B, H, T, K, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1]
57
+ h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
58
+ dq = torch.zeros_like(q)
59
+ dq_aux = torch.zeros_like(q)
60
+
61
+ if initial_state is not None:
62
+ h += initial_state
63
+
64
+ for i in range(T):
65
+ k_i = k[:, :, i]
66
+ v_i = v[:, :, i]
67
+ w_i = w[:, :, i].exp()
68
+ kv_i = k_i[..., None] * v_i[..., None, :]
69
+ h_i = (h + u[None, ..., None] * kv_i)
70
+ dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
71
+ dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
72
+ dq[:, :, i] = dq_i
73
+ dq_aux[:, :, i] = dq_aux_i
74
+ h = h * w_i[..., None] + kv_i
75
+
76
+ du = torch.zeros_like(u)
77
+ dh = torch.zeros_like(h)
78
+ dk = torch.zeros_like(k)
79
+ dk_aux = torch.zeros_like(k)
80
+ dv = torch.zeros_like(v)
81
+
82
+ for i in range(T - 1, -1, -1):
83
+ d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
84
+ k_i = k[:, :, i]
85
+ v_i = v[:, :, i]
86
+ du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
87
+ du += du_i.sum(0)
88
+ dk_i = (dh * v_i[..., None, :]).sum(-1)
89
+ dk_aux[:, :, i] = dk_i
90
+ dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
91
+ dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
92
+ dv_i += (dh * k_i[..., None]).sum(-2)
93
+
94
+ dk[:, :, i] = dk_i
95
+ dv[:, :, i] = dv_i
96
+ dh = dh * w[:, :, i, :, None].exp() + d_kv_i
97
+
98
+ # dw = q * dq_aux - k * dk_aux
99
+ dw = torch.zeros_like(w)
100
+ for i in range(T - 2, -1, -1):
101
+ dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
102
+
103
+ return dq, dk, dv, dw, du, dh
fla2/ops/simple_gla/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ - Simple GLA
2
+
3
+ Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA.
4
+
5
+ $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
fla2/ops/simple_gla/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_simple_gla
4
+
5
+ __all__ = [
6
+ 'chunk_simple_gla'
7
+ ]
fla2/ops/simple_gla/chunk.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023, Yu Zhang, Songlin Yang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
10
+ from fla.ops.utils import chunk_local_cumsum, chunk_global_reversed_cumsum
11
+ from fla.ops.common.chunk_h import chunk_fwd_h_fn, chunk_bwd_dh_fn
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=4),
16
+ ],
17
+ key=["BT", "BK", "BV"],
18
+ )
19
+ @triton.jit
20
+ def chunk_simple_gla_fwd_kernel_o(
21
+ q,
22
+ k,
23
+ v,
24
+ h,
25
+ g,
26
+ o,
27
+ s_qk_h,
28
+ s_qk_t,
29
+ s_qk_d,
30
+ s_vo_h,
31
+ s_vo_t,
32
+ s_vo_d,
33
+ s_h_h,
34
+ s_h_t,
35
+ scale,
36
+ T: tl.constexpr,
37
+ K: tl.constexpr,
38
+ V: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ BK: tl.constexpr,
41
+ BV: tl.constexpr
42
+ ):
43
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
44
+
45
+ o_i = tl.arange(0, BT)
46
+ m_s = o_i[:, None] >= o_i[None, :]
47
+
48
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
49
+ b_s = tl.zeros([BT, BT], dtype=tl.float32)
50
+ for i_k in range(tl.cdiv(K, BK)):
51
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
52
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
53
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
54
+ # [BT, BK]
55
+ b_q = tl.load(p_q, boundary_check=(0, 1))
56
+ # [BK, BT]
57
+ b_k = tl.load(p_k, boundary_check=(0, 1))
58
+ # [BK, BV]
59
+ b_h = tl.load(p_h, boundary_check=(0, 1))
60
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
61
+ b_s += tl.dot(b_q, b_k, allow_tf32=False)
62
+
63
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
64
+ b_g = tl.load(p_g, boundary_check=(0,))
65
+ b_o = b_o * tl.exp(b_g)[:, None]
66
+ b_s = b_s * tl.exp(b_g[:, None] - b_g[None, :])
67
+ b_s = tl.where(m_s, b_s, 0)
68
+
69
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
70
+ b_v = tl.load(p_v, boundary_check=(0, 1))
71
+ b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
72
+ p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
74
+
75
+
76
+ @triton.autotune(
77
+ configs=[
78
+ triton.Config({}, num_warps=4),
79
+ triton.Config({}, num_warps=8)
80
+ ],
81
+ key=["BT", "BK", "BV"],
82
+ )
83
+ @triton.jit
84
+ def chunk_simple_gla_bwd_kernel_dqkvg(
85
+ q,
86
+ k,
87
+ v,
88
+ h,
89
+ g,
90
+ do,
91
+ dh,
92
+ dq,
93
+ dk,
94
+ dv,
95
+ dg,
96
+ s_qk_h,
97
+ s_qk_t,
98
+ s_qk_d,
99
+ s_vo_h,
100
+ s_vo_t,
101
+ s_vo_d,
102
+ s_h_h,
103
+ s_h_t,
104
+ scale,
105
+ T: tl.constexpr,
106
+ K: tl.constexpr,
107
+ V: tl.constexpr,
108
+ BT: tl.constexpr,
109
+ BK: tl.constexpr,
110
+ BV: tl.constexpr,
111
+ NT: tl.constexpr
112
+ ):
113
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
114
+ n_bh = tl.num_programs(2)
115
+ o_i = tl.arange(0, BT)
116
+
117
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
118
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
119
+
120
+ b_q = tl.load(p_q, boundary_check=(0, 1))
121
+ b_k = tl.load(p_k, boundary_check=(0, 1))
122
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
123
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
124
+ b_g = tl.load(p_g, boundary_check=(0,))
125
+ if i_t < NT - 1:
126
+ b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
127
+ else:
128
+ b_g_last = tl.load(g + i_bh * T + T - 1)
129
+ mask = tl.exp(b_g[None, :] - b_g[:, None])
130
+ mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)
131
+ b_s = b_s * mask
132
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
133
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
134
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
135
+ for i_v in range(tl.cdiv(V, BV)):
136
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
137
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
138
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
140
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
141
+ # [BT, BV]
142
+ b_v = tl.load(p_v, boundary_check=(0, 1))
143
+ b_do = tl.load(p_do, boundary_check=(0, 1))
144
+ # [BV, BK]
145
+ b_h = tl.load(p_h, boundary_check=(0, 1))
146
+ # [BK, BV]
147
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
148
+ # [BT, BT]
149
+ b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
150
+ # [BT, BK]
151
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
152
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
153
+ # [BT, BV]
154
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.exp(-b_g + b_g_last)[:, None]
155
+ b_dv += tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
156
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
157
+
158
+ b_dq = b_dq * tl.exp(b_g)[:, None]
159
+ b_dk = b_dk * tl.exp(-b_g + b_g_last)[:, None]
160
+ b_ds = b_ds * tl.trans(mask)
161
+ b_ds = b_ds.to(b_k.dtype)
162
+ # [BT, BK]
163
+ b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
164
+ b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
165
+ p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
166
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
167
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
168
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
169
+
170
+ tl.debug_barrier()
171
+ b_ds = None
172
+ b_s = None
173
+ b_q = None
174
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
175
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
176
+ b_dg = tl.sum(b_dq * b_q - b_dk * b_k.to(tl.float32), axis=1)
177
+ p_dg = tl.make_block_ptr(dg + (i_k*n_bh + i_bh) * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
178
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
179
+
180
+
181
+ def chunk_fwd_o_fn(h, q, k, v, g, BT, scale):
182
+ B, H, T, K, V = *k.shape, v.shape[-1]
183
+ o = torch.empty_like(v)
184
+ BK = min(triton.next_power_of_2(K), 64)
185
+ BV = min(triton.next_power_of_2(V), 64)
186
+ NV = triton.cdiv(V, BV)
187
+ NT = triton.cdiv(T, BT)
188
+ grid = (NV, NT, B * H)
189
+ chunk_simple_gla_fwd_kernel_o[grid](
190
+ q, k, v, h, g, o,
191
+ q.stride(1), q.stride(2), q.stride(3),
192
+ v.stride(1), v.stride(2), v.stride(3),
193
+ h.stride(1), h.stride(2),
194
+ scale,
195
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV
196
+ )
197
+ return o
198
+
199
+
200
+ def chunk_bwd_dqkvg_fn(do, q, k, v, g, h, dh, scale):
201
+ B, H, T, K, V = *k.shape, v.shape[-1]
202
+ BT = 64
203
+ BK = min(triton.next_power_of_2(K), 64)
204
+ BV = min(triton.next_power_of_2(V), 64)
205
+ NT, NK = triton.cdiv(T, BT), triton.cdiv(K, BK)
206
+ grid = (NK, NT, B * H)
207
+ dq = torch.empty_like(q)
208
+ dk = torch.empty_like(k)
209
+ dv = v.new_empty(NK, *v.shape)
210
+ dg = torch.empty(NK, B, H, T, dtype=torch.float32, device=g.device)
211
+ chunk_simple_gla_bwd_kernel_dqkvg[grid](
212
+ q, k, v, h, g, do, dh, dq, dk, dv, dg,
213
+ q.stride(1), q.stride(2), q.stride(3),
214
+ v.stride(1), v.stride(2), v.stride(3),
215
+ dh.stride(1), dh.stride(2),
216
+ scale,
217
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT
218
+ )
219
+ dv = dv.sum(0)
220
+ dg = dg.sum(0)
221
+ dg = chunk_global_reversed_cumsum(dg)
222
+ return dq, dk, dv, dg
223
+
224
+
225
+
226
+
227
+ class SimpleGLAFunction(torch.autograd.Function):
228
+ @staticmethod
229
+ @contiguous
230
+ @autocast_custom_fwd
231
+ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level=1):
232
+ B, H, T, K, V = *q.shape, v.shape[-1]
233
+ BT = 64
234
+ g = chunk_local_cumsum(g, BT)
235
+ h, final_state = chunk_fwd_h_fn(k=k, v=v, g=g, gk=None, gv=None, BT=BT, h0=initial_state, output_final_state=output_final_state)
236
+ o = chunk_fwd_o_fn(h, q, k, v, g, BT, scale)
237
+ if checkpoint_level == 1:
238
+ h = None
239
+ ctx.save_for_backward(q, k, v, h, g, initial_state)
240
+ ctx.scale = scale
241
+ ctx.BT = BT
242
+ return o.to(q.dtype), final_state
243
+
244
+ @staticmethod
245
+ @contiguous
246
+ @autocast_custom_bwd
247
+ def backward(ctx, do, dht):
248
+ BT, scale = ctx.BT, ctx.scale
249
+ q, k, v, h, g, initial_state = ctx.saved_tensors
250
+ if h is None:
251
+ h, final_state = chunk_fwd_h_fn(k=k, v=v, g=g, gk=None, gv=None, BT=BT, h0=initial_state, output_final_state=False)
252
+ dh, dh0 = chunk_bwd_dh_fn(q=q, k=k, v=v, g=g, gk=None, gv=None, do=do, h0=initial_state, dht=dht, BT=BT, scale=scale)
253
+ dq, dk, dv, dg = chunk_bwd_dqkvg_fn(do, q, k, v, g, h, dh, scale)
254
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, dh0, None, None
255
+
256
+
257
+
258
+ def chunk_simple_gla(
259
+ q: torch.Tensor,
260
+ k: torch.Tensor,
261
+ v: torch.Tensor,
262
+ g: torch.Tensor, # log decay
263
+ scale: Optional[float] = None,
264
+ initial_state: torch.Tensor = None,
265
+ output_final_state: bool = False,
266
+ checkpoint_level: int = 1
267
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
268
+ r"""
269
+ Args:
270
+ q (torch.Tensor):
271
+ queries of shape `(B, H, T, K)`
272
+ k (torch.Tensor):
273
+ keys of shape `(B, H, T, K)`
274
+ v (torch.Tensor):
275
+ values of shape `(B, H, T, V)`
276
+ g (torch.Tensor):
277
+ Forget gates of shape `(B, H, T)` applied to keys.
278
+ Compared to GLA, the gating is head-wise instead of elementwise.
279
+ scale (Optional[int]):
280
+ Scale factor for the attention scores.
281
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
282
+ initial_state (Optional[torch.Tensor]):
283
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
284
+ output_final_state (Optional[bool]):
285
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
286
+ checkpoint_level (Optional[int]):
287
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
288
+ Default: `1` (recommended):
289
+ - Level `0`: no memory saved, no recomputation.
290
+ - Level `1`: recompute the chunk-level hidden state `h` during backward pass.
291
+ """
292
+ assert checkpoint_level in [0, 1], "checkpoint_level must be 0, 1"
293
+ assert q.dim() == k.dim() == v.dim() == 4, "q, k, v must have 4 dimensions (b, h, l, d)"
294
+ assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype"
295
+ if scale is None:
296
+ scale = k.shape[-1] ** -0.5
297
+ g = g.float()
298
+ o, final_state = SimpleGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
299
+ return o, final_state
fla2/ops/simple_gla/naive.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None):
8
+ if scale is None:
9
+ scale = (q.shape[-1] ** -0.5)
10
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale
11
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
12
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
13
+ g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size)
14
+ g = g.cumsum(-1)
15
+ kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
16
+ S = torch.zeros_like(kv)
17
+
18
+ for i in range(1, g.shape[-2]):
19
+ S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
20
+
21
+ inter = (q * g[..., None].exp()) @ S
22
+ attn = q @ k.transpose(-1, -2)
23
+ attn = attn * (g[..., None] - g[..., None, :]).exp()
24
+ attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
25
+ intra = attn @ v
26
+ o = inter + intra
27
+ return rearrange(o, 'b h n c d -> b h (n c) d')
28
+
29
+
30
+ def torch_simple_gla_recurrent(q, k, v, g, initial_state=None, scale=None):
31
+ B, H, T, DK = q.shape
32
+ if scale is None:
33
+ scale = DK ** -0.5
34
+ q = q * scale
35
+ _, _, _, DV = v.shape
36
+ if initial_state is None:
37
+ S = torch.zeros(B, H, DK, DV).to(q)
38
+ else:
39
+ S = initial_state
40
+ o = torch.zeros(B, H, T, DV).to(q)
41
+ for i in range(T):
42
+ gate = g[:, :, i].exp()
43
+ key = k[:, :, i]
44
+ value = v[:, :, i]
45
+ kv = key.unsqueeze(-1) * value.unsqueeze(-2)
46
+ S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
47
+ q_i = q[:, :, i, :]
48
+ o_i = (q_i.unsqueeze(-1) * S).sum(-2)
49
+ o[:, :, i] = o_i
50
+ return o, S
51
+
52
+ if __name__ == '__main__':
53
+ torch.set_default_dtype(torch.bfloat16)
54
+ B = 4
55
+ H = 4
56
+ L = 100
57
+ DK = 32
58
+ DV = 32
59
+ q = torch.randn(B, H, L, DK)
60
+ k = torch.randn(B, H, L, DK)
61
+ v = torch.randn(B, H, L, DV)
62
+ g = torch.nn.functional.logsigmoid(torch.randn(B, H, L))
63
+ q, k, v, g = map(lambda x: x.cuda().requires_grad_(True), [q, k, v, g])
64
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
65
+
66
+ o, _ = fused_recurrent_simple_gla(q, k, v, g)
67
+ do = torch.randn_like(o)
68
+ o.backward(do)
69
+ q_grad, k_grad, v_grad, g_grad = q.grad, k.grad, v.grad, g.grad
70
+ q.grad, k.grad, v.grad, g.grad = None, None, None, None
71
+ o2, _ = chunk_simple_gla(q, k, v, g)
72
+ o2.backward(do)
73
+ q_grad2, k_grad2, v_grad2, g_grad2 = q.grad, k.grad, v.grad, g.grad
74
+
75
+ print((o-o2).abs().max())
76
+ print((q_grad-q_grad2).abs().max())
77
+ print((k_grad-k_grad2).abs().max())
78
+ print((v_grad-v_grad2).abs().max())
79
+ print((g_grad-g_grad2).abs().max())
80
+
81
+
fla2/ops/simple_gla/recurrent_fuse.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023, Yu Zhang, Songlin Yang
3
+
4
+ from typing import Tuple, Optional
5
+ import torch
6
+ from fla.ops.common.fused_recurrent import fused_recurrent
7
+
8
+ def fused_recurrent_simple_gla(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ g: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ initial_state: torch.Tensor = None,
15
+ output_final_state: bool = False,
16
+ reverse: bool = False
17
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
18
+ if scale is None:
19
+ scale = q.shape[-1] ** -0.5
20
+ o, final_state = fused_recurrent(q, k, v, g, None, None, scale, initial_state, output_final_state, reverse)
21
+ return o, final_state
fla3/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
fla3/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.01 kB). View file
 
fla3/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.97 kB). View file
 
fla3/__pycache__/utils.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
fla3/layers/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # -*- coding: utf-8 -*-
2
+ # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # from .abc import ABCAttention
5
+ # from .attn import Attention
6
+ # from .based import BasedLinearAttention
7
+ # from .bitattn import BitAttention
8
+ # from .delta_net import DeltaNet
9
+ # from .forgetting_attn import ForgettingAttention
10
+ # from .gated_deltanet import GatedDeltaNet
11
+ # from .gated_deltaproduct import GatedDeltaProduct
12
+ # from .gla import GatedLinearAttention
13
+ # from .gsa import GatedSlotAttention
14
+ # from .hgrn import HGRNAttention
15
+ # from .hgrn2 import HGRN2Attention
16
+ # from .lightnet import LightNetAttention
17
+ # from .linear_attn import LinearAttention
18
+ # from .mamba import Mamba
19
+ # from .mamba2 import Mamba2
20
+ # from .multiscale_retention import MultiScaleRetention
21
+ # from .nsa import NativeSparseAttention
22
+ # from .path_attn import PaTHAttention
23
+ # from .rebased import ReBasedLinearAttention
24
+ # from .rwkv6 import RWKV6Attention
25
+ # from .rwkv7 import RWKV7Attention
26
+
27
+ # __all__ = [
28
+ # 'ABCAttention',
29
+ # 'Attention',
30
+ # 'BasedLinearAttention',
31
+ # 'BitAttention',
32
+ # 'DeltaNet',
33
+ # 'ForgettingAttention',
34
+ # 'GatedDeltaNet',
35
+ # 'GatedDeltaProduct',
36
+ # 'GatedLinearAttention',
37
+ # 'GatedSlotAttention',
38
+ # 'HGRNAttention',
39
+ # 'HGRN2Attention',
40
+ # 'LightNetAttention',
41
+ # 'LinearAttention',
42
+ # 'Mamba',
43
+ # 'Mamba2',
44
+ # 'MultiScaleRetention',
45
+ # 'NativeSparseAttention',
46
+ # 'ReBasedLinearAttention',
47
+ # 'RWKV6Attention',
48
+ # 'RWKV7Attention',
49
+ # 'PaTHAttention'
50
+ # ]
51
+ from .emdeltanet import emdeltanet
fla3/layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (204 Bytes). View file
 
fla3/layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (210 Bytes). View file
 
fla3/layers/__pycache__/abc.cpython-310.pyc ADDED
Binary file (5.55 kB). View file