msj19 commited on
Commit
ab9496c
·
verified ·
1 Parent(s): 31fa9bd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. opencompass/models/fla2/ops/abc/__init__.py +11 -0
  2. opencompass/models/fla2/ops/abc/chunk.py +1192 -0
  3. opencompass/models/fla2/ops/abc/chunk_gate.py +1333 -0
  4. opencompass/models/fla2/ops/abc/naive.py +96 -0
  5. opencompass/models/fla2/ops/abc/recurrent_fuse.py +490 -0
  6. opencompass/models/fla2/ops/based/__init__.py +9 -0
  7. opencompass/models/fla2/ops/based/chunk_fuse.py +389 -0
  8. opencompass/models/fla2/ops/based/naive.py +72 -0
  9. opencompass/models/fla2/ops/based/parallel.py +403 -0
  10. opencompass/models/fla2/ops/common/chunk_h.py +249 -0
  11. opencompass/models/fla2/ops/common/fused_recurrent.py +346 -0
  12. opencompass/models/fla2/ops/delta_rule/README.md +4 -0
  13. opencompass/models/fla2/ops/delta_rule/__init__.py +11 -0
  14. opencompass/models/fla2/ops/delta_rule/chunk.py +543 -0
  15. opencompass/models/fla2/ops/delta_rule/chunk_fuse.py +448 -0
  16. opencompass/models/fla2/ops/delta_rule/naive.py +97 -0
  17. opencompass/models/fla2/ops/delta_rule/recurrent_fuse.py +330 -0
  18. opencompass/models/fla2/ops/delta_rule/utils.py +292 -0
  19. opencompass/models/fla2/ops/delta_rule/wy_fast.py +374 -0
  20. opencompass/models/fla2/ops/generalized_delta_rule/README.md +37 -0
  21. opencompass/models/fla2/ops/generalized_delta_rule/__init__.py +9 -0
  22. opencompass/models/fla2/ops/generalized_delta_rule/dplr/__init__.py +7 -0
  23. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk.py +364 -0
  24. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +365 -0
  25. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +196 -0
  26. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +173 -0
  27. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +173 -0
  28. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +428 -0
  29. opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +123 -0
  30. opencompass/models/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py +273 -0
  31. opencompass/models/fla2/ops/generalized_delta_rule/dplr/naive.py +96 -0
  32. opencompass/models/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +164 -0
  33. opencompass/models/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +284 -0
  34. opencompass/models/fla2/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  35. opencompass/models/fla2/ops/generalized_delta_rule/iplr/chunk.py +500 -0
  36. opencompass/models/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py +452 -0
  37. opencompass/models/fla2/ops/generalized_delta_rule/iplr/naive.py +69 -0
  38. opencompass/models/fla2/ops/generalized_delta_rule/iplr/wy_fast.py +300 -0
  39. opencompass/models/fla2/ops/gla/__init__.py +11 -0
  40. opencompass/models/fla2/ops/gla/chunk.py +491 -0
  41. opencompass/models/fla2/ops/gla/chunk_fuse.py +575 -0
  42. opencompass/models/fla2/ops/gla/chunk_util.py +125 -0
  43. opencompass/models/fla2/ops/gla/naive.py +116 -0
  44. opencompass/models/fla2/ops/gla/recurrent_fuse.py +27 -0
  45. opencompass/models/fla2/ops/hgrn/__init__.py +9 -0
  46. opencompass/models/fla2/ops/hgrn/chunk.py +290 -0
  47. opencompass/models/fla2/ops/hgrn/naive.py +63 -0
  48. opencompass/models/fla2/ops/hgrn/recurrent_fuse.py +182 -0
  49. opencompass/models/fla2/ops/linear_attn/__init__.py +11 -0
  50. opencompass/models/fla2/ops/linear_attn/chunk.py +361 -0
opencompass/models/fla2/ops/abc/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_abc
4
+ from .chunk_gate import chunk_gated_abc
5
+ from .recurrent_fuse import fused_recurrent_gated_abc
6
+
7
+ __all__ = [
8
+ 'chunk_abc',
9
+ 'chunk_gated_abc',
10
+ 'fused_recurrent_gated_abc'
11
+ ]
opencompass/models/fla2/ops/abc/chunk.py ADDED
@@ -0,0 +1,1192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ...ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel,
12
+ softmax_fwd_kernel)
13
+ from ...utils import contiguous
14
+
15
+
16
+ @triton.jit
17
+ def chunk_abc_fwd_kernel_h(
18
+ k,
19
+ v,
20
+ z,
21
+ h,
22
+ h0,
23
+ ht,
24
+ s_k_h,
25
+ s_k_t,
26
+ s_k_d,
27
+ s_v_h,
28
+ s_v_t,
29
+ s_v_d,
30
+ s_h_h,
31
+ s_h_t,
32
+ s_h_d,
33
+ T: tl.constexpr,
34
+ K: tl.constexpr,
35
+ V: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ NT: tl.constexpr,
40
+ NORMK: tl.constexpr,
41
+ USE_INITIAL_STATE: tl.constexpr,
42
+ STORE_FINAL_STATE: tl.constexpr
43
+ ):
44
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
45
+
46
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
47
+ if USE_INITIAL_STATE:
48
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
49
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
50
+ if NORMK:
51
+ p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,))
52
+ else:
53
+ p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,))
54
+ b_zp = tl.load(p_z0).to(tl.float32)
55
+ for i_t in range(NT):
56
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
57
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
58
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
59
+
60
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
61
+ # [BK, BT]
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ # [BT, BV]
64
+ b_v = tl.load(p_v, boundary_check=(0, 1))
65
+ if NORMK:
66
+ p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
67
+ # [BK,]
68
+ b_zc = tl.load(p_zc, boundary_check=(0,))
69
+ b_r, b_zp = tl.exp(b_zp - b_zc), b_zc
70
+ # [BK, BV]
71
+ b_h = b_h * b_r[:, None]
72
+ b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype)
73
+ else:
74
+ p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
75
+ # [BV,]
76
+ b_zc = tl.load(p_zc, boundary_check=(0,))
77
+ b_r, b_zp = tl.exp(b_zp - b_zc), b_zc
78
+ # [BK, BV]
79
+ b_h = b_h * b_r[None, :]
80
+ b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)
81
+ # [BK, BV]
82
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
83
+
84
+ if STORE_FINAL_STATE:
85
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
86
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
87
+
88
+
89
+ @triton.jit
90
+ def chunk_abc_fwd_kernel_intra_K(
91
+ v,
92
+ z,
93
+ o,
94
+ A,
95
+ s_v_h,
96
+ s_v_t,
97
+ s_v_d,
98
+ T: tl.constexpr,
99
+ V: tl.constexpr,
100
+ BT: tl.constexpr,
101
+ BC: tl.constexpr,
102
+ BV: tl.constexpr,
103
+ NC: tl.constexpr
104
+ ):
105
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
106
+ i_t, i_i = i_c // NC, i_c % NC
107
+
108
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
109
+ p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
110
+ # [BV,]
111
+ b_zn = tl.load(p_zn, boundary_check=(0,))
112
+ # [BC, BV]
113
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
114
+ for i_j in range(0, i_i):
115
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
116
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
117
+ # [BC, BV]
118
+ b_v = tl.load(p_v, boundary_check=(0, 1))
119
+ # [BC, BC]
120
+ b_A = tl.load(p_A, boundary_check=(0, 1))
121
+ b_o += tl.dot(b_A, tl.exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)
122
+ b_z = tl.load(p_z, boundary_check=(0, 1))
123
+ b_o *= tl.exp(b_zn[None, :] - b_z)
124
+
125
+ o_i = tl.arange(0, BC)
126
+ o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
127
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
128
+ for j in range(0, BC):
129
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
130
+ # [BC,]
131
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
132
+ # [BV,]
133
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
134
+ # [BC, BV]
135
+ # avoid 0 * inf = inf
136
+ m_i = o_i[:, None] >= j
137
+ b_o += tl.where(m_i, b_A[:, None] * tl.exp(b_v[None, :] - b_z), 0)
138
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
139
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
140
+
141
+
142
+ @triton.jit
143
+ def chunk_abc_fwd_kernel_K(
144
+ q,
145
+ k,
146
+ z,
147
+ h,
148
+ o,
149
+ A,
150
+ s_k_h,
151
+ s_k_t,
152
+ s_k_d,
153
+ s_v_h,
154
+ s_v_t,
155
+ s_v_d,
156
+ s_h_h,
157
+ s_h_t,
158
+ s_h_d,
159
+ scale,
160
+ T: tl.constexpr,
161
+ K: tl.constexpr,
162
+ V: tl.constexpr,
163
+ BT: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr
166
+ ):
167
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
168
+ i_p = tl.maximum(i_t * BT - 1, 0)
169
+
170
+ o_i = tl.arange(0, BT)
171
+ m_s = o_i[:, None] >= o_i[None, :]
172
+
173
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
174
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
175
+ for i_k in range(tl.cdiv(K, BK)):
176
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
177
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
178
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
179
+
180
+ # [BT, BK]
181
+ b_q = tl.load(p_q, boundary_check=(0, 1))
182
+ b_q = (b_q * scale).to(b_q.dtype)
183
+ # [BK, BT]
184
+ b_k = tl.load(p_k, boundary_check=(0, 1))
185
+ # [BK, BV]
186
+ b_h = tl.load(p_h, boundary_check=(0, 1))
187
+ # [BT, BV]
188
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
189
+ # [BT, BT]
190
+ b_A += tl.dot(b_q, b_k, allow_tf32=False)
191
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
192
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
193
+ # [BT, BV]
194
+ b_z = tl.load(p_z, boundary_check=(0, 1))
195
+ # [BT, BV]
196
+ p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,))
197
+ b_zp = tl.load(p_zp, boundary_check=(0,))
198
+ b_o = b_o * tl.exp(b_zp[None, :] - b_z)
199
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
200
+
201
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
202
+ # [BT, BT]
203
+ b_A = tl.where(m_s, b_A, 0.)
204
+ if i_v == 0:
205
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
206
+
207
+
208
+ @triton.jit
209
+ def chunk_abc_fwd_kernel_intra_V(
210
+ q,
211
+ k,
212
+ z,
213
+ A,
214
+ s_k_h,
215
+ s_k_t,
216
+ s_k_d,
217
+ scale,
218
+ T: tl.constexpr,
219
+ K: tl.constexpr,
220
+ BT: tl.constexpr,
221
+ BC: tl.constexpr,
222
+ BK: tl.constexpr,
223
+ NC: tl.constexpr
224
+ ):
225
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
226
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
227
+ n_bh = tl.num_programs(2)
228
+
229
+ if i_i > i_j:
230
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
231
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
232
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
233
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
234
+ p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
235
+ # [BK,]
236
+ b_zn = tl.load(p_zn, boundary_check=(0,))
237
+ # [BC, BK]
238
+ b_q = tl.load(p_q, boundary_check=(0, 1))
239
+ b_z = tl.load(p_z, boundary_check=(0, 1))
240
+ b_q = (b_q * tl.exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype)
241
+ # [BK, BC]
242
+ b_k = tl.load(p_k, boundary_check=(0, 1))
243
+ b_k = tl.exp(b_k - b_zn[:, None]).to(b_k.dtype)
244
+ # [BC, BC]
245
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
246
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
247
+ elif i_i == i_j:
248
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
249
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
250
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
251
+ # [BC, BK]
252
+ b_q = tl.load(p_q, boundary_check=(0, 1))
253
+ b_z = tl.load(p_z, boundary_check=(0, 1))
254
+
255
+ o_i = tl.arange(0, BC)
256
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
257
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
258
+ for j in range(0, BC):
259
+ # [BK,]
260
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
261
+ # [BC,]
262
+ b_A = tl.sum(b_q * tl.exp(b_k[None, :] - b_z) * scale, 1)
263
+ b_A = tl.where(o_i >= j, b_A, 0.)
264
+ tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
265
+
266
+ p_k = tl.advance(p_k, (K,))
267
+
268
+
269
+ @triton.jit
270
+ def chunk_abc_fwd_kernel_V(
271
+ q,
272
+ v,
273
+ z,
274
+ h,
275
+ o,
276
+ A,
277
+ s_k_h,
278
+ s_k_t,
279
+ s_k_d,
280
+ s_v_h,
281
+ s_v_t,
282
+ s_v_d,
283
+ s_h_h,
284
+ s_h_t,
285
+ s_h_d,
286
+ scale,
287
+ T: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr
293
+ ):
294
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
295
+ i_p = tl.maximum(i_t * BT - 1, 0)
296
+
297
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
298
+ for i_k in range(tl.cdiv(K, BK)):
299
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
300
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
301
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
302
+ p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,))
303
+
304
+ # [BT, BK]
305
+ b_q = tl.load(p_q, boundary_check=(0, 1))
306
+ b_q = (b_q * scale).to(b_q.dtype)
307
+ # [BT, BK]
308
+ b_z = tl.load(p_z, boundary_check=(0, 1))
309
+ # [BT, BK]
310
+ b_zp = tl.load(p_zp, boundary_check=(0,))
311
+ b_q = (b_q * tl.exp(b_zp[None, :] - b_z)).to(b_q.dtype)
312
+ # [BK, BV]
313
+ b_h = tl.load(p_h, boundary_check=(0, 1))
314
+ # works but dkw, owing to divine benevolence
315
+ # [BT, BV]
316
+ if i_k >= 0:
317
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
318
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
319
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
320
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
321
+ # [BT, BV]
322
+ b_v = tl.load(p_v, boundary_check=(0, 1))
323
+ # [BT, BT]
324
+ b_A = tl.load(p_A, boundary_check=(0, 1))
325
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
326
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
327
+
328
+
329
+ @triton.jit
330
+ def chunk_abc_bwd_kernel_dh(
331
+ q,
332
+ z,
333
+ do,
334
+ dh,
335
+ s_k_h,
336
+ s_k_t,
337
+ s_k_d,
338
+ s_v_h,
339
+ s_v_t,
340
+ s_v_d,
341
+ s_h_h,
342
+ s_h_t,
343
+ s_h_d,
344
+ scale,
345
+ T: tl.constexpr,
346
+ K: tl.constexpr,
347
+ V: tl.constexpr,
348
+ BT: tl.constexpr,
349
+ BK: tl.constexpr,
350
+ BV: tl.constexpr,
351
+ NT: tl.constexpr,
352
+ NORMK: tl.constexpr
353
+ ):
354
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
355
+
356
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
357
+ b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32)
358
+ for i_t in range(NT - 1, -1, -1):
359
+ i_p = tl.maximum(i_t * BT - 1, 0)
360
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
361
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
362
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
363
+
364
+ # [BK, BT]
365
+ b_q = tl.load(p_q, boundary_check=(0, 1))
366
+ b_q = (b_q * scale).to(b_q.dtype)
367
+ # [BT, BV]
368
+ b_do = tl.load(p_do, boundary_check=(0, 1))
369
+
370
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
371
+ if NORMK:
372
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
373
+ p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,))
374
+ # [BK,]
375
+ b_zc = tl.load(p_zc, boundary_check=(0,))
376
+ b_r, b_zp = tl.exp(b_zc - b_zp), b_zc
377
+ # [BK, BT]
378
+ b_z = tl.load(p_z, boundary_check=(0, 1))
379
+ b_q = (b_q * tl.exp(b_zc[:, None] - b_z)).to(b_q.dtype)
380
+ # [BK, BV]
381
+ b_dh = b_dh * b_r[:, None]
382
+ else:
383
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
384
+ p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,))
385
+ # [BV,]
386
+ b_zc = tl.load(p_zc, boundary_check=(0,))
387
+ b_r, b_zp = tl.exp(b_zc - b_zp), b_zc
388
+ # [BT, BV]
389
+ b_z = tl.load(p_z, boundary_check=(0,))
390
+ b_do = (b_do * tl.exp(b_zc[None, :] - b_z)).to(b_do.dtype)
391
+ # [BK, BV]
392
+ b_dh = b_dh * b_r[None, :]
393
+ # [BK, BV]
394
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
395
+
396
+
397
+ @triton.jit
398
+ def chunk_abc_bwd_kernel_V(
399
+ k,
400
+ v,
401
+ z,
402
+ h,
403
+ A,
404
+ do,
405
+ dh,
406
+ dq,
407
+ dk,
408
+ dv,
409
+ dA,
410
+ s_k_h,
411
+ s_k_t,
412
+ s_k_d,
413
+ s_v_h,
414
+ s_v_t,
415
+ s_v_d,
416
+ s_h_h,
417
+ s_h_t,
418
+ s_h_d,
419
+ scale,
420
+ T: tl.constexpr,
421
+ K: tl.constexpr,
422
+ V: tl.constexpr,
423
+ BT: tl.constexpr,
424
+ BK: tl.constexpr,
425
+ BV: tl.constexpr
426
+ ):
427
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
428
+ i_p = tl.maximum(i_t * BT - 1, 0)
429
+ n_bh = tl.num_programs(2)
430
+
431
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
432
+ p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
433
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
434
+
435
+ # [BK,]
436
+ b_zc = tl.load(p_zc, boundary_check=(0,))
437
+ # [BT, BK]
438
+ b_k = tl.load(p_k, boundary_check=(0, 1))
439
+ b_k = tl.exp(b_k - b_zc[None, :]).to(b_k.dtype)
440
+ # [BT, BT]
441
+ b_A = tl.load(p_A, boundary_check=(0, 1))
442
+
443
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
444
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
445
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
446
+ for i_v in range(tl.cdiv(V, BV)):
447
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
448
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
449
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
450
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
451
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
452
+
453
+ # [BT, BV]
454
+ b_v = tl.load(p_v, boundary_check=(0, 1))
455
+ # [BV, BK]
456
+ b_h = tl.load(p_h, boundary_check=(0, 1))
457
+ # [BT, BV]
458
+ b_do = tl.load(p_do, boundary_check=(0, 1))
459
+ # [BK, BV]
460
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
461
+
462
+ # [BT, BV]
463
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
464
+ if i_k == 0:
465
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
466
+ b_do = (b_do * scale).to(b_do.dtype)
467
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
468
+ # [BT, BT]
469
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
470
+ # [BT, BK]
471
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
472
+ # [BT, BK]
473
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
474
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
475
+ p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,))
476
+ # [BK,]
477
+ b_zp = tl.load(p_zp, boundary_check=(0,))
478
+ # [BT, BK]
479
+ b_z = tl.load(p_z, boundary_check=(0, 1))
480
+ b_z = tl.exp(b_zp[None, :] - b_z)
481
+ # [BT, BK]
482
+ b_dq = b_dq * b_z
483
+ b_dk = b_dk * b_k
484
+
485
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
486
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
487
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
488
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
489
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
490
+
491
+ o_i = tl.arange(0, BT)
492
+ m_s = o_i[:, None] >= o_i[None, :]
493
+ # [BT, BT]
494
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
495
+ if i_k == 0:
496
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
497
+
498
+
499
+ @triton.jit
500
+ def chunk_abc_bwd_kernel_intra_V(
501
+ q,
502
+ k,
503
+ z,
504
+ dA,
505
+ dq,
506
+ dk,
507
+ s_k_h,
508
+ s_k_t,
509
+ s_k_d,
510
+ T: tl.constexpr,
511
+ K: tl.constexpr,
512
+ BT: tl.constexpr,
513
+ BC: tl.constexpr,
514
+ BK: tl.constexpr,
515
+ NC: tl.constexpr
516
+ ):
517
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
518
+ i_t, i_i = i_c // NC, i_c % NC
519
+
520
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
521
+ p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
522
+ # [BK,]
523
+ b_zn = tl.load(p_zn, boundary_check=(0,))
524
+ # [BC, BK]
525
+ b_z = tl.load(p_z, boundary_check=(0, 1))
526
+ b_zq = tl.exp(b_zn[None, :] - b_z)
527
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
528
+ for i_j in range(0, i_i):
529
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
530
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
531
+ # [BC, BK]
532
+ b_k = tl.load(p_k, boundary_check=(0, 1))
533
+ b_kz = tl.exp(b_k - b_zn[None, :]).to(b_k.dtype)
534
+ # [BC, BC]
535
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
536
+ # [BC, BK]
537
+ b_dq += tl.dot(b_dA, b_kz, allow_tf32=False)
538
+ b_dq *= b_zq
539
+
540
+ o_i = tl.arange(0, BC)
541
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
542
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
543
+ for j in range(0, BC):
544
+ p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
545
+ # [BC,]
546
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
547
+ # [BK,]
548
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
549
+ # [BC, BK]
550
+ m_i = o_i[:, None] >= j
551
+ # [BC, BK]
552
+ b_dq += tl.where(m_i, b_dA[:, None] * tl.exp(b_kj[None, :] - b_z), 0.)
553
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
554
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
555
+
556
+ tl.debug_barrier()
557
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
558
+ p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
559
+ # [BK,]
560
+ b_zn = tl.load(p_zn, boundary_check=(0,))
561
+ # [BC, BK]
562
+ b_k = tl.load(p_k, boundary_check=(0, 1))
563
+ b_kz = tl.exp(b_k - b_zn[None, :])
564
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
565
+ for i_j in range(i_i + 1, NC):
566
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
567
+ p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
568
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
569
+ # [BC, BK]
570
+ b_q = tl.load(p_q, boundary_check=(0, 1))
571
+ b_z = tl.load(p_z, boundary_check=(0, 1))
572
+ b_qz = (b_q * tl.exp(b_zn[None, :] - b_z)).to(b_q.dtype)
573
+ # [BC, BC]
574
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
575
+ # [BC, BK]
576
+ b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False)
577
+ b_dk *= b_kz
578
+
579
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
580
+ for j in range(0, BC):
581
+ p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
582
+ p_zj = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
583
+ # [BC,]
584
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
585
+ # [BK,]
586
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
587
+ b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32)
588
+ # [BC, BK]
589
+ m_i = o_i[:, None] <= j
590
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_k - b_zj[None, :]), 0.)
591
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
592
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
593
+
594
+
595
+ @triton.jit
596
+ def chunk_abc_bwd_kernel_intra_K(
597
+ v,
598
+ z,
599
+ do,
600
+ dA,
601
+ s_v_h,
602
+ s_v_t,
603
+ s_v_d,
604
+ scale,
605
+ T: tl.constexpr,
606
+ V: tl.constexpr,
607
+ BT: tl.constexpr,
608
+ BC: tl.constexpr,
609
+ BV: tl.constexpr,
610
+ NC: tl.constexpr
611
+ ):
612
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
613
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
614
+ n_bh = tl.num_programs(2)
615
+
616
+ if i_i > i_j:
617
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
618
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
619
+ p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
620
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
621
+ p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
622
+ # [BV,]
623
+ b_zn = tl.load(p_zn, boundary_check=(0,))
624
+ # [BC, BV]
625
+ b_z = tl.load(p_z, boundary_check=(0, 1))
626
+ b_do = tl.load(p_do, boundary_check=(0, 1))
627
+ b_do = (b_do * tl.exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype)
628
+ # [BV, BC]
629
+ b_v = tl.load(p_v, boundary_check=(0, 1))
630
+ b_v = tl.exp(b_v - b_zn[:, None]).to(b_v.dtype)
631
+ # [BC, BC]
632
+ b_dA = tl.dot(b_do, b_v, allow_tf32=False)
633
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
634
+ elif i_i == i_j:
635
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
636
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
637
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
638
+ # [BC, BV]
639
+ b_z = tl.load(p_z, boundary_check=(0, 1))
640
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
641
+
642
+ o_i = tl.arange(0, BC)
643
+ o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
644
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
645
+ for j in range(0, BC):
646
+ # [BV,]
647
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
648
+ # [BC,]
649
+ b_dA = tl.sum(b_do * tl.exp(b_v[None, :] - b_z), 1)
650
+ b_dA = tl.where(o_i >= j, b_dA, 0)
651
+ tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A)
652
+
653
+ p_v = tl.advance(p_v, (V,))
654
+
655
+
656
+ @triton.jit
657
+ def chunk_abc_bwd_kernel_K(
658
+ q,
659
+ k,
660
+ v,
661
+ z,
662
+ h,
663
+ A,
664
+ do,
665
+ dh,
666
+ dq,
667
+ dk,
668
+ dv,
669
+ dA,
670
+ s_k_h,
671
+ s_k_t,
672
+ s_k_d,
673
+ s_v_h,
674
+ s_v_t,
675
+ s_v_d,
676
+ s_h_h,
677
+ s_h_t,
678
+ s_h_d,
679
+ scale,
680
+ T: tl.constexpr,
681
+ K: tl.constexpr,
682
+ V: tl.constexpr,
683
+ BT: tl.constexpr,
684
+ BK: tl.constexpr,
685
+ BV: tl.constexpr
686
+ ):
687
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
688
+ i_p = tl.maximum(i_t * BT - 1, 0)
689
+ n_bh = tl.num_programs(2)
690
+
691
+ o_i = tl.arange(0, BT)
692
+ m_s = o_i[:, None] >= o_i[None, :]
693
+
694
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
695
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
696
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
697
+
698
+ # [BT, BK]
699
+ b_q = tl.load(p_q, boundary_check=(0, 1))
700
+ b_k = tl.load(p_k, boundary_check=(0, 1))
701
+ # [BT, BT]
702
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False)
703
+ b_A = tl.where(m_s, b_A, 0.)
704
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
705
+
706
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
707
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
708
+ for i_v in range(tl.cdiv(V, BV)):
709
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
710
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
711
+ p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,))
712
+ p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
713
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K*V, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
714
+
715
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
716
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
717
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
718
+
719
+ # [BV,]
720
+ b_zp = tl.load(p_zp, boundary_check=(0,))
721
+ b_zc = tl.load(p_zc, boundary_check=(0,))
722
+ # [BT, BV]
723
+ b_v = tl.load(p_v, boundary_check=(0, 1))
724
+ b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype)
725
+ b_z = tl.load(p_z, boundary_check=(0, 1))
726
+ b_z = tl.exp(b_zp[None, :] - b_z)
727
+ # [BV, BK]
728
+ b_h = tl.load(p_h, boundary_check=(0, 1))
729
+ # [BT, BV]
730
+ b_do = tl.load(p_do, boundary_check=(0, 1))
731
+ b_do = (b_do * b_z * scale).to(b_do.dtype)
732
+ # [BK, BV]
733
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
734
+
735
+ # [BT, BK]
736
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
737
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
738
+ # [BT, BV]
739
+ b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False)
740
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
741
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
742
+ # [BT, BT]
743
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
744
+ # [BT, BK]
745
+ b_dq += tl.dot(b_dA, b_k, allow_tf32=False)
746
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False)
747
+
748
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
749
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
750
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
751
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
752
+
753
+
754
+ @triton.jit
755
+ def chunk_abc_bwd_kernel_intra_KV(
756
+ v,
757
+ z,
758
+ A,
759
+ do,
760
+ dv,
761
+ s_v_h,
762
+ s_v_t,
763
+ s_v_d,
764
+ T: tl.constexpr,
765
+ V: tl.constexpr,
766
+ BT: tl.constexpr,
767
+ BC: tl.constexpr,
768
+ BV: tl.constexpr,
769
+ NC: tl.constexpr
770
+ ):
771
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
772
+ i_t, i_i = i_c // NC, i_c % NC
773
+
774
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
775
+ p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T*V,), (s_v_d,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,))
776
+ # [BV,]
777
+ b_zn = tl.load(p_zn, boundary_check=(0,))
778
+ # [BC, BV]
779
+ b_v = tl.load(p_v, boundary_check=(0, 1))
780
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
781
+ for i_j in range(i_i + 1, NC):
782
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
783
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
784
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
785
+ # [BC, BV]
786
+ b_z = tl.load(p_z, boundary_check=(0, 1))
787
+ b_do = tl.load(p_do, boundary_check=(0, 1))
788
+ b_do = (b_do * tl.exp(b_zn[None, :] - b_z)).to(b_do.dtype)
789
+ # [BC, BC]
790
+ b_A = tl.load(p_A, boundary_check=(0, 1))
791
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
792
+ b_dv *= tl.exp(b_v - b_zn[None, :])
793
+
794
+ o_i = tl.arange(0, BC)
795
+ for j in range(0, BC):
796
+ p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
797
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,))
798
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
799
+ # [BC,]
800
+ b_A = tl.load(p_A, boundary_check=(0,))
801
+ # [BV,]
802
+ b_z = tl.load(p_z, boundary_check=(0,))
803
+ b_do = tl.load(p_do, boundary_check=(0,))
804
+ # [BC, BV]
805
+ m_i = o_i[:, None] <= j
806
+ b_dv += tl.where(m_i, tl.exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.)
807
+ p_dv = tl.make_block_ptr(dv + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
808
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
809
+
810
+
811
+ @triton.jit
812
+ def chunk_abc_bwd_kernel_rcum_inter(
813
+ s,
814
+ z,
815
+ ss,
816
+ doo,
817
+ s_s_h,
818
+ s_s_t,
819
+ s_s_d,
820
+ T: tl.constexpr,
821
+ S: tl.constexpr,
822
+ BT: tl.constexpr,
823
+ BS: tl.constexpr,
824
+ NT: tl.constexpr
825
+ ):
826
+ i_m, i_bh = tl.program_id(0), tl.program_id(1)
827
+
828
+ b_sp = tl.zeros([BS,], dtype=tl.float32)
829
+ b_zp = tl.full([BS,], float('inf'), dtype=tl.float32)
830
+ for i_t in range(NT - 1, -1, -1):
831
+ p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
832
+ p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
833
+ p_zc = tl.make_block_ptr(z + i_bh * s_s_h, (T * S,), (s_s_d,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,))
834
+ p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
835
+ p_doo = tl.make_block_ptr(doo + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
836
+ # [BS,]
837
+ b_zc = tl.load(p_zc, boundary_check=(0,))
838
+ # [BT, BS]
839
+ b_s = tl.load(p_s, boundary_check=(0, 1))
840
+ b_z = tl.load(p_z, boundary_check=(0, 1))
841
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
842
+
843
+ b_doo = tl.exp(b_s - b_zp[None, :]) * b_sp[None, :]
844
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
845
+ # [BS,]
846
+ b_sp = b_sp * tl.exp(b_zc - b_zp) + tl.sum(b_ss * tl.exp(b_zc[None, :] - b_z), 0)
847
+ b_zp = b_zc
848
+
849
+
850
+ @triton.jit
851
+ def chunk_abc_bwd_kernel_rcum_intra(
852
+ s,
853
+ z,
854
+ ss,
855
+ doo,
856
+ s_s_h,
857
+ s_s_t,
858
+ s_s_d,
859
+ T: tl.constexpr,
860
+ S: tl.constexpr,
861
+ BT: tl.constexpr,
862
+ BC: tl.constexpr,
863
+ BS: tl.constexpr,
864
+ NC: tl.constexpr
865
+ ):
866
+ i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
867
+ i_t, i_i = i_c // NC, i_c % NC
868
+
869
+ o_i = tl.arange(0, BC)
870
+ m_o = tl.full([BC, BC], 1., dtype=tl.float32)
871
+
872
+ p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
873
+ p_zn = tl.make_block_ptr(z + i_bh * s_s_h, (T*S,), (s_s_d,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,))
874
+ p_doo = tl.make_block_ptr(doo + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
875
+ # [BC, BS]
876
+ b_s = tl.load(p_s, boundary_check=(0, 1))
877
+ # [BS,]
878
+ b_zn = tl.load(p_zn, boundary_check=(0,))
879
+
880
+ b_doo = tl.zeros([BC, BS], dtype=tl.float32)
881
+ for i_j in range(i_i + 1, NC):
882
+ p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
883
+ p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
884
+ # [BC, BS]
885
+ b_z = tl.load(p_z, boundary_check=(0, 1))
886
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
887
+ # [BC, BS]
888
+ b_doo += b_ss * tl.exp(b_zn[None, :] - b_z)
889
+ b_doo = tl.exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False)
890
+
891
+ for j in range(0, BC):
892
+ p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T * S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
893
+ p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T * S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
894
+ # [BS,]
895
+ b_z = tl.load(p_z, boundary_check=(0,))
896
+ b_ss = tl.load(p_ss, boundary_check=(0,))
897
+ # [BC, BS]
898
+ m_i = o_i[:, None] <= j
899
+ b_doo += tl.where(m_i, tl.exp(b_s - b_z[None, :]) * b_ss[None, :], 0.)
900
+ b_doo += tl.load(p_doo, boundary_check=(0, 1))
901
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
902
+
903
+
904
+ class ChunkABCFunction(torch.autograd.Function):
905
+
906
+ @staticmethod
907
+ @contiguous
908
+ def forward(ctx, q, k, v, s, initial_state, output_final_state):
909
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
910
+ BT, BC = 64, 16
911
+ BK = min(64, triton.next_power_of_2(K))
912
+ BV = min(64, triton.next_power_of_2(V))
913
+ BM = min(64, triton.next_power_of_2(M))
914
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
915
+ NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM)
916
+ num_warps = 4 if BK == 64 else 2
917
+ num_stages = 1
918
+
919
+ def fwd_pre(s, B, H, T, S):
920
+ # keep cummulative normalizer in fp32
921
+ z = torch.empty_like(s, dtype=torch.float)
922
+ grid = (B * H,)
923
+ logcumsumexp_fwd_kernel[grid](
924
+ s, z,
925
+ s.stride(1), s.stride(2), s.stride(3),
926
+ T=T, S=S
927
+ )
928
+ return z
929
+
930
+ def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None):
931
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
932
+ h = q.new_empty(B, H, NT * K, V)
933
+ grid = (NV, NK, B * H)
934
+ chunk_abc_fwd_kernel_h[grid](
935
+ k, v, z, h, h0, ht,
936
+ k.stride(1), k.stride(2), k.stride(3),
937
+ v.stride(1), v.stride(2), v.stride(3),
938
+ h.stride(1), h.stride(2), h.stride(3),
939
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
940
+ NORMK=normk,
941
+ USE_INITIAL_STATE=h0 is not None,
942
+ STORE_FINAL_STATE=ht is not None,
943
+ num_warps=num_warps,
944
+ num_stages=num_stages
945
+ )
946
+ return h
947
+
948
+ final_state = None
949
+ if output_final_state:
950
+ final_state = (q.new_empty(B, H, K, M, dtype=torch.float),
951
+ q.new_empty(B, H, M, V, dtype=torch.float))
952
+
953
+ z = fwd_pre(s, B, H, T, M)
954
+ scale = K ** -0.5
955
+ hk = fwd_inner(
956
+ q=q, k=k, v=s, z=z,
957
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
958
+ normk=False,
959
+ h0=initial_state[0] if initial_state is not None else None,
960
+ ht=final_state[0] if final_state is not None else None
961
+ )
962
+ ok1 = torch.empty_like(s)
963
+ Ak = q.new_empty(B, H, T, BT)
964
+ grid = (NM, NT, B * H)
965
+ chunk_abc_fwd_kernel_K[grid](
966
+ q, k, z, hk, ok1, Ak,
967
+ k.stride(1), k.stride(2), k.stride(3),
968
+ s.stride(1), s.stride(2), s.stride(3),
969
+ hk.stride(1), hk.stride(2), hk.stride(3),
970
+ scale=scale,
971
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM,
972
+ num_warps=num_warps,
973
+ num_stages=num_stages
974
+ )
975
+ ok0 = torch.empty_like(s)
976
+ grid = (NM, NT * NC, B * H)
977
+ chunk_abc_fwd_kernel_intra_K[grid](
978
+ s, z, ok0, Ak,
979
+ s.stride(1), s.stride(2), s.stride(3),
980
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
981
+ num_warps=2,
982
+ num_stages=num_stages
983
+ )
984
+ ok = ok0.add_(ok1)
985
+
986
+ scale = 1.
987
+ # equivalent to:
988
+ # p = ok.softmax(-1, torch.float)
989
+ # p is kept in fp32 for safe softmax backward
990
+ p = torch.empty_like(ok, dtype=torch.float)
991
+ grid = (NT, B * H)
992
+ softmax_fwd_kernel[grid](
993
+ ok, p,
994
+ s.stride(1), s.stride(2), s.stride(3),
995
+ T=T, S=M, BT=BT
996
+ )
997
+ qv = p.to(q.dtype)
998
+
999
+ scale = 1.
1000
+ hv = fwd_inner(
1001
+ q=qv, k=s, v=v, z=z,
1002
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
1003
+ normk=True,
1004
+ h0=initial_state[1] if initial_state is not None else None,
1005
+ ht=final_state[1] if final_state is not None else None
1006
+ )
1007
+ Av = q.new_zeros(NM, B, H, T, BT)
1008
+ grid = (NM, NT * NC * NC, B * H)
1009
+ chunk_abc_fwd_kernel_intra_V[grid](
1010
+ qv, s, z, Av,
1011
+ s.stride(1), s.stride(2), s.stride(3),
1012
+ scale=scale,
1013
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
1014
+ num_warps=2,
1015
+ num_stages=num_stages
1016
+ )
1017
+ Av = Av.sum(0)
1018
+ ov = torch.empty_like(v)
1019
+ grid = (NV, NT, B * H)
1020
+ chunk_abc_fwd_kernel_V[grid](
1021
+ qv, v, z, hv, ov, Av,
1022
+ s.stride(1), s.stride(2), s.stride(3),
1023
+ v.stride(1), v.stride(2), v.stride(3),
1024
+ hv.stride(1), hv.stride(2), hv.stride(3),
1025
+ scale=scale,
1026
+ T=T, K=M, V=V, BT=BT, BK=BM, BV=BV,
1027
+ num_warps=num_warps,
1028
+ num_stages=num_stages
1029
+ )
1030
+ ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av)
1031
+ ctx.BT = BT
1032
+ return ov, final_state
1033
+
1034
+ @staticmethod
1035
+ @contiguous
1036
+ def backward(ctx, dov, dht=None):
1037
+ q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors
1038
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
1039
+ BT, BC = ctx.BT, 16
1040
+ BK = min(64, triton.next_power_of_2(K))
1041
+ BV = min(64, triton.next_power_of_2(V))
1042
+ BM = min(64, triton.next_power_of_2(M))
1043
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
1044
+ NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM)
1045
+ num_warps = 4 if BK == 64 else 2
1046
+ num_stages = 1
1047
+
1048
+ def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False):
1049
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
1050
+ dh = q.new_empty(B, H, NT * K, V)
1051
+ grid = (NK, NV, B * H)
1052
+ chunk_abc_bwd_kernel_dh[grid](
1053
+ q, z, do, dh,
1054
+ q.stride(1), q.stride(2), q.stride(3),
1055
+ do.stride(1), do.stride(2), do.stride(3),
1056
+ dh.stride(1), dh.stride(2), dh.stride(3),
1057
+ scale=scale,
1058
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
1059
+ NORMK=normk,
1060
+ num_warps=num_warps,
1061
+ num_stages=num_stages
1062
+ )
1063
+ return dh
1064
+
1065
+ def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS):
1066
+ doo = torch.empty_like(s)
1067
+ grid = (NS, B * H)
1068
+ chunk_abc_bwd_kernel_rcum_inter[grid](
1069
+ s, z, ss, doo,
1070
+ s.stride(1), s.stride(2), s.stride(3),
1071
+ T=T, S=S, BT=BT, BS=BS, NT=NT,
1072
+ num_warps=num_warps,
1073
+ num_stages=num_stages
1074
+ )
1075
+ grid = (NS, NT * NC, B * H)
1076
+ chunk_abc_bwd_kernel_rcum_intra[grid](
1077
+ s, z, ss, doo,
1078
+ s.stride(1), s.stride(2), s.stride(3),
1079
+ T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC,
1080
+ num_warps=num_warps,
1081
+ num_stages=num_stages
1082
+ )
1083
+ return doo
1084
+
1085
+ scale = 1.
1086
+ qv = p.to(q.dtype)
1087
+ dhv = bwd_inner(
1088
+ qv, z, dov,
1089
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
1090
+ scale=scale,
1091
+ normk=True
1092
+ )
1093
+ dp1 = torch.empty_like(p)
1094
+ dsv1 = torch.empty_like(s, dtype=torch.float)
1095
+ dv = v.new_empty(NM, *v.shape)
1096
+ dAv = q.new_zeros(B, H, T, BT)
1097
+ grid = (NM, NT, B * H)
1098
+ chunk_abc_bwd_kernel_V[grid](
1099
+ s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv,
1100
+ s.stride(1), s.stride(2), s.stride(3),
1101
+ v.stride(1), v.stride(2), v.stride(3),
1102
+ hv.stride(1), hv.stride(2), hv.stride(3),
1103
+ scale=scale,
1104
+ T=T, K=M, V=V, BT=BT, BK=BM, BV=BV,
1105
+ num_warps=num_warps,
1106
+ num_stages=num_stages
1107
+ )
1108
+ dv = dv.sum(0)
1109
+ dp0 = torch.empty_like(p)
1110
+ dsv0 = s.new_zeros(s.shape, dtype=torch.float)
1111
+ grid = (NM, NT * NC, B * H)
1112
+ chunk_abc_bwd_kernel_intra_V[grid](
1113
+ qv, s, z, dAv, dp0, dsv0,
1114
+ s.stride(1), s.stride(2), s.stride(3),
1115
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
1116
+ num_warps=2,
1117
+ num_stages=num_stages
1118
+ )
1119
+ dp = dp1.add_(dp0)
1120
+ dsv = dsv1.add_(dsv0)
1121
+
1122
+ # softmax gradient, equivalent to:
1123
+ # dok = p * (dp - (p * dp).sum(-1, True))
1124
+ dok = torch.empty_like(ok)
1125
+ grid = (NT, B * H)
1126
+ softmax_bwd_kernel[grid](
1127
+ p, dp, dok,
1128
+ s.stride(1), s.stride(2), s.stride(3),
1129
+ T=T, S=M, BT=BT
1130
+ )
1131
+
1132
+ scale = K ** -0.5
1133
+ dhk = bwd_inner(
1134
+ q, z, dok,
1135
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1136
+ scale=scale,
1137
+ normk=False
1138
+ )
1139
+ dAk = q.new_zeros(NM, B, H, T, BT)
1140
+ grid = (NM, NT * NC * NC, B * H)
1141
+ chunk_abc_bwd_kernel_intra_K[grid](
1142
+ s, z, dok, dAk,
1143
+ s.stride(1), s.stride(2), s.stride(3),
1144
+ scale=scale,
1145
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1146
+ num_warps=2,
1147
+ num_stages=num_stages
1148
+ )
1149
+ dAk = dAk.sum(0)
1150
+
1151
+ Ak = q.new_zeros(NK, B, H, T, BT)
1152
+ dq = torch.empty_like(q)
1153
+ dk = torch.empty_like(k)
1154
+ dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float)
1155
+ grid = (NK, NT, B * H)
1156
+ chunk_abc_bwd_kernel_K[grid](
1157
+ q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk,
1158
+ q.stride(1), q.stride(2), q.stride(3),
1159
+ s.stride(1), s.stride(2), s.stride(3),
1160
+ hk.stride(1), hk.stride(2), hk.stride(3),
1161
+ scale=scale,
1162
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM,
1163
+ num_warps=num_warps,
1164
+ num_stages=num_stages
1165
+ )
1166
+ Ak = Ak.sum(0)
1167
+ dsk1 = dsk1.sum(0)
1168
+ dsk0 = torch.empty_like(s, dtype=torch.float)
1169
+ grid = (NM, NT * NC, B * H)
1170
+ chunk_abc_bwd_kernel_intra_KV[grid](
1171
+ s, z, Ak, dok, dsk0,
1172
+ s.stride(1), s.stride(2), s.stride(3),
1173
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1174
+ num_warps=2,
1175
+ num_stages=num_stages
1176
+ )
1177
+ ds = dsv.add_(dsk1.add_(dsk0))
1178
+ ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM)
1179
+ ds = ds.to(s.dtype)
1180
+ return dq, dk, dv, ds, None, None
1181
+
1182
+
1183
+ def chunk_abc(
1184
+ q: torch.Tensor,
1185
+ k: torch.Tensor,
1186
+ v: torch.Tensor,
1187
+ s: torch.Tensor,
1188
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1189
+ output_final_state: Optional[bool] = False
1190
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
1191
+ ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)
1192
+ return ov, final_state
opencompass/models/fla2/ops/abc/chunk_gate.py ADDED
@@ -0,0 +1,1333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import reduce
11
+
12
+ from ...ops.utils import (chunk_global_reversed_cumsum, chunk_local_cumsum, softmax_bwd_kernel,
13
+ softmax_fwd_kernel)
14
+ from ...utils import contiguous
15
+
16
+
17
+
18
+ @triton.jit
19
+ def chunk_gated_abc_fwd_kernel_h(
20
+ k,
21
+ v,
22
+ g,
23
+ h,
24
+ h0,
25
+ ht,
26
+ s_k_h,
27
+ s_k_t,
28
+ s_k_d,
29
+ s_v_h,
30
+ s_v_t,
31
+ s_v_d,
32
+ s_h_h,
33
+ s_h_t,
34
+ s_h_d,
35
+ T: tl.constexpr,
36
+ K: tl.constexpr,
37
+ V: tl.constexpr,
38
+ BT: tl.constexpr,
39
+ BK: tl.constexpr,
40
+ BV: tl.constexpr,
41
+ NT: tl.constexpr,
42
+ GATEK: tl.constexpr,
43
+ USE_INITIAL_STATE: tl.constexpr,
44
+ STORE_FINAL_STATE: tl.constexpr
45
+ ):
46
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
47
+
48
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
49
+ if USE_INITIAL_STATE:
50
+ p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
51
+ b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
52
+ for i_t in range(NT):
53
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
54
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
55
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
56
+ o_t = min(i_t * BT + BT, T)
57
+
58
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
59
+ # [BK, BT]
60
+ b_k = tl.load(p_k, boundary_check=(0, 1))
61
+ # [BT, BV]
62
+ b_v = tl.load(p_v, boundary_check=(0, 1))
63
+ if GATEK:
64
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
65
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))
66
+ # [BK,]
67
+ b_gn = tl.load(p_gn, boundary_check=(0,))
68
+ # [BK, BV]
69
+ b_h *= tl.exp(b_gn)[:, None]
70
+ # [BK, BT]
71
+ b_g = tl.load(p_g, boundary_check=(0, 1))
72
+ b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
73
+ else:
74
+ p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
75
+ p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((o_t - 1) * V + i_v * BV,), (BV,), (0,))
76
+ # [BV,]
77
+ b_gn = tl.load(p_gn, boundary_check=(0,))
78
+ # [BK, BV]
79
+ b_h *= tl.exp(b_gn)[None, :]
80
+ # [BT, BV]
81
+ b_g = tl.load(p_g, boundary_check=(0, 1))
82
+ b_v = (b_v * tl.exp(b_gn[None, :] - b_g)).to(b_v.dtype)
83
+ # [BK, BV]
84
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
85
+
86
+ if STORE_FINAL_STATE:
87
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
88
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
89
+
90
+
91
+ @triton.jit
92
+ def chunk_gated_abc_fwd_kernel_intra_K(
93
+ v,
94
+ g,
95
+ o,
96
+ A,
97
+ s_v_h,
98
+ s_v_t,
99
+ s_v_d,
100
+ T: tl.constexpr,
101
+ V: tl.constexpr,
102
+ BT: tl.constexpr,
103
+ BC: tl.constexpr,
104
+ BV: tl.constexpr,
105
+ NC: tl.constexpr,
106
+ NG: tl.constexpr
107
+ ):
108
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
109
+ i_bg = i_bh // NG
110
+ i_t, i_i = i_c // NC, i_c % NC
111
+
112
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
113
+ p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
114
+ # [BV,]
115
+ b_gn = tl.load(p_gn, boundary_check=(0,))
116
+ # [BC, BV]
117
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
118
+ for i_j in range(0, i_i):
119
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
121
+ p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
122
+ # [BC, BV]
123
+ b_v = tl.load(p_v, boundary_check=(0, 1))
124
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
125
+ b_vg = (b_v * tl.exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
126
+ # [BC, BC]
127
+ b_A = tl.load(p_A, boundary_check=(0, 1))
128
+ b_o += tl.dot(b_A, b_vg, allow_tf32=False)
129
+ # [BC, BV]
130
+ b_g = tl.load(p_g, boundary_check=(0, 1))
131
+ b_o *= tl.exp(b_g - b_gn[None, :])
132
+
133
+ o_i = tl.arange(0, BC)
134
+ o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
135
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
136
+ for j in range(0, BC):
137
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
138
+ p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
139
+ # [BC,]
140
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
141
+ # [BV,]
142
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
143
+ b_gv = tl.load(p_gv, boundary_check=(0,)).to(tl.float32)
144
+ # [BC, BV]
145
+ b_vg = b_v[None, :] * tl.exp(b_g - b_gv[None, :])
146
+ # avoid 0 * inf = inf
147
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
148
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
149
+
150
+ b_o += tl.load(p_o, boundary_check=(0, 1))
151
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
152
+
153
+
154
+ @triton.jit
155
+ def chunk_gated_abc_fwd_kernel_K(
156
+ q,
157
+ k,
158
+ h,
159
+ g,
160
+ o,
161
+ A,
162
+ s_k_h,
163
+ s_k_t,
164
+ s_k_d,
165
+ s_v_h,
166
+ s_v_t,
167
+ s_v_d,
168
+ s_h_h,
169
+ s_h_t,
170
+ s_h_d,
171
+ scale,
172
+ T: tl.constexpr,
173
+ K: tl.constexpr,
174
+ V: tl.constexpr,
175
+ BT: tl.constexpr,
176
+ BK: tl.constexpr,
177
+ BV: tl.constexpr,
178
+ NG: tl.constexpr
179
+ ):
180
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
181
+ i_bg = i_bh // NG
182
+
183
+ o_i = tl.arange(0, BT)
184
+ m_s = o_i[:, None] >= o_i[None, :]
185
+
186
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
187
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
188
+ for i_k in range(tl.cdiv(K, BK)):
189
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
190
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
191
+ p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
192
+
193
+ # [BT, BK]
194
+ b_q = tl.load(p_q, boundary_check=(0, 1))
195
+ b_q = (b_q * scale).to(b_q.dtype)
196
+ # [BK, BT]
197
+ b_k = tl.load(p_k, boundary_check=(0, 1))
198
+ # [BK, BV]
199
+ b_h = tl.load(p_h, boundary_check=(0, 1))
200
+ # [BT, BV]
201
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
202
+ # [BT, BT]
203
+ b_A += tl.dot(b_q, b_k, allow_tf32=False)
204
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ # [BT, BV]
207
+ b_g = tl.load(p_g, boundary_check=(0, 1))
208
+ b_o = b_o * tl.exp(b_g)
209
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
210
+
211
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
212
+ # [BT, BT]
213
+ b_A = tl.where(m_s, b_A, 0.)
214
+ if i_v == 0:
215
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
216
+
217
+
218
+ @triton.jit
219
+ def chunk_gated_abc_fwd_kernel_intra_Vk(
220
+ q,
221
+ k,
222
+ g,
223
+ A,
224
+ s_k_h,
225
+ s_k_t,
226
+ s_k_d,
227
+ i_k,
228
+ i_c,
229
+ i_bh,
230
+ scale,
231
+ T: tl.constexpr,
232
+ K: tl.constexpr,
233
+ BT: tl.constexpr,
234
+ BC: tl.constexpr,
235
+ BK: tl.constexpr,
236
+ NC: tl.constexpr,
237
+ NG: tl.constexpr
238
+ ):
239
+ i_bg = i_bh // NG
240
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
241
+
242
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
243
+
244
+ b_A = tl.zeros([BC, BC], tl.float32)
245
+ if i_i > i_j:
246
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
247
+ p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
248
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
249
+ p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
250
+ p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
251
+
252
+ # [BK,]
253
+ b_gn = tl.load(p_gn, boundary_check=(0,))
254
+ # [BC, BK]
255
+ b_q = tl.load(p_q, boundary_check=(0, 1))
256
+ b_g = tl.load(p_g, boundary_check=(0, 1))
257
+ b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
258
+ # [BK, BC]
259
+ b_k = tl.load(p_k, boundary_check=(0, 1))
260
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
261
+ b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
262
+ # [BC, BC]
263
+ b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
264
+ if i_k != 0:
265
+ b_A += tl.load(p_A, boundary_check=(0, 1))
266
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
267
+ elif i_i == i_j:
268
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
269
+ p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
270
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
271
+ p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
272
+ # [BC, BK]
273
+ b_q = tl.load(p_q, boundary_check=(0, 1))
274
+ b_g = tl.load(p_g, boundary_check=(0, 1))
275
+
276
+ o_i = tl.arange(0, BC)
277
+ # [BC, BC]
278
+ m_A = o_i[:, None] >= o_i[None, :]
279
+ for j in range(0, BC):
280
+ # [BK,]
281
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
282
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
283
+ # [BC,]
284
+ b_Aj = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
285
+ b_A = tl.where((o_i == j)[None, :], b_Aj[:, None], b_A)
286
+
287
+ p_k = tl.advance(p_k, (K,))
288
+ p_gk = tl.advance(p_gk, (K,))
289
+ b_A = tl.where(m_A, b_A, 0.)
290
+ if i_k != 0:
291
+ b_A += tl.load(p_A, boundary_check=(0, 1))
292
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
293
+ else:
294
+ # set the upper triangular part to 0
295
+ if i_k == 0:
296
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
297
+
298
+
299
+ @triton.jit
300
+ def chunk_gated_abc_fwd_kernel_intra_V(
301
+ q,
302
+ k,
303
+ g,
304
+ A,
305
+ s_k_h,
306
+ s_k_t,
307
+ s_k_d,
308
+ scale,
309
+ T: tl.constexpr,
310
+ K: tl.constexpr,
311
+ BT: tl.constexpr,
312
+ BC: tl.constexpr,
313
+ BK: tl.constexpr,
314
+ NC: tl.constexpr,
315
+ NK: tl.constexpr,
316
+ NG: tl.constexpr
317
+ ):
318
+ i_c, i_bh = tl.program_id(0), tl.program_id(1)
319
+
320
+ for i_k in range(0, NK):
321
+ chunk_gated_abc_fwd_kernel_intra_Vk(
322
+ q,
323
+ k,
324
+ g,
325
+ A,
326
+ s_k_h,
327
+ s_k_t,
328
+ s_k_d,
329
+ i_k,
330
+ i_c,
331
+ i_bh,
332
+ scale,
333
+ T,
334
+ K,
335
+ BT,
336
+ BC,
337
+ BK,
338
+ NC,
339
+ NG,
340
+ )
341
+
342
+
343
+ @triton.jit
344
+ def chunk_gated_abc_fwd_kernel_V(
345
+ q,
346
+ v,
347
+ g,
348
+ h,
349
+ o,
350
+ A,
351
+ s_k_h,
352
+ s_k_t,
353
+ s_k_d,
354
+ s_v_h,
355
+ s_v_t,
356
+ s_v_d,
357
+ s_h_h,
358
+ s_h_t,
359
+ s_h_d,
360
+ scale,
361
+ T: tl.constexpr,
362
+ K: tl.constexpr,
363
+ V: tl.constexpr,
364
+ BT: tl.constexpr,
365
+ BK: tl.constexpr,
366
+ BV: tl.constexpr,
367
+ NG: tl.constexpr
368
+ ):
369
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
370
+ i_bg = i_bh // NG
371
+
372
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
373
+ for i_k in range(tl.cdiv(K, BK)):
374
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
375
+ p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
376
+ p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+
378
+ # [BT, BK]
379
+ b_q = tl.load(p_q, boundary_check=(0, 1))
380
+ b_q = (b_q * scale).to(b_q.dtype)
381
+ # [BT, BK]
382
+ b_g = tl.load(p_g, boundary_check=(0, 1))
383
+ # [BT, BK]
384
+ b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
385
+ # [BK, BV]
386
+ b_h = tl.load(p_h, boundary_check=(0, 1))
387
+ # works but dkw, owing to divine benevolence
388
+ # [BT, BV]
389
+ if i_k >= 0:
390
+ b_o += tl.dot(b_qg, b_h, allow_tf32=False)
391
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
392
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
393
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
394
+ # [BT, BV]
395
+ b_v = tl.load(p_v, boundary_check=(0, 1))
396
+ # [BT, BT]
397
+ b_A = tl.load(p_A, boundary_check=(0, 1))
398
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
399
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
400
+
401
+
402
+ @triton.jit
403
+ def chunk_gated_abc_bwd_kernel_dh(
404
+ q,
405
+ g,
406
+ do,
407
+ dh,
408
+ s_k_h,
409
+ s_k_t,
410
+ s_k_d,
411
+ s_v_h,
412
+ s_v_t,
413
+ s_v_d,
414
+ s_h_h,
415
+ s_h_t,
416
+ s_h_d,
417
+ scale,
418
+ T: tl.constexpr,
419
+ K: tl.constexpr,
420
+ V: tl.constexpr,
421
+ BT: tl.constexpr,
422
+ BK: tl.constexpr,
423
+ BV: tl.constexpr,
424
+ NT: tl.constexpr,
425
+ NG: tl.constexpr,
426
+ GATEK: tl.constexpr
427
+ ):
428
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
429
+ i_bg = i_bh // NG
430
+
431
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
432
+ for i_t in range(NT - 1, -1, -1):
433
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
434
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
436
+ o_t = min(i_t * BT + BT, T)
437
+
438
+ # [BK, BT]
439
+ b_q = tl.load(p_q, boundary_check=(0, 1))
440
+ b_q = (b_q * scale).to(b_q.dtype)
441
+ # [BT, BV]
442
+ b_do = tl.load(p_do, boundary_check=(0, 1))
443
+
444
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
445
+ if GATEK:
446
+ p_g = tl.make_block_ptr(g + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
447
+ p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))
448
+ # [BK,]
449
+ b_gn = tl.load(p_gn, boundary_check=(0,))
450
+ # [BK, BV]
451
+ b_dh *= tl.exp(b_gn)[:, None]
452
+ # [BK, BT]
453
+ b_g = tl.load(p_g, boundary_check=(0, 1))
454
+ b_q = (b_q * tl.exp(b_g)).to(b_q.dtype)
455
+ else:
456
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
457
+ p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((o_t - 1) * V + i_v * BV,), (BV,), (0,))
458
+ # [BV,]
459
+ b_gn = tl.load(p_gn, boundary_check=(0,))
460
+ # [BK, BV]
461
+ b_dh *= tl.exp(b_gn)[None, :]
462
+ # [BT, BV]
463
+ b_g = tl.load(p_g, boundary_check=(0, 1))
464
+ b_do = (b_do * tl.exp(b_g)).to(b_do.dtype)
465
+ # [BK, BV]
466
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
467
+
468
+
469
+ @triton.jit
470
+ def chunk_gated_abc_bwd_kernel_V(
471
+ k,
472
+ v,
473
+ h,
474
+ g,
475
+ A,
476
+ do,
477
+ dh,
478
+ dq,
479
+ dk,
480
+ dv,
481
+ dA,
482
+ s_k_h,
483
+ s_k_t,
484
+ s_k_d,
485
+ s_v_h,
486
+ s_v_t,
487
+ s_v_d,
488
+ s_h_h,
489
+ s_h_t,
490
+ s_h_d,
491
+ scale,
492
+ T: tl.constexpr,
493
+ K: tl.constexpr,
494
+ V: tl.constexpr,
495
+ BT: tl.constexpr,
496
+ BK: tl.constexpr,
497
+ BV: tl.constexpr,
498
+ NG: tl.constexpr
499
+ ):
500
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
501
+ i_bg = i_bh // NG
502
+ n_bh = tl.num_programs(2)
503
+ o_t = min(i_t * BT + BT, T)
504
+
505
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
506
+ p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
507
+ p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,))
508
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
509
+
510
+ # [BK,]
511
+ # [BT, BK]
512
+ b_k = tl.load(p_k, boundary_check=(0, 1))
513
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
514
+ b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
515
+ b_k = (b_k * b_gn).to(b_k.dtype)
516
+ # [BT, BT]
517
+ b_A = tl.load(p_A, boundary_check=(0, 1))
518
+
519
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
520
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
521
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
522
+ for i_v in range(tl.cdiv(V, BV)):
523
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
524
+ p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
525
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
526
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
527
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
528
+
529
+ # [BT, BV]
530
+ b_v = tl.load(p_v, boundary_check=(0, 1))
531
+ # [BV, BK]
532
+ b_h = tl.load(p_h, boundary_check=(0, 1))
533
+ # [BT, BV]
534
+ b_do = tl.load(p_do, boundary_check=(0, 1))
535
+ # [BK, BV]
536
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
537
+
538
+ # [BT, BV]
539
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
540
+ if i_k == 0:
541
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
542
+ b_do = (b_do * scale).to(b_do.dtype)
543
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
544
+ # [BT, BT]
545
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
546
+ # [BT, BK]
547
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
548
+ # [BT, BK]
549
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
550
+ b_dq = b_dq * tl.exp(b_gk)
551
+ b_dk = b_dk * b_gn
552
+
553
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
554
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
555
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
556
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
557
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
558
+
559
+ o_i = tl.arange(0, BT)
560
+ m_s = o_i[:, None] >= o_i[None, :]
561
+ # [BT, BT]
562
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
563
+ if i_k == 0:
564
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
565
+
566
+
567
+ @triton.jit
568
+ def chunk_gated_abc_bwd_kernel_intra_V(
569
+ q,
570
+ k,
571
+ g,
572
+ dA,
573
+ dq,
574
+ dk,
575
+ dg,
576
+ s_k_h,
577
+ s_k_t,
578
+ s_k_d,
579
+ T: tl.constexpr,
580
+ K: tl.constexpr,
581
+ BT: tl.constexpr,
582
+ BC: tl.constexpr,
583
+ BK: tl.constexpr,
584
+ NC: tl.constexpr,
585
+ NG: tl.constexpr,
586
+ OVERWRITE: tl.constexpr
587
+ ):
588
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
589
+ i_bg = i_bh // NG
590
+ i_t, i_i = i_c // NC, i_c % NC
591
+
592
+ p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
593
+ p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
594
+ # [BK,]
595
+ b_gn = tl.load(p_gn, boundary_check=(0,))
596
+ # [BC, BK]
597
+ b_g = tl.load(p_g, boundary_check=(0, 1))
598
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
599
+ for i_j in range(0, i_i):
600
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
601
+ p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
602
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
603
+ # [BC, BK]
604
+ b_k = tl.load(p_k, boundary_check=(0, 1))
605
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
606
+ b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
607
+ # [BC, BC]
608
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
609
+ # [BC, BK]
610
+ b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
611
+ b_dq *= tl.exp(b_g - b_gn[None, :])
612
+
613
+ o_i = tl.arange(0, BC)
614
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
615
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
616
+ for j in range(0, BC):
617
+ p_kj = tl.make_block_ptr(k + i_bg * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
618
+ p_gkj = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
619
+ # [BC,]
620
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
621
+ # [BK,]
622
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
623
+ b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
624
+ # [BC, BK]
625
+ m_i = o_i[:, None] >= j
626
+ # [BC, BK]
627
+ b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
628
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
629
+
630
+ b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
631
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
632
+
633
+ tl.debug_barrier()
634
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
635
+ p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
636
+ p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
637
+ # [BK,]
638
+ b_gn = tl.load(p_gn, boundary_check=(0,))
639
+ # [BC, BK]
640
+ b_k = tl.load(p_k, boundary_check=(0, 1))
641
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
642
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
643
+ for i_j in range(i_i + 1, NC):
644
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
645
+ p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
646
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
647
+ # [BC, BK]
648
+ b_q = tl.load(p_q, boundary_check=(0, 1))
649
+ b_g = tl.load(p_g, boundary_check=(0, 1))
650
+ b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
651
+ # [BC, BC]
652
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
653
+ # [BC, BK]
654
+ b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
655
+ b_dk *= tl.exp(b_gn[None, :] - b_gk)
656
+
657
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
658
+ for j in range(0, BC):
659
+ p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
660
+ p_gqj = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
661
+ # [BC,]
662
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
663
+ # [BK,]
664
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
665
+ b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
666
+ # [BC, BK]
667
+ m_i = o_i[:, None] <= j
668
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
669
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
670
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
671
+ p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
672
+
673
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
674
+ b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)).to(tl.float32)
675
+ b_dg = b_q * b_dq - b_k * b_dk
676
+ if not OVERWRITE:
677
+ b_dg = b_dg + tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
678
+
679
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
680
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
681
+
682
+
683
+ @triton.jit
684
+ def chunk_gated_abc_bwd_kernel_intra_K(
685
+ v,
686
+ g,
687
+ do,
688
+ dA,
689
+ s_v_h,
690
+ s_v_t,
691
+ s_v_d,
692
+ scale,
693
+ T: tl.constexpr,
694
+ V: tl.constexpr,
695
+ BT: tl.constexpr,
696
+ BC: tl.constexpr,
697
+ BV: tl.constexpr,
698
+ NC: tl.constexpr,
699
+ NG: tl.constexpr
700
+ ):
701
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
702
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
703
+ i_bg = i_bh // NG
704
+ n_bh = tl.num_programs(2)
705
+
706
+ p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
707
+
708
+ # [BC, BC]
709
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
710
+ if i_i > i_j:
711
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
712
+ p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
713
+ p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
714
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
715
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
716
+ # [BV,]
717
+ b_gn = tl.load(p_gn, boundary_check=(0,))
718
+ # [BC, BV]
719
+ b_g = tl.load(p_g, boundary_check=(0, 1))
720
+ b_do = tl.load(p_do, boundary_check=(0, 1))
721
+ b_do = (b_do * tl.exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
722
+ # [BV, BC]
723
+ b_v = tl.load(p_v, boundary_check=(0, 1))
724
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
725
+ b_vg = (b_v * tl.exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
726
+ # [BC, BC]
727
+ b_dA = tl.dot(b_do, b_vg, allow_tf32=False)
728
+ elif i_i == i_j:
729
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
730
+ p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
731
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
732
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
733
+ # [BC, BV]
734
+ b_g = tl.load(p_g, boundary_check=(0, 1))
735
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
736
+
737
+ o_i = tl.arange(0, BC)
738
+ # [BC, BC]
739
+ m_dA = o_i[:, None] >= o_i[None, :]
740
+ for j in range(0, BC):
741
+ # [BV,]
742
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
743
+ b_gv = tl.load(p_gv, boundary_check=(0,)).to(tl.float32)
744
+ # [BC,]
745
+ b_dAj = tl.sum(b_do * b_v[None, :] * tl.exp(b_g - b_gv[None, :]), 1)
746
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
747
+
748
+ p_v = tl.advance(p_v, (V,))
749
+ p_gv = tl.advance(p_gv, (V,))
750
+ b_dA = tl.where(m_dA, b_dA, 0.)
751
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
752
+
753
+
754
+ @triton.jit
755
+ def chunk_gated_abc_bwd_kernel_K(
756
+ q,
757
+ k,
758
+ v,
759
+ h,
760
+ g,
761
+ A,
762
+ do,
763
+ dh,
764
+ dq,
765
+ dk,
766
+ dv,
767
+ dA,
768
+ s_k_h,
769
+ s_k_t,
770
+ s_k_d,
771
+ s_v_h,
772
+ s_v_t,
773
+ s_v_d,
774
+ s_h_h,
775
+ s_h_t,
776
+ s_h_d,
777
+ scale,
778
+ T: tl.constexpr,
779
+ K: tl.constexpr,
780
+ V: tl.constexpr,
781
+ BT: tl.constexpr,
782
+ BK: tl.constexpr,
783
+ BV: tl.constexpr,
784
+ NG: tl.constexpr
785
+ ):
786
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
787
+ i_bg = i_bh // NG
788
+ n_bh = tl.num_programs(2)
789
+
790
+ o_i = tl.arange(0, BT)
791
+ o_t = min(i_t * BT + BT, T)
792
+ m_s = o_i[:, None] >= o_i[None, :]
793
+
794
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
795
+ p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
796
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
797
+
798
+ # [BT, BK]
799
+ b_q = tl.load(p_q, boundary_check=(0, 1))
800
+ b_k = tl.load(p_k, boundary_check=(0, 1))
801
+ # [BT, BT]
802
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False)
803
+ b_A = tl.where(m_s, b_A, 0.)
804
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
805
+
806
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
807
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
808
+ for i_v in range(tl.cdiv(V, BV)):
809
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
810
+ p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * K*V, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
811
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
812
+ p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((o_t - 1) * V + i_v * BV,), (BV,), (0,))
813
+
814
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
815
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
816
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
817
+
818
+ # [BV,]
819
+ b_gn = tl.load(p_gn, boundary_check=(0,))
820
+ # [BT, BV]
821
+ b_v = tl.load(p_v, boundary_check=(0, 1))
822
+ b_g = tl.load(p_g, boundary_check=(0, 1))
823
+ b_v = b_v * tl.exp(b_gn[None, :] - b_g)
824
+ # [BV, BK]
825
+ b_h = tl.load(p_h, boundary_check=(0, 1))
826
+ # [BT, BV]
827
+ b_do = tl.load(p_do, boundary_check=(0, 1))
828
+ b_do = (b_do * tl.exp(b_g) * scale).to(b_do.dtype)
829
+ # [BK, BV]
830
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
831
+
832
+ # [BT, BK]
833
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
834
+ b_dk += tl.dot(b_v.to(b_dh.dtype), tl.trans(b_dh), allow_tf32=False)
835
+ # [BT, BV]
836
+ b_dv = tl.exp(b_gn[None, :] - b_g) * tl.dot(b_k, b_dh, allow_tf32=False)
837
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
838
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
839
+ # [BT, BT]
840
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
841
+ # [BT, BK]
842
+ b_dq += tl.dot(b_dA, b_k, allow_tf32=False)
843
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False)
844
+
845
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
846
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
847
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
848
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
849
+
850
+
851
+ @triton.jit
852
+ def chunk_gated_abc_bwd_kernel_intra_KV(
853
+ v,
854
+ g,
855
+ o,
856
+ A,
857
+ do,
858
+ dv,
859
+ dg,
860
+ s_v_h,
861
+ s_v_t,
862
+ s_v_d,
863
+ T: tl.constexpr,
864
+ V: tl.constexpr,
865
+ BT: tl.constexpr,
866
+ BC: tl.constexpr,
867
+ BV: tl.constexpr,
868
+ NC: tl.constexpr,
869
+ NG: tl.constexpr,
870
+ OVERWRITE: tl.constexpr
871
+ ):
872
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
873
+ i_bg = i_bh // NG
874
+ i_t, i_i = i_c // NC, i_c % NC
875
+
876
+ p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
877
+ p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T*V,), (s_v_d,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,))
878
+ # [BV,]
879
+ b_gn = tl.load(p_gn, boundary_check=(0,))
880
+ # [BC, BV]
881
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
882
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
883
+ for i_j in range(i_i + 1, NC):
884
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
885
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
886
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
887
+ # [BC, BV]
888
+ b_g = tl.load(p_g, boundary_check=(0, 1))
889
+ b_do = tl.load(p_do, boundary_check=(0, 1))
890
+ b_do = (b_do * tl.exp(b_g - b_gn[None, :])).to(b_do.dtype)
891
+ # [BC, BC]
892
+ b_A = tl.load(p_A, boundary_check=(0, 1))
893
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
894
+ b_dv *= tl.exp(b_gn[None, :] - b_gv)
895
+
896
+ o_i = tl.arange(0, BC)
897
+ for j in range(0, BC):
898
+ p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
899
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,))
900
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
901
+ # [BC,]
902
+ b_A = tl.load(p_A, boundary_check=(0,))
903
+ # [BV,]
904
+ b_g = tl.load(p_g, boundary_check=(0,))
905
+ b_do = tl.load(p_do, boundary_check=(0,))
906
+ # [BC, BV]
907
+ m_i = o_i[:, None] <= j
908
+ b_dv += tl.where(m_i, tl.exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
909
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
910
+ p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
911
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
912
+ p_dv = tl.make_block_ptr(dv + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
913
+ p_dg = tl.make_block_ptr(dg + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
914
+
915
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
916
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
917
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
918
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
919
+ b_dg = b_o * b_do - b_v * b_dv
920
+ if not OVERWRITE:
921
+ b_dg = b_dg + tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
922
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
923
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
924
+
925
+
926
+ def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, gatek=False, h0=None, ht=None):
927
+ NT = triton.cdiv(T, BT)
928
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
929
+ num_warps = 4 if BK == 64 else 2
930
+ num_stages = 1
931
+
932
+ h = q.new_empty(B, H, NT * K, V)
933
+ grid = (NV, NK, B * H)
934
+ chunk_gated_abc_fwd_kernel_h[grid](
935
+ k, v, g, h, h0, ht,
936
+ k.stride(1), k.stride(2), k.stride(3),
937
+ v.stride(1), v.stride(2), v.stride(3),
938
+ h.stride(1), h.stride(2), h.stride(3),
939
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
940
+ GATEK=gatek,
941
+ USE_INITIAL_STATE=h0 is not None,
942
+ STORE_FINAL_STATE=ht is not None,
943
+ num_warps=num_warps,
944
+ num_stages=num_stages
945
+ )
946
+ return h
947
+
948
+
949
+ def fwd_v(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.):
950
+ HQ = q.shape[1]
951
+ NT = triton.cdiv(T, BT)
952
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
953
+ NC = triton.cdiv(BT, BC)
954
+ NG = HQ // H
955
+ num_warps = 4 if BK == 64 else 2
956
+ num_stages = 1
957
+
958
+ h = fwd_inner(
959
+ q=q, k=k, v=v, g=g,
960
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
961
+ gatek=True,
962
+ h0=h0,
963
+ ht=ht
964
+ )
965
+ A = q.new_empty(B, HQ, T, BT)
966
+ grid = (NT * NC * NC, B * HQ)
967
+ chunk_gated_abc_fwd_kernel_intra_V[grid](
968
+ q, k, g, A,
969
+ k.stride(1), k.stride(2), k.stride(3),
970
+ scale,
971
+ T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, NK=NK, NG=NG,
972
+ num_warps=num_warps,
973
+ num_stages=num_stages
974
+ )
975
+ o = v.new_empty(B, HQ, T, V)
976
+ grid = (NV, NT, B * HQ)
977
+ chunk_gated_abc_fwd_kernel_V[grid](
978
+ q, v, g, h, o, A,
979
+ k.stride(1), k.stride(2), k.stride(3),
980
+ v.stride(1), v.stride(2), v.stride(3),
981
+ h.stride(1), h.stride(2), h.stride(3),
982
+ scale,
983
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG,
984
+ num_warps=num_warps,
985
+ num_stages=num_stages
986
+ )
987
+ return o, h, A
988
+
989
+
990
+ def fwd_k(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.):
991
+ HQ = q.shape[1]
992
+ NT = triton.cdiv(T, BT)
993
+ NV = triton.cdiv(V, BV)
994
+ NC = triton.cdiv(BT, BC)
995
+ NG = HQ // H
996
+ num_warps = 4 if BK == 64 else 2
997
+ num_stages = 1
998
+
999
+ h = fwd_inner(
1000
+ q=q, k=k, v=v, g=g,
1001
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
1002
+ gatek=False,
1003
+ h0=h0,
1004
+ ht=ht
1005
+ )
1006
+ o = v.new_empty(B, HQ, T, V)
1007
+ A = q.new_empty(B, HQ, T, BT)
1008
+ grid = (NV, NT, B * HQ)
1009
+ chunk_gated_abc_fwd_kernel_K[grid](
1010
+ q, k, h, g, o, A,
1011
+ k.stride(1), k.stride(2), k.stride(3),
1012
+ v.stride(1), v.stride(2), v.stride(3),
1013
+ h.stride(1), h.stride(2), h.stride(3),
1014
+ scale,
1015
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG,
1016
+ num_warps=num_warps,
1017
+ num_stages=num_stages
1018
+ )
1019
+ grid = (NV, NT * NC, B * HQ)
1020
+ chunk_gated_abc_fwd_kernel_intra_K[grid](
1021
+ v, g, o, A,
1022
+ v.stride(1), v.stride(2), v.stride(3),
1023
+ T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG,
1024
+ num_warps=num_warps,
1025
+ num_stages=num_stages
1026
+ )
1027
+ return o, h, A
1028
+
1029
+
1030
+ def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, scale, gatek=False):
1031
+ HQ = q.shape[1]
1032
+ NT = triton.cdiv(T, BT)
1033
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
1034
+ NG = HQ // H
1035
+ num_warps = 4 if BK == 64 else 2
1036
+ num_stages = 1
1037
+
1038
+ dh = q.new_empty(B, HQ, NT * K, V)
1039
+ grid = (NK, NV, B * HQ)
1040
+ chunk_gated_abc_bwd_kernel_dh[grid](
1041
+ q, g, do, dh,
1042
+ q.stride(1), q.stride(2), q.stride(3),
1043
+ do.stride(1), do.stride(2), do.stride(3),
1044
+ dh.stride(1), dh.stride(2), dh.stride(3),
1045
+ scale,
1046
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, NG=NG,
1047
+ GATEK=gatek,
1048
+ num_warps=num_warps,
1049
+ num_stages=num_stages
1050
+ )
1051
+ return dh
1052
+
1053
+
1054
+ def bwd_v(q, k, v, g, h, A, do, dg, B, H, T, K, V, BT, BK, BV, BC, scale=1.):
1055
+ HQ = q.shape[1]
1056
+ NT = triton.cdiv(T, BT)
1057
+ NK = triton.cdiv(K, BK)
1058
+ NC = triton.cdiv(BT, BC)
1059
+ NG = HQ // H
1060
+ num_warps = 4 if BK == 64 else 2
1061
+ num_stages = 1
1062
+
1063
+ overwrite_dg = dg is None
1064
+ dh = bwd_inner(
1065
+ q, g, do,
1066
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
1067
+ scale=scale,
1068
+ gatek=True
1069
+ )
1070
+ dq = torch.empty_like(q, dtype=torch.float)
1071
+ dk = k.new_empty(B, HQ, T, K, dtype=torch.float)
1072
+ dv = v.new_empty(NK, B, HQ, T, V)
1073
+ dg = g.new_empty(B, HQ, T, K, dtype=torch.float) if dg is None else dg
1074
+ dA = v.new_empty(B, HQ, T, BT)
1075
+
1076
+ grid = (NK, NT, B * HQ)
1077
+ chunk_gated_abc_bwd_kernel_V[grid](
1078
+ k, v, h, g, A, do, dh, dq, dk, dv, dA,
1079
+ k.stride(1), k.stride(2), k.stride(3),
1080
+ v.stride(1), v.stride(2), v.stride(3),
1081
+ h.stride(1), h.stride(2), h.stride(3),
1082
+ scale,
1083
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG,
1084
+ num_warps=num_warps,
1085
+ num_stages=num_stages
1086
+ )
1087
+ dv = dv.sum(0, dtype=dv.dtype)
1088
+ grid = (NK, NT * NC, B * HQ)
1089
+ chunk_gated_abc_bwd_kernel_intra_V[grid](
1090
+ q, k, g, dA, dq, dk, dg,
1091
+ k.stride(1), k.stride(2), k.stride(3),
1092
+ T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, NG=NG,
1093
+ OVERWRITE=overwrite_dg,
1094
+ num_warps=num_warps,
1095
+ num_stages=num_stages
1096
+ )
1097
+ return dq, dk, dv, dg
1098
+
1099
+
1100
+ def bwd_k(q, k, v, g, h, o, do, dg, B, H, T, K, V, BT, BK, BV, BC, scale=1.):
1101
+ HQ = q.shape[1]
1102
+ NT = triton.cdiv(T, BT)
1103
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
1104
+ NC = triton.cdiv(BT, BC)
1105
+ NG = HQ // H
1106
+ num_warps = 4 if BK == 64 else 2
1107
+ num_stages = 1
1108
+
1109
+ overwrite_dg = dg is None
1110
+ dh = bwd_inner(
1111
+ q, g, do,
1112
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
1113
+ scale=scale,
1114
+ gatek=False
1115
+ )
1116
+ dA = q.new_empty(NV, B, HQ, T, BT)
1117
+ grid = (NV, NT * NC * NC, B * HQ)
1118
+ chunk_gated_abc_bwd_kernel_intra_K[grid](
1119
+ v, g, do, dA,
1120
+ v.stride(1), v.stride(2), v.stride(3),
1121
+ scale,
1122
+ T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG,
1123
+ num_warps=num_warps,
1124
+ num_stages=num_stages
1125
+ )
1126
+ dA = dA.sum(0, dtype=dA.dtype)
1127
+
1128
+ A = do.new_empty(NK, B, HQ, T, BT)
1129
+ dq = torch.empty_like(q)
1130
+ dk = k.new_empty(B, HQ, T, K)
1131
+ dv = v.new_empty(NK, B, HQ, T, V)
1132
+ dg = g.new_empty(B, HQ, T, V, dtype=torch.float) if dg is None else dg
1133
+ grid = (NK, NT, B * HQ)
1134
+ chunk_gated_abc_bwd_kernel_K[grid](
1135
+ q, k, v, h, g, A, do, dh, dq, dk, dv, dA,
1136
+ q.stride(1), q.stride(2), q.stride(3),
1137
+ v.stride(1), v.stride(2), v.stride(3),
1138
+ h.stride(1), h.stride(2), h.stride(3),
1139
+ scale,
1140
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG,
1141
+ num_warps=num_warps,
1142
+ num_stages=num_stages
1143
+ )
1144
+ A = A.sum(0, dtype=A.dtype)
1145
+ dv = dv.sum(0, dtype=dv.dtype)
1146
+ grid = (NV, NT * NC, B * HQ)
1147
+ chunk_gated_abc_bwd_kernel_intra_KV[grid](
1148
+ v, g, o, A, do, dv, dg,
1149
+ v.stride(1), v.stride(2), v.stride(3),
1150
+ T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG,
1151
+ OVERWRITE=overwrite_dg,
1152
+ num_warps=num_warps,
1153
+ num_stages=num_stages
1154
+ )
1155
+ return dq, dk, dv, dg
1156
+
1157
+
1158
+ class ChunkGatedABCFunction(torch.autograd.Function):
1159
+
1160
+ @staticmethod
1161
+ @contiguous
1162
+ def forward(ctx, q, k, v, s, g, scale, hk0, hv0, output_final_state, checkpoint_level):
1163
+ B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
1164
+ BT, BC = 64, 16
1165
+ BK = min(64, triton.next_power_of_2(K))
1166
+ BV = min(64, triton.next_power_of_2(V))
1167
+ BM = min(64, triton.next_power_of_2(M))
1168
+
1169
+ hkt, hvt = None, None
1170
+ if output_final_state:
1171
+ hkt = q.new_empty(B, H, K, M, dtype=torch.float)
1172
+ hvt = q.new_empty(B, H, M, V, dtype=torch.float)
1173
+
1174
+ g_cumsum = chunk_local_cumsum(g, BT)
1175
+ g_org, g = g, g_cumsum
1176
+ ok, hk, _ = fwd_k(
1177
+ q=q, k=k, v=s, g=g,
1178
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, BC=BC,
1179
+ h0=hk0,
1180
+ ht=hkt,
1181
+ scale=scale
1182
+ )
1183
+
1184
+ # equivalent to:
1185
+ # p = ok.softmax(-1, torch.float)
1186
+ # p is kept in fp32 for safe softmax backward
1187
+ p = torch.empty_like(ok, dtype=torch.float)
1188
+ def grid(meta): return (triton.cdiv(meta['T'], meta['BT']), p.shape[0] * p.shape[1])
1189
+ softmax_fwd_kernel[grid](
1190
+ ok, p,
1191
+ s.stride(1), s.stride(2), s.stride(3),
1192
+ T=T, S=M, BT=BT
1193
+ )
1194
+
1195
+ ov, hv, Av = fwd_v(
1196
+ q=p.to(q.dtype), k=s, v=v, g=g,
1197
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, BC=BC,
1198
+ h0=hv0,
1199
+ ht=hvt,
1200
+ scale=1.
1201
+ )
1202
+
1203
+ if checkpoint_level >= 1:
1204
+ del g
1205
+ g = g_org
1206
+ if checkpoint_level > 1:
1207
+ del hk
1208
+ del hv
1209
+ hk, hv = None, None
1210
+ else:
1211
+ hk0, hv0 = None, None
1212
+
1213
+ ctx.save_for_backward(q, k, v, s, g, ok, p, hk, hv, Av, hk0, hv0)
1214
+ ctx.checkpoint_level = checkpoint_level
1215
+ ctx.scale = scale
1216
+ ctx.BT = BT
1217
+ return ov, (hkt, hvt)
1218
+
1219
+ @staticmethod
1220
+ @contiguous
1221
+ def backward(ctx, dov, dht=None):
1222
+ q, k, v, s, g, ok, p, hk, hv, Av, hk0, hv0 = ctx.saved_tensors
1223
+ qv = p.to(q.dtype)
1224
+ B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
1225
+ BT, BC = ctx.BT, 16
1226
+ BK = min(64, triton.next_power_of_2(K))
1227
+ BV = min(64, triton.next_power_of_2(V))
1228
+ BM = min(64, triton.next_power_of_2(M))
1229
+
1230
+ if ctx.checkpoint_level >= 1:
1231
+ g = chunk_local_cumsum(g, BT)
1232
+
1233
+ # rerun the forward pass to get h if checkpoint_level >= 1
1234
+ if ctx.checkpoint_level > 1:
1235
+ hk = fwd_inner(
1236
+ q=q, k=k, v=s, g=g,
1237
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM,
1238
+ gatek=False,
1239
+ h0=hk0,
1240
+ ht=None
1241
+ )
1242
+ hv = fwd_inner(
1243
+ q=qv, k=s, v=v, g=g,
1244
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV,
1245
+ gatek=True,
1246
+ h0=hv0,
1247
+ ht=None
1248
+ )
1249
+
1250
+ dqv, dsv, dv, dg = bwd_v(
1251
+ q=qv, k=s, v=v, g=g, h=hv, A=Av, do=dov, dg=None,
1252
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, BC=BC,
1253
+ scale=1.
1254
+ )
1255
+
1256
+ # softmax gradient, equivalent to:
1257
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
1258
+ dok = torch.empty_like(ok)
1259
+ def grid(meta): return (triton.cdiv(meta['T'], meta['BT']), p.shape[0] * p.shape[1])
1260
+ softmax_bwd_kernel[grid](
1261
+ p, dqv, dok,
1262
+ s.stride(1), s.stride(2), s.stride(3),
1263
+ T=T, S=M, BT=BT
1264
+ )
1265
+
1266
+ dq, dk, dsk, dg = bwd_k(
1267
+ q=q, k=k, v=s, g=g, h=hk, o=ok, do=dok, dg=dg,
1268
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, BC=BC,
1269
+ scale=ctx.scale
1270
+ )
1271
+
1272
+ ds = dsv.add_(dsk)
1273
+ # reversed cumsum, equivalent to:
1274
+ #
1275
+ # def reversed_cumsum(x, dim=-1):
1276
+ # c = x.cumsum(dim)
1277
+ # return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c
1278
+ dg = chunk_global_reversed_cumsum(dg).to(s.dtype)
1279
+ if q.shape[1] != H:
1280
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=H), (dk, dv, ds, dg))
1281
+ return dq, dk, dv, ds, dg, None, None, None, None, None
1282
+
1283
+
1284
+ def chunk_gated_abc(
1285
+ q: torch.Tensor,
1286
+ k: torch.Tensor,
1287
+ v: torch.Tensor,
1288
+ s: torch.Tensor,
1289
+ g: Optional[torch.Tensor] = None,
1290
+ scale: Optional[int] = None,
1291
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1292
+ output_final_state: Optional[bool] = False,
1293
+ checkpoint_level: Optional[int] = 2
1294
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1295
+ r"""
1296
+ Args:
1297
+ q (torch.Tensor):
1298
+ queries of shape `(B, HQ, T, K)`.
1299
+ k (torch.Tensor):
1300
+ keys of shape `(B, H, T, K)`. GQA is performed if `H` is not equal to `HQ`.
1301
+ v (torch.Tensor):
1302
+ values of shape `(B, H, T, V)`.
1303
+ g (torch.Tensor):
1304
+ Forget gates of shape `(B, H, T, M)` applied to keys.
1305
+ If not provided, this function is equivalent to vanilla ABC.
1306
+ scale (Optional[int]):
1307
+ Scale factor for attention scores.
1308
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1309
+ initial_state (Optional[Tuple[torch.Tensor]]):
1310
+ Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.
1311
+ output_final_state (Optional[bool]):
1312
+ Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.
1313
+ checkpoint_level (Optional[int]):
1314
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1315
+ Default: `2`:
1316
+ - Level `0`: no memory saved, no recomputation.
1317
+ - Level `1`: recompute the fp32 cumulative values during backward.
1318
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1319
+ """
1320
+ assert checkpoint_level in [0, 1, 2]
1321
+ if g is None:
1322
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1323
+ z = s.float().logcumsumexp(2)
1324
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
1325
+ s = torch.exp(s - z).to(k.dtype)
1326
+ if scale is None:
1327
+ scale = q.shape[-1] ** -0.5
1328
+
1329
+ hk0, hv0 = None, None
1330
+ if initial_state is not None:
1331
+ hk0, hv0 = initial_state
1332
+ ov, final_state = ChunkGatedABCFunction.apply(q, k, v, s, g, scale, hk0, hv0, output_final_state, checkpoint_level)
1333
+ return ov, final_state
opencompass/models/fla2/ops/abc/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_abc(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
69
+
70
+
71
+ def naive_cumsum_abc(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ s: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ """
78
+ A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
79
+ This is just for demonstration purposes, with no numerical stabilities guaranteed.
80
+ """
81
+
82
+ dtype = q.dtype
83
+ q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
84
+
85
+ scale = q.shape[-1] ** -0.5
86
+ # [batch_size, n_heads, seq_len, n_slots]
87
+ s = (s - s.max(2, True)[0]).exp()
88
+ z = s.cumsum(2)
89
+ # [batch_size, n_heads, seq_len, n_slots, d_head]
90
+ K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
91
+ V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
92
+ # [batch_size, n_heads, seq_len, n_slots]
93
+ p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
94
+ # [batch_size, n_heads, seq_len, d_head]
95
+ o = torch.einsum('...m,...md->...d', p, V)
96
+ return o.to(dtype), None
opencompass/models/fla2/ops/abc/recurrent_fuse.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2024, Yu Zhang, Songlin Yang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
12
+
13
+
14
+ @triton.jit
15
+ def fused_recurrent_gated_abc_inference_kernel(
16
+ q,
17
+ k,
18
+ v,
19
+ s,
20
+ g,
21
+ o,
22
+ hk0,
23
+ hv0,
24
+ hkt,
25
+ hvt,
26
+ scale,
27
+ K: tl.constexpr,
28
+ V: tl.constexpr,
29
+ M: tl.constexpr,
30
+ BK: tl.constexpr,
31
+ BV: tl.constexpr,
32
+ NG: tl.constexpr
33
+ ):
34
+ i_bh = tl.program_id(0)
35
+ i_bg = i_bh // NG
36
+
37
+ b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32)
38
+ b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32)
39
+ b_g = tl.exp(b_g)
40
+
41
+ b_ok = tl.zeros([M], dtype=tl.float32)
42
+ for i_k in range(tl.cdiv(K, BK)):
43
+ o_k = i_k * BK + tl.arange(0, BK)
44
+
45
+ p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None]
46
+ # [BK,]
47
+ mask_k = o_k < K
48
+ # [M, BK]
49
+ mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :]
50
+ # [M, BK]
51
+ b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32)
52
+ # [BK,]
53
+ b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale
54
+ b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32)
55
+ b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None]
56
+ b_ok += tl.sum(b_hk * b_q[None, :], axis=1)
57
+
58
+ if i_bh % NG == 0:
59
+ p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None]
60
+ tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk)
61
+
62
+ b_qv = tl.softmax(b_ok)
63
+ for i_v in range(tl.cdiv(V, BV)):
64
+ o_v = i_v * BV + tl.arange(0, BV)
65
+
66
+ p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
67
+ # [BV,]
68
+ mask_v = o_v < V
69
+ # [BV, M]
70
+ mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :]
71
+ # [BV, M]
72
+ b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32)
73
+ # [BV,]
74
+ b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32)
75
+ b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None]
76
+ b_ov = tl.sum(b_hv * b_qv[None, :], axis=1)
77
+
78
+ tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v)
79
+
80
+ if i_bh % NG == 0:
81
+ p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
82
+ tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv)
83
+
84
+
85
+ @triton.jit
86
+ def fused_recurrent_gated_abc_fwd_kernel(
87
+ q,
88
+ k,
89
+ v,
90
+ gk,
91
+ gv,
92
+ o,
93
+ h0,
94
+ ht,
95
+ s_k_h,
96
+ s_v_h,
97
+ scale,
98
+ B: tl.constexpr,
99
+ H: tl.constexpr,
100
+ T: tl.constexpr,
101
+ K: tl.constexpr,
102
+ V: tl.constexpr,
103
+ BK: tl.constexpr,
104
+ BV: tl.constexpr,
105
+ USE_INITIAL_STATE: tl.constexpr,
106
+ STORE_FINAL_STATE: tl.constexpr,
107
+ REVERSE: tl.constexpr,
108
+ USE_GK: tl.constexpr,
109
+ USE_GV: tl.constexpr
110
+ ):
111
+ # indices
112
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
113
+
114
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
115
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
116
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
117
+ p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
118
+
119
+ if USE_GK:
120
+ p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
121
+ if USE_GV:
122
+ p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
123
+
124
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
125
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
126
+
127
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
128
+ mask_h = mask_k[None, :] & mask_v[:, None]
129
+
130
+ if USE_INITIAL_STATE:
131
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
132
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
133
+
134
+ for _ in range(0, T):
135
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
136
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
137
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
138
+ if USE_GK:
139
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
140
+ b_h = b_h * tl.exp(b_gk)[None, :]
141
+ if USE_GV:
142
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
143
+ b_h = b_h * tl.exp(b_gv)[:, None]
144
+ b_h += b_k[None, :] * b_v[:, None]
145
+ b_o = b_h * b_q[None, :]
146
+ b_o = tl.sum(b_o, axis=1)
147
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
148
+ p_q += -K if REVERSE else K
149
+ p_k += -K if REVERSE else K
150
+ p_o += -V if REVERSE else V
151
+ p_v += -V if REVERSE else V
152
+ if USE_GK:
153
+ p_gk += -K if REVERSE else K
154
+ if USE_GV:
155
+ p_gv += -V if REVERSE else V
156
+
157
+ if STORE_FINAL_STATE:
158
+ p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
159
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
160
+
161
+
162
+ @triton.jit
163
+ def fused_recurrent_gated_abc_bwd_kernel(
164
+ q,
165
+ k,
166
+ v,
167
+ gk,
168
+ gv,
169
+ do,
170
+ dq,
171
+ dk,
172
+ dv,
173
+ dh0,
174
+ h0,
175
+ s_k_h,
176
+ s_v_h,
177
+ scale,
178
+ B: tl.constexpr,
179
+ H: tl.constexpr,
180
+ T: tl.constexpr,
181
+ K: tl.constexpr,
182
+ V: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_INITIAL_STATE: tl.constexpr,
186
+ REVERSE: tl.constexpr,
187
+ USE_GK: tl.constexpr,
188
+ USE_GV: tl.constexpr,
189
+ ):
190
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
191
+
192
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
193
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
194
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
195
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
196
+ p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
197
+ if USE_GK:
198
+ p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
199
+ if USE_GV:
200
+ p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
201
+ mask_k = i_k * BK + tl.arange(0, BK) < K
202
+ mask_v = i_v * BV + tl.arange(0, BV) < V
203
+ mask_h = mask_k[:, None] & mask_v[None, :]
204
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
205
+
206
+ if USE_INITIAL_STATE:
207
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
208
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
209
+
210
+ for _ in range(0, T):
211
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
212
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
213
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
214
+ if USE_GK:
215
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
216
+ b_h = b_h * tl.exp(b_gk)[:, None]
217
+ if USE_GV:
218
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
219
+ b_h = b_h * tl.exp(b_gv)[None, :]
220
+ b_h += b_k[:, None] * b_v[None, :]
221
+ b_dq = tl.sum(b_h * b_do[None, :], axis=1) * scale
222
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
223
+
224
+ p_k += -K if REVERSE else K
225
+ p_v += -V if REVERSE else V
226
+ p_q += -K if REVERSE else K
227
+ p_do += -V if REVERSE else V
228
+ p_dq += -K if REVERSE else K
229
+ if USE_GK:
230
+ p_gk += -K if REVERSE else K
231
+ if USE_GV:
232
+ p_gv += -V if REVERSE else V
233
+
234
+ # sync threads
235
+ tl.debug_barrier()
236
+
237
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
238
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
239
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
240
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
241
+ p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
242
+ p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
243
+ if USE_GK:
244
+ p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
245
+ if USE_GV:
246
+ p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
247
+
248
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
249
+ for _ in range(T):
250
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
251
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
252
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
253
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
254
+ b_dh += b_q[:, None] * b_do[None, :]
255
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
256
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
257
+ if USE_GK:
258
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
259
+ b_dh *= tl.exp(b_gk)[:, None]
260
+ if USE_GV:
261
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
262
+ b_dh *= tl.exp(b_gv)[None, :]
263
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
264
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
265
+
266
+ p_q += K if REVERSE else -K
267
+ p_k += K if REVERSE else -K
268
+ p_v += V if REVERSE else -V
269
+ p_do += V if REVERSE else -V
270
+ p_dk += K if REVERSE else -K
271
+ p_dv += V if REVERSE else -V
272
+ if USE_GK:
273
+ p_gk += K if REVERSE else -K
274
+ if USE_GV:
275
+ p_gv += V if REVERSE else -V
276
+
277
+ if USE_INITIAL_STATE:
278
+ p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
279
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
280
+
281
+
282
+ class FusedRecurrentGatedABCFunction(torch.autograd.Function):
283
+
284
+ @staticmethod
285
+ @contiguous
286
+ @autocast_custom_fwd
287
+ def forward(
288
+ ctx,
289
+ q: torch.Tensor,
290
+ k: torch.Tensor,
291
+ v: torch.Tensor,
292
+ s: torch.Tensor,
293
+ g: torch.Tensor,
294
+ scale: Optional[float] = None,
295
+ hk0: Optional[torch.Tensor] = None,
296
+ hv0: Optional[torch.Tensor] = None,
297
+ output_final_state: bool = False,
298
+ reverse: bool = False,
299
+ inference_mode: bool = False
300
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
301
+ B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
302
+ HQ = q.shape[1]
303
+
304
+ BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)
305
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
306
+ NG = HQ // H
307
+ num_warps = 1
308
+ num_stages = 1
309
+
310
+ hkt, hvt = None, None
311
+ if output_final_state:
312
+ hkt, hvt = (hk0, hv0) if inference_mode and NG == 1 else (q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float))
313
+
314
+ if inference_mode:
315
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 16)
316
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
317
+
318
+ o = v.new_empty(B, HQ, T, V)
319
+ grid = (B * HQ,)
320
+ fused_recurrent_gated_abc_inference_kernel[grid](
321
+ q, k, v, s, g, o, hk0, hv0, hkt, hvt,
322
+ scale=scale,
323
+ K=K, V=V, M=M, BK=BK, BV=BV, NG=NG,
324
+ num_warps=num_warps,
325
+ num_stages=num_stages
326
+ )
327
+ return o, (hkt, hvt)
328
+
329
+ ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)
330
+ gk, gv = None, g
331
+ grid = (NM, NK, B * H)
332
+ fused_recurrent_gated_abc_fwd_kernel[grid](
333
+ q, k, s, gk, gv, ok, hk0, hkt,
334
+ k.stride(1),
335
+ s.stride(1),
336
+ scale=scale,
337
+ B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
338
+ USE_INITIAL_STATE=hk0 is not None,
339
+ STORE_FINAL_STATE=hkt is not None,
340
+ USE_GK=False,
341
+ USE_GV=True,
342
+ REVERSE=reverse,
343
+ num_warps=num_warps,
344
+ num_stages=num_stages
345
+ )
346
+ ok = ok.sum(0)
347
+
348
+ qv = ok.softmax(-1, dtype=torch.float)
349
+ ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)
350
+ gk, gv = g, None
351
+ grid = (NV, NM, B * H)
352
+ fused_recurrent_gated_abc_fwd_kernel[grid](
353
+ qv, s, v, gk, gv, ov, hv0, hvt,
354
+ s.stride(1),
355
+ v.stride(1),
356
+ scale=1.,
357
+ B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
358
+ USE_INITIAL_STATE=hv0 is not None,
359
+ STORE_FINAL_STATE=hvt is not None,
360
+ USE_GK=True,
361
+ USE_GV=False,
362
+ REVERSE=reverse,
363
+ num_warps=num_warps,
364
+ num_stages=num_stages
365
+ )
366
+ ov = ov.sum(0)
367
+
368
+ ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok)
369
+ ctx.scale = scale
370
+ ctx.reverse = reverse
371
+ return ov.to(q.dtype), (hkt, hvt)
372
+
373
+
374
+ @staticmethod
375
+ @contiguous
376
+ @autocast_custom_bwd
377
+ def backward(ctx, do, dht=None):
378
+ q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors
379
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
380
+ scale = ctx.scale
381
+
382
+ BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)
383
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
384
+ num_warps = 1
385
+ num_stages = 1
386
+
387
+ dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
388
+ dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
389
+ dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)
390
+ dhk0 = torch.empty_like(hk0)if hk0 is not None else None
391
+ dhv0 = torch.empty_like(hv0)if hv0 is not None else None
392
+
393
+ gk, gv = g, None
394
+ grid = (NV, NM, B * H)
395
+ fused_recurrent_gated_abc_bwd_kernel[grid](
396
+ qv, s, v, gk, gv, do, dqv, dsv, dv, dhv0, hv0,
397
+ s.stride(1),
398
+ v.stride(1),
399
+ scale=1.,
400
+ B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
401
+ USE_INITIAL_STATE=hv0 is not None,
402
+ REVERSE=ctx.reverse,
403
+ USE_GK=gk is not None,
404
+ USE_GV=gv is not None,
405
+ num_warps=num_warps,
406
+ num_stages=num_stages
407
+ )
408
+ dqv = dqv.sum(0)
409
+ dsv = dsv.sum(0)
410
+ dv = dv.sum(0)
411
+ dgk = dqv * qv.float() - dsv * s.float()
412
+ dgk_cumsum = dgk.cumsum(-2)
413
+ dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum
414
+
415
+ dok = qv * (dqv - (qv * dqv).sum(-1, True))
416
+ dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)
417
+ dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)
418
+ dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)
419
+ gk, gv = None, g
420
+ grid = (NM, NK, B * H)
421
+ fused_recurrent_gated_abc_bwd_kernel[grid](
422
+ q, k, s, gk, gv, dok, dq, dk, dsk, dhk0, hk0,
423
+ q.stride(1),
424
+ s.stride(1),
425
+ scale=scale,
426
+ B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
427
+ USE_INITIAL_STATE=hk0 is not None,
428
+ REVERSE=ctx.reverse,
429
+ USE_GK=gk is not None,
430
+ USE_GV=gv is not None,
431
+ num_warps=num_warps,
432
+ num_stages=num_stages
433
+ )
434
+ dq = dq.sum(0)
435
+ dk = dk.sum(0)
436
+ dsk = dsk.sum(0)
437
+
438
+ dgv = dok.float() * ok.float() - dsk * s.float()
439
+ dgv_cumsum = dgv.cumsum(-2)
440
+ dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum
441
+
442
+ ds = dsk.add_(dsv)
443
+ dg = dgk.add_(dgv)
444
+
445
+ return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None
446
+
447
+
448
+ def fused_recurrent_gated_abc(
449
+ q: torch.Tensor,
450
+ k: torch.Tensor,
451
+ v: torch.Tensor,
452
+ s: torch.Tensor,
453
+ g: Optional[torch.Tensor] = None,
454
+ scale: Optional[int] = None,
455
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
456
+ output_final_state: Optional[bool] = False
457
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
458
+ r"""
459
+ Args:
460
+ q (torch.Tensor):
461
+ queries of shape `(B, H, T, K)`
462
+ k (torch.Tensor):
463
+ keys of shape `(B, H, T, K)`
464
+ v (torch.Tensor):
465
+ values of shape `(B, H, T, V)`
466
+ g (torch.Tensor):
467
+ Forget gates of shape `(B, H, T, M)` applied to keys.
468
+ If not provided, this function is equivalent to vanilla ABC.
469
+ scale (Optional[int]):
470
+ Scale factor for attention scores.
471
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
472
+ initial_state (Optional[Tuple[torch.Tensor]]):
473
+ Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.
474
+ output_final_state (Optional[bool]):
475
+ Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.
476
+ """
477
+ if g is None:
478
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
479
+ z = s.float().logcumsumexp(2)
480
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
481
+ s = torch.exp(s - z).to(k.dtype)
482
+ if scale is None:
483
+ scale = q.shape[-1] ** -0.5
484
+ if initial_state is None:
485
+ initial_state = (None, None)
486
+ inference_mode = q.shape[2] == 1 and not q.requires_grad
487
+ ov, final_state = FusedRecurrentGatedABCFunction.apply(
488
+ q, k, v, s, g, scale, *initial_state, output_final_state, False, inference_mode
489
+ )
490
+ return ov, final_state
opencompass/models/fla2/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk_fuse import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
opencompass/models/fla2/ops/based/chunk_fuse.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
10
+
11
+ # on-the-fly computation without materializing hidden statets into HBMs
12
+
13
+
14
+ @triton.jit
15
+ def fused_chunk_based_fwd_kernel(
16
+ q, # query [B, H, L, K]
17
+ k, # key [B, H, L, V]
18
+ v, # value [B, H, L, V]
19
+ o, # output [B, H, L, V]
20
+ z, # normalizer [B, H, L, 1]
21
+ s_qk_h, # stride size: L * K
22
+ s_qk_t, # stride size: K
23
+ s_qk_d, # stride size: 1
24
+ s_vo_h, # stride size: L * V
25
+ s_vo_t, # stride size: V
26
+ s_vo_d, # stride size: 1
27
+ scale, # K ** -0.5
28
+ B: tl.constexpr, # batch size
29
+ H: tl.constexpr, # H
30
+ T: tl.constexpr, # T
31
+ K: tl.constexpr, # K
32
+ V: tl.constexpr, # V
33
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
34
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
35
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
36
+ ):
37
+ # indices
38
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+
40
+ o_i = tl.arange(0, BT)
41
+
42
+ # [BT, BT]
43
+ m_s = o_i[:, None] >= o_i[None, :]
44
+
45
+ # [BV], zero-order taylor expansion
46
+ b_h_0o = tl.zeros([BV], dtype=tl.float32)
47
+ # [BK, BV], first-order taylor expansion
48
+ b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
49
+ # [BK, BK, BV] second-order taylor expansion
50
+ b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
51
+
52
+ # make block pointers
53
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
54
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
55
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
56
+ p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
57
+
58
+ p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
59
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
60
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
61
+ k_0o = 0
62
+
63
+ for i in range(0, tl.cdiv(T, BT)):
64
+ # [BK, BT]
65
+ b_k = tl.load(p_k, boundary_check=(0, 1))
66
+ # [BK*BK, BT]
67
+ b_k_2o = b_k[:, None, :] * b_k[None, :, :]
68
+ b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
69
+ # [BT, BV]
70
+ b_v = tl.load(p_v, boundary_check=(0, 1))
71
+ # [BT, BK]
72
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
73
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
74
+ b_z = tl.zeros([BT], dtype=tl.float32)
75
+
76
+ # interchunk
77
+ # zero-order
78
+ b_o += b_h_0o
79
+ b_z += k_0o
80
+ # first-order
81
+ b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
82
+ b_z += tl.sum(b_q * k_1o, axis=1)
83
+ # second-order
84
+ b_q_2o = b_q[:, :, None] * b_q[:, None, :]
85
+ b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
86
+ b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
87
+ b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
88
+
89
+ # update running statistics
90
+ k_1o += tl.sum(b_k, axis=1)[None, :]
91
+ k_2o += tl.sum(b_k_2o, axis=1)[None, :]
92
+ k_0o += BT
93
+
94
+ # intrachunk
95
+ # [BT, BT]
96
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
97
+ b_s = 1 + b_s + 0.5 * b_s * b_s
98
+ b_s = tl.where(m_s, b_s, 0)
99
+ b_z += tl.sum(b_s, axis=1)
100
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
101
+ # [TB, BV]
102
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
103
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)
104
+
105
+ # update hidden state
106
+ # [BK, BV]
107
+ b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
108
+ b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
109
+ b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
110
+
111
+ p_q = tl.advance(p_q, (BT, 0))
112
+ p_k = tl.advance(p_k, (0, BT))
113
+ p_v = tl.advance(p_v, (BT, 0))
114
+ p_o = tl.advance(p_o, (BT, 0))
115
+ p_z += BT
116
+
117
+
118
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
119
+ @triton.jit
120
+ def fused_chunk_based_bwd_kernel(
121
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
122
+ q, # query [B, H, L, K]
123
+ k, # key [B, H, L, V]
124
+ v, # value [B, H, L, V]
125
+ do, # gradient of output [B, H, L, V]
126
+ dz, # gradient of normalizer [B, H, L]
127
+ dq, # gradient of query [NV, B, H, L, K]
128
+ dk, # gradient of key [NV, B, H, L, K]
129
+ dv, # gradient of value [NK, B, H, L, V]
130
+ s_qk_h, # stride size: L * K
131
+ s_qk_t, # stride size: K
132
+ s_qk_d, # stride size: 1
133
+ s_vo_h, # stride size: L * V
134
+ s_vo_t, # stride size: V
135
+ s_vo_d, # stride size: 1
136
+ scale, # K ** -0.5
137
+ B: tl.constexpr, # B
138
+ H: tl.constexpr, # H
139
+ T: tl.constexpr, # T
140
+ K: tl.constexpr, # K
141
+ V: tl.constexpr, # V
142
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
143
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
144
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
145
+ ):
146
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
147
+
148
+ o_i = tl.arange(0, BT)
149
+ m_s = o_i[:, None] >= o_i[None, :]
150
+
151
+ # [BV], zero-order taylor expansion
152
+ # b_h_0o = tl.zeros([BV], dtype=tl.float32)
153
+ # [BK, BV], first-order taylor expansion
154
+ b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
155
+ # [BK, BK, BV] second-order taylor expansion
156
+ b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
157
+
158
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
159
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
160
+
161
+ for i in range(0, tl.cdiv(T, BT)):
162
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
163
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
164
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
165
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
166
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
167
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
168
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
169
+
170
+ # load tensors
171
+ # [BT, BK]
172
+ b_q = tl.load(p_q, boundary_check=(0, 1))
173
+ b_q = (b_q * scale).to(b_q.dtype)
174
+ b_k = tl.load(p_k, boundary_check=(0, 1))
175
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
176
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
177
+ # [BV, BT]
178
+ b_v = tl.load(p_v, boundary_check=(0, 1))
179
+
180
+ # inter-chunk
181
+ b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
182
+ if i_v == 0:
183
+ b_dq += b_dz[:, None] * k_1o
184
+ b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
185
+ if i_v == 0:
186
+ b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
187
+ b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
188
+ b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
189
+ b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
190
+ b_dq *= scale
191
+
192
+ # intra-chunk
193
+ # [BT, BT]
194
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
195
+ if i_v == 0:
196
+ b_ds += b_dz[:, None]
197
+ b_ds = tl.where(m_s, b_ds, 0) * scale
198
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
199
+ b_s = tl.where(m_s, b_s, 0)
200
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
201
+
202
+ # store
203
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
204
+
205
+ # update hidden state
206
+ # [BT, BK*BK]
207
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
208
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
209
+ # [BV, BK*BK]
210
+ b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
211
+ # [BV, BK]
212
+ b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
213
+
214
+ if i_v == 0:
215
+ # update running statistics
216
+ k_1o += tl.sum(b_k, axis=0)[None, :]
217
+ k_2o += tl.sum(b_k_2o, axis=0)[None, :]
218
+
219
+ tl.debug_barrier()
220
+ b_h_1o = None
221
+ b_h_2o = None
222
+
223
+ # [BK, BV], first-order taylor expansion
224
+ b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
225
+ # [BK, BK, BV] second-order taylor expansion
226
+ b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
227
+ b_dh_0o = tl.zeros([BV], dtype=tl.float32)
228
+ m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
229
+
230
+ dq_1o = tl.zeros([1, BK], dtype=tl.float32)
231
+ dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
232
+
233
+ for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
234
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))
235
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))
236
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
237
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
238
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))
239
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))
240
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
241
+
242
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
243
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
244
+
245
+ b_q = tl.load(p_q, boundary_check=(0, 1))
246
+ b_k = tl.load(p_k, boundary_check=(0, 1))
247
+ b_v = tl.load(p_v, boundary_check=(0, 1))
248
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
249
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
250
+ b_q = (b_q * scale).to(b_k.dtype)
251
+
252
+ # intra chunk
253
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
254
+ if i_v == 0:
255
+ b_ds += b_dz[None, :]
256
+ b_ds = tl.where(m_s, b_ds, 0)
257
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
258
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
259
+ b_s = tl.where(m_s, b_s, 0)
260
+ b_s2 = tl.where(m_s, b_s2, 0)
261
+ b_ds *= (1+b_s)
262
+
263
+ b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
264
+ b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
265
+
266
+ # inter chunk
267
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
268
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
269
+
270
+ b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
271
+ b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
272
+ b_dv += b_dh_0o
273
+
274
+ b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
275
+
276
+ if i_v == 0:
277
+ b_dk += dq_1o
278
+
279
+ b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)
280
+ if i_v == 0:
281
+ b_dk_2o += dq_2o
282
+ b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
283
+ b_k_fp32 = tl.trans(b_k.to(tl.float32))
284
+ b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
285
+ b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
286
+ b_dk += tl.trans(b_dk2)
287
+
288
+ # hidden state update
289
+ b_dh_0o += tl.sum(b_do, axis=0)
290
+ b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
291
+ b_q_2o = b_q[None, :, :] * b_q[:, None, :]
292
+ b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
293
+ b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
294
+
295
+ if i_v == 0:
296
+ dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
297
+ dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
298
+
299
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
300
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
301
+
302
+
303
+ class FusedChunkBasedFunction(torch.autograd.Function):
304
+
305
+ @staticmethod
306
+ @contiguous
307
+ @autocast_custom_fwd
308
+ def forward(ctx, q, k, v, scale=1):
309
+ B, H, T, K, V = *k.shape, v.shape[-1]
310
+
311
+ scale = scale
312
+ BT = 16
313
+ BK, BV = min(K, 16), min(V, 32)
314
+ BK, BV = max(BK, 16), max(BV, 16)
315
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
316
+
317
+ num_warps = 4
318
+
319
+ # the norm of o might explode, so we need to use float32 here
320
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
321
+ z = q.new_empty(NK, B, H, T, dtype=torch.float32)
322
+
323
+ grid = (NV, NK, B * H)
324
+ fused_chunk_based_fwd_kernel[grid](
325
+ q, k, v, o, z,
326
+ q.stride(1), q.stride(2), q.stride(3),
327
+ v.stride(1), v.stride(2), v.stride(3),
328
+ scale,
329
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
330
+ num_warps=num_warps,
331
+ )
332
+ o = o.sum(0)
333
+ z = z.sum(0)
334
+ ctx.save_for_backward(q, k, v)
335
+ ctx.scale = scale
336
+ return o.to(q.dtype), z.to(z.dtype)
337
+
338
+ @staticmethod
339
+ @contiguous
340
+ @autocast_custom_bwd
341
+ def backward(ctx, do, dz):
342
+ q, k, v = ctx.saved_tensors
343
+ B, H, T, K, V = *k.shape, v.shape[-1]
344
+ scale = ctx.scale
345
+
346
+ BT = 16
347
+ BK, BV = min(K, 16), min(V, 32)
348
+ BK, BV = max(BK, 16), max(BV, 16)
349
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
350
+ num_stages = 1
351
+ num_warps = 4
352
+
353
+ dq = q.new_empty(NV, B, H, T, K)
354
+ dk = q.new_empty(NV, B, H, T, K)
355
+ dv = q.new_empty(NK, B, H, T, V)
356
+ grid = (NV, NK, B * H)
357
+
358
+ fused_chunk_based_bwd_kernel[grid](
359
+ q, k, v, do, dz, dq, dk, dv,
360
+ q.stride(1), q.stride(2), q.stride(3),
361
+ v.stride(1), v.stride(2), v.stride(3),
362
+ scale,
363
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
364
+ num_warps=num_warps,
365
+ num_stages=num_stages
366
+ )
367
+ dq = dq.sum(0)
368
+ dk = dk.sum(0)
369
+ dv = dv.sum(0)
370
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
371
+
372
+
373
+ triton_fused_chunk_based = FusedChunkBasedFunction.apply
374
+
375
+
376
+ def fused_chunk_based(
377
+ q: torch.Tensor,
378
+ k: torch.Tensor,
379
+ v: torch.Tensor,
380
+ scale: Optional[float] = None,
381
+ use_norm: bool = True
382
+ ):
383
+ assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
384
+ if scale is None:
385
+ scale = q.shape[-1] ** -0.5
386
+ o, z = triton_fused_chunk_based(q, k, v, scale)
387
+ if use_norm:
388
+ o = o / (z[..., None] + 1e-6)
389
+ return o.to(q.dtype)
opencompass/models/fla2/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
opencompass/models/fla2/ops/based/parallel.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
11
+
12
+ # Based: An Educational and Effective Sequence Mixer
13
+ # https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
14
+
15
+
16
+ @triton.jit
17
+ def parallel_based_fwd_kernel(
18
+ q, # query [B, H, L, K]
19
+ k, # key [B, H, L, V]
20
+ v, # value [B, H, L, V]
21
+ o, # output [B, H, L, V]
22
+ z, # normalizer [B, H, L]
23
+ s_qk_h, # stride size: L * K
24
+ s_qk_t, # stride size: K
25
+ s_qk_d, # stride size: 1
26
+ s_vo_h, # stride size: L * V
27
+ s_vo_t, # stride size: V
28
+ s_vo_d, # stride size: 1
29
+ scale, # K ** -0.5
30
+ B: tl.constexpr, # batch size
31
+ H: tl.constexpr, # H
32
+ T: tl.constexpr, # T
33
+ K: tl.constexpr, # K
34
+ V: tl.constexpr, # V
35
+ BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
36
+ BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
37
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
38
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
39
+ ):
40
+ # i_c: chunk index. used for sequence parallelism
41
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
42
+ NV = tl.cdiv(V, BV)
43
+ i_k = i_kv // (NV)
44
+ i_v = i_kv % (NV)
45
+
46
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
47
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
48
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
49
+
50
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
51
+ b_q = tl.load(p_q, boundary_check=(0, 1))
52
+ b_q = (b_q * scale).to(b_q.dtype)
53
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
54
+ b_z = tl.zeros([BTL], dtype=tl.float32)
55
+
56
+ # Q block and K block have no overlap
57
+ # no need for mask, thereby saving flops
58
+ for _ in range(0, i_c * BTL, BTS):
59
+ # [BK, BTS]
60
+ b_k = tl.load(p_k, boundary_check=(0, 1))
61
+
62
+ # [BTS, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ # [BTL, BTS]
65
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
66
+ b_s = 1 + b_s + 0.5 * b_s * b_s
67
+ b_z += tl.sum(b_s, axis=1)
68
+
69
+ # [BQ, BD]
70
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
71
+ p_k = tl.advance(p_k, (0, BTS))
72
+ p_v = tl.advance(p_v, (BTS, 0))
73
+
74
+ # # rescale interchunk output
75
+ tl.debug_barrier()
76
+ o_q = tl.arange(0, BTL)
77
+ # # sync threads, easy for compiler to optimize
78
+ # tl.debug_barrier()
79
+
80
+ o_k = tl.arange(0, BTS)
81
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
82
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
83
+ # Q block and K block have overlap. masks required
84
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
85
+ # [BK, BTS]
86
+ b_k = tl.load(p_k, boundary_check=(0, 1))
87
+ # [BTS, BV]
88
+ b_v = tl.load(p_v, boundary_check=(0, 1))
89
+ # [BTL, BTS]
90
+ m_s = o_q[:, None] >= o_k[None, :]
91
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
92
+ b_s = 1 + b_s + 0.5 * b_s * b_s
93
+ b_s = tl.where(m_s, b_s, 0)
94
+ b_z += tl.sum(b_s, axis=1)
95
+ # [BTL, BV]
96
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
97
+
98
+ p_k = tl.advance(p_k, (0, BTS))
99
+ p_v = tl.advance(p_v, (BTS, 0))
100
+ o_k += BTS
101
+
102
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
103
+ p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
104
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
105
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))
106
+
107
+
108
+ @triton.jit
109
+ def _parallel_based_bwd_dq(
110
+ i_bh,
111
+ i_c,
112
+ i_k,
113
+ i_v,
114
+ i_h,
115
+ q,
116
+ k,
117
+ v,
118
+ do,
119
+ dz,
120
+ dq,
121
+ s_qk_h,
122
+ s_qk_t,
123
+ s_qk_d,
124
+ s_vo_h,
125
+ s_vo_t, s_vo_d, B, H, T, scale,
126
+ BTL: tl.constexpr,
127
+ BTS: tl.constexpr,
128
+ BK: tl.constexpr,
129
+ BV: tl.constexpr,
130
+ K: tl.constexpr,
131
+ V: tl.constexpr,
132
+ ):
133
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),
134
+ (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
135
+ p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, K),
136
+ (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
137
+ b_q = tl.load(p_q, boundary_check=(0, 1))
138
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
139
+ b_q = (b_q * scale).to(b_q.dtype)
140
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
141
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
142
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
143
+ p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
144
+ b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
145
+
146
+ for _ in range(0, i_c * BTL, BTS):
147
+ # [BTS, BK]
148
+ b_k = tl.load(p_k, boundary_check=(0, 1))
149
+ # [BV, BTS]
150
+ b_v = tl.load(p_v, boundary_check=(0, 1))
151
+ # [BTL, BTS]
152
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
153
+ if i_v == 0:
154
+ b_ds += b_dz[:, None]
155
+ else:
156
+ b_ds = b_ds
157
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
158
+ # [BQ, BD]
159
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
160
+ p_k = tl.advance(p_k, (BTS, 0))
161
+ p_v = tl.advance(p_v, (0, BTS))
162
+
163
+ b_dq *= scale
164
+ o_q = tl.arange(0, BTL)
165
+ o_k = tl.arange(0, BTS)
166
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
167
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
168
+ # Q block and K block have overlap. masks required
169
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
170
+ # [BTS, BK]
171
+ b_k = tl.load(p_k, boundary_check=(0, 1))
172
+ # [BV, BTS]
173
+ b_v = tl.load(p_v, boundary_check=(0, 1))
174
+ # [BTL, BTS]
175
+ m_s = o_q[:, None] >= o_k[None, :]
176
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
177
+ if i_v == 0:
178
+ b_ds += b_dz[:, None]
179
+ else:
180
+ b_ds = b_ds
181
+ b_ds = tl.where(m_s, b_ds, 0) * scale
182
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
183
+ b_s = tl.where(m_s, b_s, 0)
184
+ # [BTL, BK]
185
+ b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype),
186
+ b_k, allow_tf32=False)
187
+ p_k = tl.advance(p_k, (BTS, 0))
188
+ p_v = tl.advance(p_v, (0, BTS))
189
+ o_k += BTS
190
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, K),
191
+ (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
192
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
193
+ return
194
+
195
+
196
+ @triton.jit
197
+ def _parallel_based_bwd_dkv(
198
+ i_bh, i_c, i_k, i_v, i_h,
199
+ q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
200
+ s_vo_t, s_vo_d, B, H, T, scale,
201
+ BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
202
+ K: tl.constexpr, V: tl.constexpr,
203
+ ):
204
+ # compute dk dv
205
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
206
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
207
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
208
+ p_v, boundary_check=(0, 1))
209
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
210
+ [BTL, BV], dtype=tl.float32)
211
+
212
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
213
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
214
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
215
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
216
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
217
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
218
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
219
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
220
+ scale # [BTL, BTS]
221
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
222
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
223
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
224
+ if i_v == 0:
225
+ b_ds += b_dz[None, :] * scale
226
+ else:
227
+ b_ds = b_ds
228
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
229
+
230
+ tl.debug_barrier()
231
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
232
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
233
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
234
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
235
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
236
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
237
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
238
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
239
+ # [BK, BQ]
240
+ m_s = o_k[:, None] <= o_q[None, :]
241
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
242
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
243
+ b_s = tl.where(m_s, b_s, 0)
244
+ b_s2 = tl.where(m_s, b_s2, 0)
245
+
246
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
247
+ if i_v == 0:
248
+ b_ds += b_dz[None, :]
249
+ else:
250
+ b_ds = b_ds
251
+ b_ds = tl.where(m_s, b_ds, 0) * scale
252
+ # [BK, BD]
253
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
254
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
255
+ tl.trans(b_q), allow_tf32=False)
256
+ o_q += BTS
257
+
258
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, (T, K),
259
+ (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
260
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, (T, V),
261
+ (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
262
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
263
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
264
+ return
265
+
266
+
267
+ @triton.jit
268
+ def parallel_based_bwd_kernel(
269
+ q,
270
+ k,
271
+ v,
272
+ do,
273
+ dz,
274
+ dq,
275
+ dk,
276
+ dv,
277
+ s_qk_h,
278
+ s_qk_t,
279
+ s_qk_d,
280
+ s_vo_h,
281
+ s_vo_t,
282
+ s_vo_d,
283
+ scale,
284
+ B: tl.constexpr,
285
+ H: tl.constexpr,
286
+ T: tl.constexpr,
287
+ K: tl.constexpr,
288
+ V: tl.constexpr,
289
+ BTL: tl.constexpr,
290
+ BTS: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ ):
294
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
295
+ NV = tl.cdiv(V, BV)
296
+ i_k = i_kv // (NV)
297
+ i_v = i_kv % (NV)
298
+ i_h = i_bh % H
299
+ _parallel_based_bwd_dq(
300
+ i_bh, i_c, i_k, i_v, i_h,
301
+ q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
302
+ s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, K=K, V=V
303
+ )
304
+ tl.debug_barrier()
305
+ _parallel_based_bwd_dkv(
306
+ i_bh, i_c, i_k, i_v, i_h,
307
+ q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
308
+ s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, K, V
309
+ )
310
+
311
+
312
+ class ParallelBasedFunction(torch.autograd.Function):
313
+
314
+ @staticmethod
315
+ @contiguous
316
+ @autocast_custom_fwd
317
+ def forward(ctx, q, k, v, scale):
318
+ BTL, BTS = 128, 32
319
+ assert BTL % BTS == 0
320
+ # assert q.shape[-1] % 16 == 0
321
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
322
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
323
+ BK, BV = max(BK, 16), max(BV, 16)
324
+ B, H, T, K, V = *k.shape, v.shape[-1]
325
+ num_stages = 2
326
+ num_warps = 4
327
+ NK = triton.cdiv(K, BK)
328
+ NV = triton.cdiv(V, BV)
329
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
330
+
331
+ assert NK == 1, "will encounter some synchronization issue if not."
332
+
333
+ o = torch.empty(NK, B, H, T, V, device=q.device)
334
+ z = torch.empty(NK, B, H, T, device=q.device)
335
+ parallel_based_fwd_kernel[grid](
336
+ q, k, v, o, z,
337
+ q.stride(1), q.stride(2), q.stride(3),
338
+ v.stride(1), v.stride(2), v.stride(3),
339
+ scale,
340
+ B=B, H=H, T=T, K=K, V=V,
341
+ BTL=BTL, BTS=BTS, BK=BK, BV=BV,
342
+ num_warps=num_warps,
343
+ num_stages=num_stages
344
+ )
345
+ ctx.save_for_backward(q, k, v)
346
+ ctx.scale = scale
347
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
348
+
349
+ @staticmethod
350
+ @contiguous
351
+ @autocast_custom_bwd
352
+ def backward(ctx, do, dz):
353
+ q, k, v = ctx.saved_tensors
354
+ scale = ctx.scale
355
+ BTL, BTS = 64, 32
356
+ assert BTL % BTS == 0
357
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
358
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
359
+ BK, BV = max(BK, 16), max(BV, 16)
360
+ B, H, T, K, V = *k.shape, v.shape[-1]
361
+ num_stages = 2
362
+ num_warps = 4
363
+ NK = triton.cdiv(K, BK)
364
+ NV = triton.cdiv(V, BV)
365
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
366
+
367
+ assert NK == 1, "will encounter some synchronization issue if not"
368
+
369
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
370
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
371
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
372
+
373
+ parallel_based_bwd_kernel[grid](
374
+ q, k, v, do, dz, dq, dk, dv,
375
+ q.stride(1), q.stride(2), q.stride(3),
376
+ v.stride(1), v.stride(2), v.stride(3),
377
+ scale,
378
+ B=B, H=H, T=T, K=K, V=V,
379
+ BTL=BTL, BTS=BTS, BK=BK, BV=BV,
380
+ num_warps=num_warps,
381
+ num_stages=num_stages
382
+ )
383
+
384
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
385
+
386
+
387
+ triton_parallel_based = ParallelBasedFunction.apply
388
+
389
+
390
+ def parallel_based(
391
+ q: torch.Tensor,
392
+ k: torch.Tensor,
393
+ v: torch.Tensor,
394
+ scale: Optional[float] = None,
395
+ use_norm: bool = True
396
+ ):
397
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
398
+ if scale is None:
399
+ scale = q.shape[-1] ** -0.5
400
+ o, z = triton_parallel_based(q, k, v, scale)
401
+ if use_norm:
402
+ o = o / (z[..., None] + 1e-6)
403
+ return o.to(q.dtype)
opencompass/models/fla2/ops/common/chunk_h.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ import torch
4
+
5
+ @triton.autotune(
6
+ configs=[
7
+ triton.Config({}, num_warps=1),
8
+ triton.Config({}, num_warps=2),
9
+ triton.Config({}, num_warps=4),
10
+ triton.Config({}, num_warps=8),
11
+ ],
12
+ key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'],
13
+ )
14
+ @triton.jit
15
+ def chunk_fwd_kernel_h(
16
+ k,
17
+ v,
18
+ h,
19
+ g,
20
+ gk,
21
+ gv,
22
+ h0,
23
+ ht,
24
+ s_qk_h,
25
+ s_qk_t,
26
+ s_qk_d,
27
+ s_vo_h,
28
+ s_vo_t,
29
+ s_vo_d,
30
+ s_h_h,
31
+ s_h_t,
32
+ T: tl.constexpr,
33
+ K: tl.constexpr,
34
+ V: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ BV: tl.constexpr,
38
+ NT: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr,
40
+ STORE_FINAL_STATE: tl.constexpr,
41
+ USE_G: tl.constexpr,
42
+ USE_GK: tl.constexpr,
43
+ USE_GV: tl.constexpr
44
+ ):
45
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+
47
+ # [BK, BV]
48
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
49
+
50
+ if USE_INITIAL_STATE:
51
+ p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
52
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
53
+
54
+ for i_t in range(NT):
55
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
56
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
57
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
58
+
59
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
60
+ # [BK, BT]
61
+ b_k = tl.load(p_k, boundary_check=(0, 1))
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ last_idx = min((i_t + 1) * BT, T) - 1
65
+
66
+ # scalar decay
67
+ if USE_G:
68
+ b_g_last = tl.load(g + i_bh * T + last_idx)
69
+ b_h *= tl.exp(b_g_last)
70
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
71
+ b_g = tl.load(p_g, boundary_check=(0,))
72
+ b_v = (b_v * tl.exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
73
+
74
+ # vector decay, h = Diag(gk) @ h
75
+ if USE_GK:
76
+ p_gk_last = tl.make_block_ptr(gk + i_bh * s_qk_h, (T * K,), (s_qk_d,), (last_idx * K + i_k * BK,), (BK,), (0,))
77
+ b_gk_last = tl.load(p_gk_last, boundary_check=(0,))
78
+ b_h *= tl.exp(b_gk_last)[:, None]
79
+
80
+ p_gk = tl.make_block_ptr(gk + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
81
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
82
+ b_k = (b_k * tl.exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
83
+
84
+ # vector decay, h = h @ Diag(gv)
85
+ if USE_GV:
86
+ p_gv_last = tl.make_block_ptr(gv + i_bh * s_vo_h, (T * V,), (s_vo_d,), (last_idx * V + i_v * BV,), (BV,), (0,))
87
+ b_gv_last = tl.load(p_gv, boundary_check=(0,))
88
+ b_h *= tl.exp(b_gv_last)[None, :]
89
+
90
+ p_gv = tl.make_block_ptr(gv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
92
+ b_v = (b_v * tl.exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
93
+
94
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
95
+
96
+ if STORE_FINAL_STATE:
97
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
98
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
99
+
100
+
101
+ @triton.autotune(
102
+ configs=[
103
+ triton.Config({}, num_warps=1),
104
+ triton.Config({}, num_warps=2),
105
+ triton.Config({}, num_warps=4),
106
+ triton.Config({}, num_warps=8),
107
+ ],
108
+ key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'],
109
+ )
110
+ @triton.jit
111
+ def chunk_bwd_kernel_dh(
112
+ q,
113
+ g,
114
+ gk,
115
+ gv,
116
+ do,
117
+ dh,
118
+ dht,
119
+ dh0,
120
+ s_qk_h,
121
+ s_qk_t,
122
+ s_qk_d,
123
+ s_vo_h,
124
+ s_vo_t,
125
+ s_vo_d,
126
+ s_h_h,
127
+ s_h_t,
128
+ scale,
129
+ T: tl.constexpr,
130
+ K: tl.constexpr,
131
+ V: tl.constexpr,
132
+ BT: tl.constexpr,
133
+ BK: tl.constexpr,
134
+ BV: tl.constexpr,
135
+ NT: tl.constexpr,
136
+ USE_G: tl.constexpr,
137
+ USE_GK: tl.constexpr,
138
+ USE_GV: tl.constexpr,
139
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
140
+ LOAD_FINAL_STATE_GRADIENT: tl.constexpr
141
+ ):
142
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
143
+ # [BK, BV]
144
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
145
+ if LOAD_FINAL_STATE_GRADIENT:
146
+ p_dht = tl.make_block_ptr(dht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
148
+
149
+ for i_t in range(NT - 1, -1, -1):
150
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
151
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
152
+ last_idx = min(i_t * BT + BT, T) - 1
153
+ # [BK, BT]
154
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
155
+ b_q = tl.load(p_q, boundary_check=(0, 1))
156
+ b_q = (b_q * scale).to(b_q.dtype)
157
+ # [BT, BV]
158
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
159
+ b_do = tl.load(p_do, boundary_check=(0, 1))
160
+
161
+ if USE_G:
162
+ p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
163
+ b_g = tl.load(p_g, boundary_check=(0,))
164
+ b_q = (b_q * tl.exp(b_g)[None, :]).to(b_q.dtype)
165
+ b_g_last = tl.load(g + i_bh * T + last_idx)
166
+ b_dh *= tl.exp(b_g_last)
167
+
168
+ if USE_GK:
169
+ p_gk = tl.make_block_ptr(gk + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
170
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
171
+ b_q = (b_q * tl.exp(b_gk)).to(b_q.dtype)
172
+ p_gk_last = tl.make_block_ptr(gk + i_bh * s_qk_h, (T * K,), (s_qk_d,), (last_idx * K + i_k * BK,), (BK,), (0,))
173
+ b_gk_last = tl.load(p_gk_last, boundary_check=(0,))
174
+ b_dh *= tl.exp(b_gk_last)[:, None]
175
+
176
+ if USE_GV:
177
+ p_gv = tl.make_block_ptr(gv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
178
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
179
+ b_do = (b_do * tl.exp(b_gv)).to(b_do.dtype)
180
+ p_gv_last = tl.make_block_ptr(gv + i_bh * s_vo_h, (T * V,), (s_vo_d,), (last_idx * V + i_v * BV,), (BV,), (0,))
181
+ b_gv_last = tl.load(p_gv, boundary_check=(0,))
182
+ b_dh *= tl.exp(b_gv_last)[None, :]
183
+
184
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
185
+
186
+
187
+ if STORE_INITIAL_STATE_GRADIENT:
188
+ p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
189
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
190
+
191
+
192
+
193
+
194
+ def chunk_fwd_h_fn(k, v, g, gk, gv, BT, h0, output_final_state):
195
+ B, H, T, K, V = *k.shape, v.shape[-1]
196
+ ht = None
197
+ if output_final_state:
198
+ ht = k.new_empty(B, H, K, V, dtype=torch.float32)
199
+
200
+ BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
201
+ NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
202
+ h = k.new_empty(B, H, NT * K, V)
203
+ grid = (NK, NV, B * H)
204
+
205
+ USE_G, USE_GK, USE_GV = g is not None, gk is not None, gv is not None
206
+
207
+ chunk_fwd_kernel_h[grid](
208
+ k, v, h, g, gk, gv, h0, ht,
209
+ k.stride(1), k.stride(2), k.stride(3),
210
+ v.stride(1), v.stride(2), v.stride(3),
211
+ h.stride(1), h.stride(2),
212
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
213
+ USE_INITIAL_STATE=h0 is not None,
214
+ STORE_FINAL_STATE=output_final_state,
215
+ USE_G=USE_G, USE_GK=USE_GK, USE_GV=USE_GV
216
+ )
217
+ return h, ht
218
+
219
+
220
+
221
+ def chunk_bwd_dh_fn(q, k, v, g, gk, gv, do, h0, dht, BT, scale):
222
+ B, H, T, K, V = *k.shape, v.shape[-1]
223
+ BT = 64
224
+ BK = min(triton.next_power_of_2(K), 64)
225
+ BV = min(triton.next_power_of_2(V), 64)
226
+ NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
227
+ dh = k.new_empty(B, H, NT * K, V)
228
+ grid = (NK, NV, B * H)
229
+ if h0 is not None:
230
+ dh0 = torch.empty_like(h0, dtype=torch.float32)
231
+ else:
232
+ dh0 = None
233
+ USE_GATE = (g is not None) or (gk is not None) or (gv is not None)
234
+ assert not (USE_GATE and dht is not None), "Cannot load final state gradient and use gates at the same time"
235
+ chunk_bwd_kernel_dh[grid](
236
+ q, g, gk, gv, do, dh, dht, dh0,
237
+ q.stride(1), q.stride(2), q.stride(3),
238
+ v.stride(1), v.stride(2), v.stride(3),
239
+ dh.stride(1), dh.stride(2),
240
+ scale,
241
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
242
+ USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None,
243
+ STORE_INITIAL_STATE_GRADIENT=dh0 is not None,
244
+ LOAD_FINAL_STATE_GRADIENT=dht is not None
245
+ )
246
+ return dh, dh0
247
+
248
+
249
+
opencompass/models/fla2/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+ from typing import Tuple
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
9
+ from ...ops.utils import chunk_global_reversed_cumsum, chunk_global_cumsum
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=1),
14
+ triton.Config({}, num_warps=2),
15
+ triton.Config({}, num_warps=4),
16
+ triton.Config({}, num_warps=8)
17
+ ],
18
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
19
+ )
20
+ @triton.jit
21
+ def fused_recurrent_fwd_kernel(
22
+ # B: batch_size, H: n_heads, T: seq_len, D: d_head
23
+ q, # query [B, H, L, K]
24
+ k, # key [B, H, L, K]
25
+ v, # value [B, H, L, V]
26
+ g, # log gate [B, H, L] or None
27
+ gk, # log gate [B, H, L, K] or None
28
+ gv, # log gate [B, H, L, V] or None
29
+ o, # output [NK, B, H, L, V]
30
+ h0, # initial hidden state [B, H, K, V]
31
+ ht, # final hidden state [B, H, K, V]
32
+ s_qk_h, # stride size: L * K
33
+ s_vo_h, # stride size: L * V
34
+ scale, # K ** -0.5
35
+ B: tl.constexpr,
36
+ H: tl.constexpr,
37
+ T: tl.constexpr,
38
+ K: tl.constexpr,
39
+ V: tl.constexpr,
40
+ BK: tl.constexpr,
41
+ BV: tl.constexpr,
42
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
43
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
44
+ REVERSE: tl.constexpr, # whether to reverse the recurrence
45
+ USE_GK: tl.constexpr, # whether to use gk
46
+ USE_GV: tl.constexpr, # whether to use gv
47
+ USE_G: tl.constexpr, # whether to use g
48
+ ):
49
+ # indices
50
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+
52
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
53
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
54
+ p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
55
+ p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
56
+
57
+ if USE_G:
58
+ p_g = g + i_bh * T + ((T-1) if REVERSE else 0)
59
+ if USE_GK:
60
+ p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
61
+ if USE_GV:
62
+ p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
63
+
64
+ mask_bk = (i_k * BK + tl.arange(0, BK)) < K
65
+ mask_bv = (i_v * BV + tl.arange(0, BV)) < V
66
+ mask_kv = mask_bk[None, :] & mask_bv[:, None]
67
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
68
+
69
+ if USE_INITIAL_STATE:
70
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
71
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
72
+
73
+ for _ in range(0, T):
74
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
75
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
76
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
77
+ if USE_GK:
78
+ b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
79
+ b_h = b_h * tl.exp(b_gk[None, :])
80
+ if USE_GV:
81
+ b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
82
+ b_h = b_h * tl.exp(b_gv[:, None])
83
+ if USE_G:
84
+ b_g = tl.load(p_g).to(tl.float32)
85
+ b_h = b_h * tl.exp(b_g)
86
+ b_h += b_k[None, :] * b_v[:, None]
87
+ b_o = b_h * b_q[None, :]
88
+ b_o = tl.sum(b_o, axis=1)
89
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
90
+ p_q += -K if REVERSE else K
91
+ p_k += -K if REVERSE else K
92
+ p_o += -V if REVERSE else V
93
+ p_v += -V if REVERSE else V
94
+ if USE_GK:
95
+ p_gk += -K if REVERSE else K
96
+ if USE_GV:
97
+ p_gv += -V if REVERSE else V
98
+ if USE_G:
99
+ p_g += -1 if REVERSE else 1
100
+
101
+
102
+ if STORE_FINAL_STATE:
103
+ p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
104
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
105
+
106
+
107
+ @triton.autotune(
108
+ configs=[
109
+ triton.Config({}, num_warps=1),
110
+ triton.Config({}, num_warps=2),
111
+ triton.Config({}, num_warps=4),
112
+ triton.Config({}, num_warps=8)
113
+ ],
114
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
115
+ )
116
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
117
+ @triton.jit
118
+ def fused_recurrent_bwd_kernel(
119
+ # B: batch_size, H: n_heads, T: seq_len, D: d_head
120
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
121
+ q, # query [B, H, L, K]
122
+ k, # key [B, H, L, V]
123
+ v, # value [B, H, L, V]
124
+ g, # log gate [B, H, L]
125
+ gk, # log gate [B, H, L, K] \alpha
126
+ gv, # log gate [B, H, L, V] \bete
127
+ do, # gradient wrt output [B, H, L, V]
128
+ dq, # gradient wrt query [NV, B, H, L, K]
129
+ dk, # gradient wrt key [NV, B, H, L, K]
130
+ dv, # gradient wrt value [NK, B, H, L, V]
131
+ dht, # gradient wrt final hidden state [B, H, K, V]
132
+ dh0, # gradient wrt initial hidden state [B, H, K, V]
133
+ h0, # initial hidden state [B, H, K, V]
134
+ s_qk_h, # stride size: L * K
135
+ s_vo_h, # stride size: L * V
136
+ scale, # K ** -0.5
137
+ B,
138
+ H,
139
+ T,
140
+ K: tl.constexpr,
141
+ V: tl.constexpr,
142
+ BK: tl.constexpr,
143
+ BV: tl.constexpr,
144
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
145
+ REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
146
+ USE_GK: tl.constexpr, # whether to use gk
147
+ USE_GV: tl.constexpr, # whether to use gv
148
+ USE_G: tl.constexpr, # whether to use g
149
+ USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to compute gradient wrt final state
150
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr, # whether to store gradient wrt initial state
151
+ ):
152
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
153
+
154
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
155
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
156
+ p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
157
+ p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
158
+ p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
159
+ if USE_GK:
160
+ p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
161
+ if USE_GV:
162
+ p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
163
+ if USE_G:
164
+ p_g = g + i_bh * T + ((T-1) if REVERSE else 0)
165
+ mask_bk = i_k * BK + tl.arange(0, BK) < K
166
+ mask_bv = i_v * BV + tl.arange(0, BV) < V
167
+ mask_kv = mask_bk[:, None] & mask_bv[None, :]
168
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
169
+ if USE_INITIAL_STATE:
170
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
171
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
172
+ for i in range(0, T):
173
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
174
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
175
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
176
+ if USE_GK:
177
+ b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
178
+ b_h = b_h * tl.exp(b_gk[:, None])
179
+ if USE_GV:
180
+ b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
181
+ b_h = b_h * tl.exp(b_gv[None, :])
182
+ if USE_G:
183
+ b_g = tl.load(p_g).to(tl.float32)
184
+ b_h = b_h * tl.exp(b_g)
185
+ b_h += b_k[:, None] * b_v[None, :]
186
+ b_dq = b_h * b_do[None, :]
187
+ d_q = tl.sum(b_dq, axis=1) * scale
188
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
189
+
190
+ p_k += -K if REVERSE else K
191
+ p_v += -V if REVERSE else V
192
+ p_q += -K if REVERSE else K
193
+ p_do += -V if REVERSE else V
194
+ p_dq += -K if REVERSE else K
195
+ if USE_GK:
196
+ p_gk += -K if REVERSE else K
197
+ if USE_GV:
198
+ p_gv += -V if REVERSE else V
199
+ if USE_G:
200
+ p_g += -1 if REVERSE else 1
201
+
202
+ # sync threads
203
+ tl.debug_barrier()
204
+
205
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
206
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
207
+ p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
208
+ p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
209
+ p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
210
+ p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
211
+ if USE_GK:
212
+ p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
213
+ if USE_GV:
214
+ p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
215
+ if USE_G:
216
+ p_g = g + i_bh * T + ((T - 1) if not REVERSE else 0)
217
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
218
+ if USE_FINAL_STATE_GRADIENT:
219
+ p_dht = dht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
220
+ b_dh += tl.load(p_dht, mask=mask_kv, other=0).to(tl.float32)
221
+
222
+ for _ in range(T):
223
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
224
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
225
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
226
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
227
+ b_dh += b_q[:, None] * b_do[None, :]
228
+ d_k = tl.sum(b_dh * b_v[None, :], axis=1)
229
+ d_v = tl.sum(b_dh * b_k[:, None], axis=0)
230
+ if USE_GK:
231
+ b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
232
+ b_dh *= tl.exp(b_gk)[:, None]
233
+ if USE_GV:
234
+ b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
235
+ b_dh *= tl.exp(b_gv)[None, :]
236
+ if USE_G:
237
+ b_g = tl.load(p_g).to(tl.float32)
238
+ b_dh *= tl.exp(b_g)
239
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
240
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
241
+
242
+ p_q += K if REVERSE else -K
243
+ p_k += K if REVERSE else -K
244
+ p_v += V if REVERSE else -V
245
+ p_do += V if REVERSE else -V
246
+ p_dk += K if REVERSE else -K
247
+ p_dv += V if REVERSE else -V
248
+ if USE_GK:
249
+ p_gk += K if REVERSE else -K
250
+ if USE_GV:
251
+ p_gv += V if REVERSE else -V
252
+ if USE_G:
253
+ p_g += 1 if REVERSE else -1
254
+
255
+ if STORE_INITIAL_STATE_GRADIENT:
256
+ p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
257
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)
258
+
259
+
260
+
261
+ class FusedRecurrentFunction(torch.autograd.Function):
262
+
263
+ @staticmethod
264
+ @contiguous
265
+ @autocast_custom_fwd
266
+ def forward(ctx, q, k, v, g, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
267
+ B, H, T, K, V = *q.shape, v.shape[-1]
268
+ # default scale
269
+ if scale is None:
270
+ scale = K ** -0.5
271
+
272
+ BK, BV = min(K, 64), min(V, 64)
273
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
274
+
275
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
276
+
277
+ h0 = initial_state
278
+ if output_final_state:
279
+ ht = q.new_empty(B, H, K, V, dtype=torch.float32)
280
+ else:
281
+ ht = None
282
+
283
+ grid = (NV, NK, B * H)
284
+ fused_recurrent_fwd_kernel[grid](
285
+ q, k, v, g, gk, gv, o, h0, ht,
286
+ q.stride(1), v.stride(1),
287
+ scale,
288
+ B=B, H=H, T=T, K=K, V=V,
289
+ BK=BK, BV=BV,
290
+ USE_INITIAL_STATE=h0 is not None,
291
+ STORE_FINAL_STATE=ht is not None,
292
+ USE_GK=gk is not None,
293
+ USE_GV=gv is not None,
294
+ USE_G=g is not None,
295
+ REVERSE=reverse,
296
+ )
297
+
298
+ o = o.sum(0)
299
+ ctx.save_for_backward(q, k, v, g, gk, gv, h0, o)
300
+ ctx.scale = scale
301
+ ctx.reverse = reverse
302
+ return o.to(q.dtype), ht
303
+
304
+ @staticmethod
305
+ @contiguous
306
+ @autocast_custom_bwd
307
+ def backward(ctx, do, dht):
308
+ q, k, v, g, gk, gv, h0, o = ctx.saved_tensors
309
+ batch_size, n_heads, seq_len, K = q.shape
310
+ V = v.shape[-1]
311
+ scale = ctx.scale
312
+
313
+ BK, BV = min(K, 64), min(V, 64)
314
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
315
+
316
+ dq = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32)
317
+ dk = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32)
318
+ dv = q.new_empty(NK, batch_size, n_heads, seq_len, V, dtype=torch.float32)
319
+ dh0 = torch.empty_like(h0) if (h0 is not None) else None
320
+ grid = (NV, NK, batch_size * n_heads)
321
+
322
+ fused_recurrent_bwd_kernel[grid](
323
+ q, k, v, g, gk, gv, do, dq, dk, dv, dht, dh0, h0,
324
+ q.stride(1),
325
+ v.stride(1), scale,
326
+ B=batch_size, H=n_heads, T=seq_len, K=K, V=V, BK=BK, BV=BV,
327
+ USE_INITIAL_STATE=h0 is not None,
328
+ REVERSE=ctx.reverse,
329
+ USE_GK=gk is not None,
330
+ USE_GV=gv is not None,
331
+ USE_G=g is not None,
332
+ USE_FINAL_STATE_GRADIENT=dht is not None,
333
+ STORE_INITIAL_STATE_GRADIENT=dh0 is not None
334
+ )
335
+ dq = dq.sum(0)
336
+ dk = dk.sum(0)
337
+ dv = dv.sum(0)
338
+ fn = chunk_global_cumsum if ctx.reverse else chunk_global_reversed_cumsum
339
+ dgk = fn(dq * q.float() - dk * k.float()) if gk is not None else None
340
+ dgv = fn(do.float() * o.float() - dv * v.float()) if gv is not None else None
341
+ dg = fn((dq * q.float() - dk * k.float()).sum(-1)) if g is not None else None
342
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None
343
+
344
+
345
+ def fused_recurrent(q, k, v, g=None, gk=None, gv=None, scale=None, initial_state=None, output_final_state=False, reverse=False):
346
+ return FusedRecurrentFunction.apply(q, k, v, g, gk, gv, scale, initial_state, output_final_state, reverse)
opencompass/models/fla2/ops/delta_rule/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ - Delta Rule
2
+
3
+ The implementation of delta rule described in https://arxiv.org/abs/2102.11174
4
+
opencompass/models/fla2/ops/delta_rule/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_delta_rule
4
+ from .chunk_fuse import fused_chunk_delta_rule
5
+ from .recurrent_fuse import fused_recurrent_delta_rule
6
+
7
+ __all__ = [
8
+ 'fused_chunk_delta_rule',
9
+ 'fused_recurrent_delta_rule',
10
+ 'chunk_delta_rule'
11
+ ]
opencompass/models/fla2/ops/delta_rule/chunk.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023, Yu Zhang, Songlin Yang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ...ops.delta_rule.wy_fast import (bwd_prepare_wy_repr,
9
+ fwd_prepare_wy_repr, fwd_recompute_w_u)
10
+ from ...ops.utils import contiguous
11
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd
12
+
13
+
14
+ @triton.autotune(
15
+ configs=[
16
+ triton.Config({}, num_warps=1),
17
+ triton.Config({}, num_warps=2),
18
+ triton.Config({}, num_warps=4),
19
+ triton.Config({}, num_warps=8),
20
+ triton.Config({}, num_warps=16)
21
+ ],
22
+ key=["BT", "BK", "BV"],
23
+ )
24
+ @triton.jit
25
+ def fwd_prepare_dv_kernel(
26
+ q,
27
+ k,
28
+ do,
29
+ dv,
30
+ s_qk_h,
31
+ s_qk_t,
32
+ s_qk_d,
33
+ s_vo_h,
34
+ s_vo_t,
35
+ s_vo_d,
36
+ T,
37
+ K,
38
+ V,
39
+ scale,
40
+ BT: tl.constexpr,
41
+ BK: tl.constexpr,
42
+ BV: tl.constexpr
43
+ ):
44
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
45
+
46
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
47
+
48
+ for i_k in range(tl.cdiv(K, BK)):
49
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
50
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
51
+ b_k = tl.load(p_k, boundary_check=(0, 1))
52
+ b_q = tl.load(p_q, boundary_check=(0, 1))
53
+ b_q = (b_q * scale).to(b_k.dtype)
54
+ b_A += tl.dot(b_k, b_q, allow_tf32=False)
55
+
56
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty)
57
+
58
+ for i_v in range(tl.cdiv(V, BV)):
59
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
60
+ b_do = tl.load(p_do, boundary_check=(0, 1))
61
+ p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
62
+ b_dv = tl.dot(b_A, b_do, allow_tf32=False)
63
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
64
+
65
+
66
+ def fwd_prepare_dv(q, k, do, BT):
67
+ dv = torch.empty_like(do)
68
+ B, H, T, K, V = *k.shape, do.shape[-1]
69
+ NT = triton.cdiv(T, BT)
70
+ BK = min(triton.next_power_of_2(K), 64)
71
+ BV = min(triton.next_power_of_2(V), 64)
72
+ fwd_prepare_dv_kernel[(NT, B*H)](
73
+ q, k, do, dv,
74
+ k.stride(1), k.stride(2), k.stride(3),
75
+ do.stride(1), do.stride(2), do.stride(3),
76
+ T, K, V, K**-0.5, BT, BK, BV
77
+ )
78
+ return dv
79
+
80
+
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=1),
84
+ triton.Config({}, num_warps=2),
85
+ triton.Config({}, num_warps=4),
86
+ triton.Config({}, num_warps=8),
87
+ triton.Config({}, num_warps=16)
88
+ ],
89
+ key=["BT", "BK", "BV"],
90
+ )
91
+ @triton.jit
92
+ def chunk_delta_rule_fwd_kernel_h(
93
+ k,
94
+ v,
95
+ d,
96
+ v_new,
97
+ h,
98
+ initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
99
+ final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
100
+ s_qk_h,
101
+ s_qk_t,
102
+ s_qk_d,
103
+ s_vo_h,
104
+ s_vo_t,
105
+ s_vo_d,
106
+ s_h_h,
107
+ s_h_t,
108
+ H: tl.constexpr,
109
+ T: tl.constexpr,
110
+ K: tl.constexpr,
111
+ V: tl.constexpr,
112
+ BT: tl.constexpr,
113
+ BC: tl.constexpr,
114
+ BK: tl.constexpr,
115
+ BV: tl.constexpr,
116
+ NT: tl.constexpr,
117
+ USE_INITIAL_STATE: tl.constexpr,
118
+ STORE_FINAL_STATE: tl.constexpr
119
+ ):
120
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
121
+
122
+ # [BK, BV]
123
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
124
+
125
+ if USE_INITIAL_STATE:
126
+ p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
127
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
128
+
129
+ for i_t in range(NT):
130
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
131
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
132
+ b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
133
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
134
+ for i_c in range(tl.cdiv(BT, BC)):
135
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t),
136
+ (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
137
+ p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d),
138
+ (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
139
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),
140
+ (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
141
+ p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),
142
+ (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
143
+ b_k = tl.load(p_k, boundary_check=(0, 1))
144
+ # [BT, BK]
145
+ b_d = tl.load(p_d, boundary_check=(0, 1))
146
+ # [BT, BV]
147
+ b_v = tl.load(p_v, boundary_check=(0, 1))
148
+ b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
149
+ # [BK, BV]
150
+ tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
151
+ b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
152
+ b_h += b_h_cumsum
153
+
154
+ if STORE_FINAL_STATE:
155
+ p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
156
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
157
+
158
+
159
+ @triton.autotune(
160
+ configs=[
161
+ triton.Config({}, num_warps=1),
162
+ triton.Config({}, num_warps=2),
163
+ triton.Config({}, num_warps=4),
164
+ triton.Config({}, num_warps=8),
165
+ triton.Config({}, num_warps=16)
166
+ ],
167
+ key=["BT", "BK", "BV"],
168
+ )
169
+ @triton.jit
170
+ def chunk_linear_attn_fwd_kernel_o(
171
+ q,
172
+ k,
173
+ v,
174
+ h,
175
+ o,
176
+ s_qk_h,
177
+ s_qk_t,
178
+ s_qk_d,
179
+ s_vo_h,
180
+ s_vo_t,
181
+ s_vo_d,
182
+ s_h_h,
183
+ s_h_t,
184
+ scale,
185
+ H: tl.constexpr,
186
+ T: tl.constexpr,
187
+ K: tl.constexpr,
188
+ V: tl.constexpr,
189
+ BT: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr
192
+ ):
193
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
194
+
195
+ o_i = tl.arange(0, BT)
196
+ m_s = o_i[:, None] >= o_i[None, :]
197
+
198
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
199
+ b_s = tl.zeros([BT, BT], dtype=tl.float32)
200
+ for i_k in range(tl.cdiv(K, BK)):
201
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
202
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
203
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
204
+ # [BT, BK]
205
+ b_q = tl.load(p_q, boundary_check=(0, 1))
206
+ b_q = (b_q * scale).to(b_q.dtype)
207
+ # [BK, BT]
208
+ b_k = tl.load(p_k, boundary_check=(0, 1))
209
+ # [BK, BV]
210
+ b_h = tl.load(p_h, boundary_check=(0, 1))
211
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
212
+ b_s += tl.dot(b_q, b_k, allow_tf32=False)
213
+
214
+ b_s = tl.where(m_s, b_s, 0)
215
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
216
+ b_v = tl.load(p_v, boundary_check=(0, 1))
217
+ b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False))
218
+ p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
219
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
220
+
221
+
222
+ @triton.autotune(
223
+ configs=[
224
+ triton.Config({}, num_warps=1),
225
+ triton.Config({}, num_warps=2),
226
+ triton.Config({}, num_warps=4),
227
+ triton.Config({}, num_warps=8),
228
+ triton.Config({}, num_warps=16)
229
+ ],
230
+ key=["BT", "BK", "BV"],
231
+ )
232
+ @triton.jit
233
+ def chunk_delta_rule_bwd_kernel_dhu(
234
+ q,
235
+ k,
236
+ d,
237
+ do,
238
+ dh,
239
+ dv,
240
+ dv2,
241
+ s_qk_h,
242
+ s_qk_t,
243
+ s_qk_d,
244
+ s_vo_h,
245
+ s_vo_t,
246
+ s_vo_d,
247
+ s_h_h,
248
+ s_h_t,
249
+ scale,
250
+ H: tl.constexpr,
251
+ T: tl.constexpr,
252
+ K: tl.constexpr,
253
+ V: tl.constexpr,
254
+ BT: tl.constexpr,
255
+ BC: tl.constexpr,
256
+ BK: tl.constexpr,
257
+ BV: tl.constexpr,
258
+ NT: tl.constexpr
259
+ ):
260
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
261
+
262
+ # [BK, BV]
263
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
264
+ for i_t in range(NT - 1, -1, -1):
265
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
266
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
267
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
268
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
269
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t),
270
+ (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
271
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d),
272
+ (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
273
+ p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t),
274
+ (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
275
+ p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),
276
+ (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
277
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),
278
+ (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
279
+ # [BK, BT]
280
+ b_q = tl.load(p_q, boundary_check=(0, 1))
281
+ b_q = (b_q * scale).to(b_q.dtype)
282
+ # [BT, BK]
283
+ b_k = tl.load(p_k, boundary_check=(0, 1))
284
+ b_d = tl.load(p_d, boundary_check=(0, 1))
285
+ # [BT, V]
286
+ b_do = tl.load(p_do, boundary_check=(0, 1))
287
+
288
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
289
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
290
+ p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d),
291
+ (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
292
+ tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
293
+ # [BK, BV]
294
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
295
+ b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
296
+ b_dh += b_dh_tmp
297
+
298
+
299
+ @triton.autotune(
300
+ configs=[
301
+ triton.Config({}, num_warps=1),
302
+ triton.Config({}, num_warps=2),
303
+ triton.Config({}, num_warps=4),
304
+ triton.Config({}, num_warps=8),
305
+ triton.Config({}, num_warps=16)
306
+ ],
307
+ key=["BT", "BK", "BV"],
308
+ )
309
+ @triton.jit
310
+ def chunk_delta_rule_bwd_kernel_dqkw(
311
+ q,
312
+ k,
313
+ v,
314
+ w,
315
+ h,
316
+ do,
317
+ dh,
318
+ dq,
319
+ dk,
320
+ dv,
321
+ dw,
322
+ s_qk_h,
323
+ s_qk_t,
324
+ s_qk_d,
325
+ s_vo_h,
326
+ s_vo_t,
327
+ s_vo_d,
328
+ s_h_h,
329
+ s_h_t,
330
+ scale,
331
+ H: tl.constexpr,
332
+ T: tl.constexpr,
333
+ K: tl.constexpr,
334
+ V: tl.constexpr,
335
+ BT: tl.constexpr,
336
+ BK: tl.constexpr,
337
+ BV: tl.constexpr,
338
+ NT: tl.constexpr
339
+ ):
340
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
341
+ o_i = tl.arange(0, BT)
342
+
343
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
344
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
345
+
346
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
347
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
348
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32)
349
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
350
+ for i_v in range(tl.cdiv(V, BV)):
351
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
352
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
353
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
355
+ p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
356
+ # [BT, BV]
357
+ b_v = tl.load(p_v, boundary_check=(0, 1))
358
+ b_do = tl.load(p_do, boundary_check=(0, 1))
359
+ # [BV, BK]
360
+ b_h = tl.load(p_h, boundary_check=(0, 1))
361
+ # [BK, BV]
362
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
363
+ # [BT, BT]
364
+ b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
365
+ # [BT, BK]
366
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
367
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
368
+
369
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
370
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
371
+
372
+ # [BT, BT]
373
+ # [BT, BK]
374
+ b_q = tl.load(p_q, boundary_check=(0, 1))
375
+ b_q = (b_q * scale).to(b_q.dtype)
376
+ b_k = tl.load(p_k, boundary_check=(0, 1))
377
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)
378
+ b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
379
+ b_dq *= scale
380
+ b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
381
+
382
+ p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
383
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
384
+ p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
385
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
386
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
387
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
388
+
389
+
390
+ def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
391
+ B, H, T, K, V = *k.shape, u.shape[-1]
392
+
393
+ BK = triton.next_power_of_2(K)
394
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
395
+ BV = 16 if BK > 128 else 32
396
+ BV = 64 if BK <= 64 else BV
397
+ BC = 16 if BK > 128 else 32
398
+ BC = 64 if BK <= 64 else BC
399
+ BC = min(BT, BC)
400
+ NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
401
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
402
+
403
+ h = k.new_empty(B, H, NT * K, V)
404
+ grid = (NK, NV, B * H)
405
+ v_new = torch.empty_like(u)
406
+ chunk_delta_rule_fwd_kernel_h[grid](
407
+ k, u, w, v_new, h, initial_state, final_state,
408
+ k.stride(1), k.stride(2), k.stride(3),
409
+ u.stride(1), u.stride(2), u.stride(3),
410
+ h.stride(1), h.stride(2),
411
+ H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
412
+ USE_INITIAL_STATE=initial_state is not None,
413
+ STORE_FINAL_STATE=final_state is not None,
414
+ )
415
+ return h, v_new
416
+
417
+
418
+ def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
419
+ B, H, T, K, V = *q.shape, do.shape[-1]
420
+
421
+ BK = triton.next_power_of_2(K)
422
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
423
+ BV = 16 if BK > 128 else 32
424
+ BV = 64 if BK <= 64 else BV
425
+ BC = 16 if BK > 128 else 32
426
+ BC = 64 if BK <= 64 else BC
427
+ BC = min(BT, BC)
428
+ NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
429
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
430
+
431
+ dh = q.new_empty(B, H, NT * K, V)
432
+ # dv_new = torch.empty_like(do)
433
+ grid = (NK, NV, B * H)
434
+ dv2 = torch.empty_like(dv)
435
+ chunk_delta_rule_bwd_kernel_dhu[grid](
436
+ q, k, w, do, dh, dv, dv2,
437
+ q.stride(1), q.stride(2), q.stride(3),
438
+ do.stride(1), do.stride(2), do.stride(3),
439
+ dh.stride(1), dh.stride(2),
440
+ K**-0.5,
441
+ H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
442
+ )
443
+ return dh, dv2
444
+
445
+
446
+ def chunk_fwd_o_fn(q, k, v_new, h, BT):
447
+ B, H, T, K, V = *q.shape, v_new.shape[-1]
448
+
449
+ BK = triton.next_power_of_2(K)
450
+ o = torch.empty_like(v_new)
451
+ BK = min(triton.next_power_of_2(K), 64)
452
+ BV = min(triton.next_power_of_2(V), 64)
453
+ NV = triton.cdiv(V, BV)
454
+ NT = triton.cdiv(T, BT)
455
+ grid = (NV, NT, B * H)
456
+ chunk_linear_attn_fwd_kernel_o[grid](
457
+ q, k, v_new, h, o,
458
+ q.stride(1), q.stride(2), q.stride(3),
459
+ v_new.stride(1), v_new.stride(2), v_new.stride(3),
460
+ h.stride(1), h.stride(2),
461
+ scale=K**-0.5,
462
+ H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
463
+ )
464
+ return o
465
+
466
+
467
+ def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
468
+ B, H, T, K, V = *q.shape, v_new.shape[-1]
469
+
470
+ BK = triton.next_power_of_2(K)
471
+ BK = min(triton.next_power_of_2(K), 64)
472
+ BV = min(triton.next_power_of_2(V), 64)
473
+ NK = triton.cdiv(K, BK)
474
+ NT = triton.cdiv(T, BT)
475
+ grid = (NK, NT, B * H)
476
+ dq = torch.empty_like(q)
477
+ dk = torch.empty_like(k)
478
+ dw = torch.empty_like(w)
479
+ chunk_delta_rule_bwd_kernel_dqkw[grid](
480
+ q, k, v_new, w, h, do, dh, dq, dk, du, dw,
481
+ q.stride(1), q.stride(2), q.stride(3),
482
+ v_new.stride(1), v_new.stride(2), v_new.stride(3),
483
+ dh.stride(1), dh.stride(2),
484
+ scale=K ** -0.5,
485
+ H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
486
+ )
487
+ return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
488
+
489
+
490
+ class ChunkDeltaRuleFunction(torch.autograd.Function):
491
+
492
+ @staticmethod
493
+ @contiguous
494
+ @autocast_custom_fwd
495
+ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
496
+ # obtain WY representation. u is actually the new v.
497
+ w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
498
+ # ### forward_h
499
+ final_state = None
500
+ if output_final_state:
501
+ final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1],
502
+ dtype=torch.float32, requires_grad=False)
503
+ h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)
504
+ # obtain output
505
+ o = chunk_fwd_o_fn(q, k, v_new, h, BT)
506
+ # save memory
507
+ if checkpoint_level == 1:
508
+ h, v_new = None, None
509
+ ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
510
+ ctx.BT = BT
511
+ return o.to(q.dtype), final_state
512
+
513
+ @staticmethod
514
+ @contiguous
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, d_ht=None):
517
+ q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
518
+ BT = ctx.BT
519
+ w, u = fwd_recompute_w_u(k, v, beta, A, BT)
520
+ # checkpont_level=1, recomputation.
521
+ if h is None:
522
+ h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
523
+ dv = fwd_prepare_dv(q, k, do, BT)
524
+ dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
525
+ dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
526
+ dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
527
+ dk.add_(dk2)
528
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
529
+
530
+
531
+ def chunk_delta_rule(
532
+ q: torch.Tensor,
533
+ k: torch.Tensor,
534
+ v: torch.Tensor,
535
+ beta: torch.Tensor,
536
+ BT: int,
537
+ initial_state: torch.Tensor = None,
538
+ output_final_state: bool = False
539
+ ):
540
+ assert q.dtype == k.dtype == v.dtype
541
+ assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16."
542
+ o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
543
+ return o, final_state
opencompass/models/fla2/ops/delta_rule/chunk_fuse.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from ...ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
10
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
11
+ import torch.nn.functional as F
12
+
13
+ def ceildiv(a, b):
14
+ return -(a // -b)
15
+
16
+ def pad(x, chunk_size=16):
17
+ seq_len = x.shape[-2]
18
+ #b n l d
19
+ padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size
20
+ if x.shape[-2] % chunk_size != 0:
21
+ x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len))
22
+ if x.shape[-1] % 32 != 0:
23
+ x = F.pad(x, (0, 32 - x.shape[-1] % 32))
24
+ return x
25
+
26
+ def pad_b(x, chunk_size=16):
27
+ seq_len = x.shape[-1] # 获取序列长度 l
28
+ padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度
29
+ # 如果序列长度不是 chunk_size 的倍数,则进行填充
30
+ if seq_len % chunk_size != 0:
31
+ x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充
32
+ return x
33
+
34
+ # on-the-fly computation without materializing hidden statets into HBMs
35
+ @triton.autotune(
36
+ configs=[
37
+ triton.Config({}, num_warps=1),
38
+ triton.Config({}, num_warps=2),
39
+ triton.Config({}, num_warps=4),
40
+ triton.Config({}, num_warps=8)
41
+ ],
42
+ key=["BT", "BK"],
43
+ )
44
+ @triton.jit
45
+ def fused_chunk_delta_rule_fwd_kernel(
46
+ # B: batch_size, H: n_heads, T: seq_len, D: d_head
47
+ q, # query [B, H, L, D_head_K]
48
+ k, # key [B, H, L, D_head_K]
49
+ v, # value [B, H, L, D_head_V]
50
+ v_new,
51
+ d, # decay [B, H, L, D_head_K]
52
+ o, # output [B, H, L, D_head_V]
53
+ initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
54
+ final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
55
+ s_qk_h, # stride size: L * D_head_K
56
+ s_qk_t, # stride size: D_head_K
57
+ s_qk_d, # stride size: 1
58
+ s_vo_h, # stride size: L * D_head_V
59
+ s_vo_t, # stride size: D_head_V
60
+ s_vo_d, # stride size: 1
61
+ B, # batch size
62
+ H, # n_heads
63
+ T, # seq_len
64
+ scale, # D_head_K ** -0.5
65
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
66
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
67
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
68
+ DK: tl.constexpr, # D_head_K
69
+ DV: tl.constexpr, # D_head_V
70
+ USE_INITIAL_STATE: tl.constexpr,
71
+ STORE_FINAL_STATE: tl.constexpr,
72
+ CHECK: tl.constexpr
73
+ ):
74
+ # indices
75
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
76
+
77
+ o_i = tl.arange(0, BT)
78
+
79
+ # [BT, BT]
80
+ m_s = o_i[:, None] >= o_i[None, :]
81
+ # [BK, BV]
82
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
83
+
84
+ # make block pointers
85
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
86
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
87
+ p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
88
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
89
+ p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
90
+ p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
91
+
92
+ if USE_INITIAL_STATE:
93
+ p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
94
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
95
+
96
+ for i in range(0, tl.cdiv(T, BT)):
97
+ # [BK, BT]
98
+ b_k = tl.load(p_k, boundary_check=(0, 1))
99
+ # [BT, BV]
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ # [BT, BK]
102
+ b_q = tl.load(p_q, boundary_check=(0, 1))
103
+ b_d = tl.load(p_d, boundary_check=(0, 1))
104
+ b_q = (b_q * scale).to(b_k.dtype)
105
+
106
+ # [BT, BT]
107
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
108
+ b_s = tl.where(m_s, b_s, 0)
109
+ # [BT, BV]
110
+ b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)
111
+ b_v = b_v - b_v_prime
112
+ tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
113
+
114
+ b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)
115
+ if CHECK and i == 0:
116
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
117
+ b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
118
+ else:
119
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
120
+ b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
121
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
122
+ p_q = tl.advance(p_q, (BT, 0))
123
+ p_k = tl.advance(p_k, (0, BT))
124
+ p_v = tl.advance(p_v, (BT, 0))
125
+ p_v_new = tl.advance(p_v_new, (BT, 0))
126
+ p_o = tl.advance(p_o, (BT, 0))
127
+ p_d = tl.advance(p_d, (BT, 0))
128
+
129
+ if STORE_FINAL_STATE:
130
+ p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
131
+ tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
132
+
133
+
134
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
135
+ @triton.autotune(
136
+ configs=[
137
+ triton.Config({}, num_warps=1),
138
+ triton.Config({}, num_warps=2),
139
+ triton.Config({}, num_warps=4),
140
+ triton.Config({}, num_warps=8),
141
+ triton.Config({}, num_warps=16),
142
+ triton.Config({}, num_warps=32),
143
+ ],
144
+ key=["BT", "BK", "BV"],
145
+ )
146
+ @triton.jit
147
+ def fused_chunk_delta_rule_bwd_kernel(
148
+ # B: batch_size, H: n_heads, T: seq_len, D: d_head
149
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
150
+ q, # query [B, H, L, D_head_K]
151
+ k, # key [B, H, L, D_head_V]
152
+ v, # value [B, H, L, D_head_V]
153
+ d, # decay [B, H, L, D_head_K]
154
+ do, # gradient of output [B, H, L, D_head_V]
155
+ dq, # gradient of query [NV, B, H, L, D_head_K]
156
+ dk, # gradient of key [NV, B, H, L, D_head_K]
157
+ dv, # gradient of value [NK, B, H, L, D_head_V]
158
+ dd, # gradient of decay [NV, B, H, L, D_head_K]
159
+ initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
160
+ s_qk_h, # stride size: L * D_head_K
161
+ s_qk_t, # stride size: D_head_K
162
+ s_qk_d, # stride size: 1
163
+ s_vo_h, # stride size: L * D_head_V
164
+ s_vo_t, # stride size: D_head_V
165
+ s_vo_d, # stride size: 1
166
+ B, # batch_size
167
+ H, # n_heads
168
+ T, # seq_len
169
+ scale, # D_head_K ** -0.5
170
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
171
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
172
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
173
+ DK: tl.constexpr, # D_head_K
174
+ DV: tl.constexpr, # D_head_V
175
+ USE_INITIAL_STATE: tl.constexpr,
176
+ CHECK: tl.constexpr
177
+ ):
178
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
179
+ o_i = tl.arange(0, BT)
180
+
181
+ # first reverse
182
+ # [BK, BV]
183
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
184
+ m_s = o_i[:, None] <= o_i[None, :]
185
+ for i in range(1, tl.cdiv(T, BT) + 1):
186
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
187
+ p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
188
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
189
+
190
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
191
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
192
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
193
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
194
+ # [DK, BT]
195
+ b_q = tl.load(p_q, boundary_check=(0, 1))
196
+ b_q = (b_q * scale).to(b_q.dtype)
197
+ # [BT, DK]
198
+ b_k = tl.load(p_k, boundary_check=(0, 1))
199
+ # [BT, DV]
200
+ b_v = tl.load(p_v, boundary_check=(0, 1))
201
+ b_do = tl.load(p_do, boundary_check=(0, 1))
202
+
203
+ # [BT, BT]
204
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
205
+ b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
206
+ # [BT, BT]
207
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
208
+ b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
209
+ # [BT, DK]
210
+ b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
211
+ # [BT, DV]
212
+ b_dv = tl.dot(b_s, b_do, allow_tf32=False)
213
+ b_d = tl.load(p_d, boundary_check=(0, 1))
214
+ if CHECK and i == 1:
215
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
216
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
217
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
218
+ b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
219
+ else:
220
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
221
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
222
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
223
+ b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
224
+
225
+ tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
226
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
227
+
228
+ # sync threads
229
+ b_h = None
230
+ tl.debug_barrier()
231
+ m_s = o_i[:, None] >= o_i[None, :]
232
+ # [BV, BK]
233
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
234
+ if USE_INITIAL_STATE:
235
+ p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
236
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
237
+ NT = tl.cdiv(T, BT)
238
+ for i in range(0, NT):
239
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
240
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
241
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
242
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
243
+
244
+ # [BT, DK]
245
+ b_k = tl.load(p_k, boundary_check=(0, 1))
246
+ # [DV, BT]
247
+ b_v = tl.load(p_v, boundary_check=(0, 1))
248
+ # [BT, DV]
249
+ b_do = tl.load(p_do, boundary_check=(0, 1))
250
+
251
+ # [BT, BT]
252
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
253
+ b_ds = tl.where(m_s, b_ds, 0)
254
+ # [BT, DK]
255
+ b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
256
+ # [DV, DK]
257
+ if CHECK and i == 0:
258
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
259
+ b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
260
+ else:
261
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
262
+ b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
263
+ b_dq *= scale
264
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
265
+
266
+ if i < (NT - 1):
267
+ p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))
268
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
269
+ b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
270
+ p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
271
+ ((i+1) * BT, i_k * BK), (BT, BK), (1, 0))
272
+ tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))
273
+
274
+
275
+ def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):
276
+ batch_size, n_heads, seq_len, d_head_qk = q.shape
277
+ d_head_v = v.shape[-1]
278
+ scale = d_head_qk ** -0.5
279
+ BT = BT
280
+ # ctx.BT = BT
281
+ BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
282
+ NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
283
+ assert NK == 1, 'NK should be 1'
284
+ o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)
285
+ if output_final_state:
286
+ final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
287
+ else:
288
+ final_state = None
289
+ CHECK = True
290
+ # if version.parse(triton.__version__) < version.parse('2.2.0'):
291
+ # import warnings
292
+ # warnings.warn(
293
+ # "Triton<2.2.0 detected for running this kernel, "
294
+ # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
295
+ # "that lead to significant precision loss. "
296
+ # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
297
+ # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
298
+ # )
299
+ # CHECK = True
300
+ grid = (NV, NK, batch_size * n_heads)
301
+ v_new = torch.empty_like(v)
302
+ fused_chunk_delta_rule_fwd_kernel[grid](
303
+ q, k, v, v_new, d, o, initial_state, final_state,
304
+ q.stride(1), q.stride(2), q.stride(3),
305
+ v.stride(1), v.stride(2), v.stride(3),
306
+ batch_size, n_heads, seq_len, scale,
307
+ BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
308
+ USE_INITIAL_STATE=initial_state is not None,
309
+ STORE_FINAL_STATE=output_final_state,
310
+ CHECK=CHECK,
311
+ )
312
+ return o, v_new, CHECK, final_state
313
+
314
+
315
+ def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):
316
+ batch_size, n_heads, seq_len, d_head_qk = q.shape
317
+ d_head_v = v.shape[-1]
318
+ scale = d_head_qk ** -0.5
319
+ BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
320
+ NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
321
+ assert NK == 1
322
+ dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
323
+ dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
324
+ dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
325
+ dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
326
+ grid = (NV, NK, batch_size * n_heads)
327
+ fused_chunk_delta_rule_bwd_kernel[grid](
328
+ q, k, v, d, do, dq, dk, dv, dd, initial_state,
329
+ q.stride(1), q.stride(2), q.stride(3),
330
+ v.stride(1), v.stride(2), v.stride(3),
331
+ batch_size, n_heads, seq_len, scale,
332
+ BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
333
+ USE_INITIAL_STATE=initial_state is not None,
334
+ CHECK=CHECK,
335
+ # num_warps=num_warps,
336
+ # num_stages=num_stages
337
+ )
338
+ dq = dq.sum(0)
339
+ dk = dk.sum(0)
340
+ dv = dv.sum(0)
341
+ dd = dd.sum(0)
342
+ dd[:, :, 0:BT] = 0
343
+ return dq, dk, dv, dd
344
+
345
+
346
+ class FusedChunkDeltaRuleFunction(torch.autograd.Function):
347
+
348
+ @staticmethod
349
+ @contiguous
350
+ @autocast_custom_fwd
351
+ def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
352
+ # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
353
+ assert checkpoint_level in [0, 1]
354
+ k_origin = k
355
+ # k = _l2_norm_fwd(k_origin)
356
+ k = k
357
+ d, v_new = fwd_prepare_wy_repr(k, v, beta, BT)
358
+ o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state)
359
+ if checkpoint_level == 1:
360
+ d, v_new = None, None
361
+ ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state)
362
+ ctx.CHECK = CHECK
363
+ ctx.chunk_size = BT
364
+ return o.to(q.dtype), final_state
365
+
366
+ @staticmethod
367
+ @contiguous
368
+ @autocast_custom_bwd
369
+ def backward(ctx, do, d_final_state=None):
370
+ q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
371
+ chunk_size = ctx.chunk_size
372
+ k = k_origin
373
+ # k = _l2_norm_fwd(k_origin)
374
+ if d is None:
375
+ d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
376
+ dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state)
377
+ dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size)
378
+ dk.add_(dk2)
379
+ # dk = _l2_norm_bwd(k_origin, dk)
380
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None
381
+
382
+
383
+ def fused_chunk_delta_rule(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ beta: torch.Tensor,
388
+ BT: int,
389
+ initial_state: torch.Tensor = None,
390
+ output_final_state: bool = False,
391
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
392
+ assert q.dtype == k.dtype == v.dtype
393
+ assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16."
394
+
395
+ if initial_state is not None:
396
+ initial_state = initial_state.detach()
397
+ seq_len = v.shape[-2]
398
+ d_head_v = v.shape[-1]
399
+ q, k, v = map(lambda x: pad(x), [q, k, v])
400
+ beta = pad_b(beta)
401
+ o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
402
+ o = o[..., :seq_len, :d_head_v]
403
+ return o, final_state
404
+
405
+
406
+ def delta_rule_recurrence(q, k, v, beta):
407
+ b, h, l, d_k = q.shape
408
+ d_v = v.shape[-1]
409
+ o = torch.zeros_like(v)
410
+ S = torch.zeros(b, h, d_k, d_v).to(v)
411
+ q = q * (d_k ** -0.5)
412
+ k = torch.nn.functional.normalize(k, p=2, dim=-1)
413
+ for i in range(l):
414
+ _k = k[:, :, i]
415
+ _q = q[:, :, i]
416
+ _v = v[:, :, i].clone()
417
+ beta_i = beta[:, :, i]
418
+ _v = _v - (S.clone() * _k[..., None]).sum(-2)
419
+ _v = _v * beta_i[..., None]
420
+ S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
421
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
422
+ return o
423
+
424
+
425
+ if __name__ == "__main__":
426
+ import torch.nn.functional as F
427
+ seq_len = 128
428
+ b = 2
429
+ h = 4
430
+ q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
431
+ k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
432
+ v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1)
433
+ beta = torch.rand(b, h, seq_len).sigmoid()
434
+ q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta))
435
+ do = torch.rand_like(v)
436
+ o2 = delta_rule_recurrence(q, k, v.clone(), beta)
437
+ o2.backward(do, retain_graph=True)
438
+ q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
439
+ q.grad = k.grad = v.grad = beta.grad = None
440
+ o, _ = fused_chunk_delta_rule(q, k, v, beta, 32)
441
+ o.backward(do, retain_graph=True)
442
+ q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad
443
+ q.grad = k.grad = v.grad = beta.grad = None
444
+ print((o - o2).abs().max())
445
+ print((q_grad - q_grad2).abs().max())
446
+ print((k_grad - k_grad2).abs().max())
447
+ print((v_grad - v_grad2).abs().max())
448
+ print((beta_grad - beta_grad2).abs().max())
opencompass/models/fla2/ops/delta_rule/naive.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def delta_rule_recurrence(q, k, v, beta):
8
+ b, h, l, d_k = q.shape
9
+ d_v = v.shape[-1]
10
+ o = torch.zeros_like(v)
11
+ S = torch.zeros(b, h, d_k, d_v).to(v)
12
+ q = q * (d_k ** -0.5)
13
+
14
+ if beta.ndim < v.ndim:
15
+ beta = beta[..., None]
16
+
17
+ for i in range(l):
18
+ _k = k[:, :, i]
19
+ _q = q[:, :, i]
20
+ _v = v[:, :, i].clone()
21
+ beta_i = beta[:, :, i]
22
+ _v = _v - (S.clone() * _k[..., None]).sum(-2)
23
+ _v = _v * beta_i
24
+ S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
25
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
26
+
27
+ return o
28
+
29
+
30
+ def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
31
+ b, h, l, d_k = q.shape
32
+ d_v = v.shape[-1]
33
+ q = q * (d_k ** -0.5)
34
+ v = v * beta[..., None]
35
+ k_beta = k * beta[..., None]
36
+
37
+ assert l % chunk_size == 0
38
+
39
+ # note that diagonal is masked.
40
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
41
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
42
+ attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
43
+
44
+ for i in range(1, chunk_size):
45
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
46
+
47
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
48
+ # u
49
+ k_cumsum = attn @ v
50
+ # w
51
+ k_cumdecay = attn @ k_beta
52
+
53
+ v = k_cumsum
54
+ S = k.new_zeros(b, h, d_k, d_v)
55
+ o = torch.zeros_like(v)
56
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
57
+ for i in range(0, l // chunk_size):
58
+ q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
59
+ attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
60
+ v_prime = k_cumdecay[:, :, i] @ S
61
+ v_new = v_i - v_prime
62
+ o_inter = q_i @ S
63
+ o[:, :, i] = o_inter + attn @ v_new
64
+ # chunk state update
65
+ S = S + k_i.transpose(-1, -2) @ v_new
66
+
67
+ return rearrange(o, 'b h n c d -> b h (n c) d')
68
+
69
+
70
+ if __name__ == '__main__':
71
+ B = 2
72
+ H = 4
73
+ L = 256
74
+ DK = 128
75
+ DV = 128
76
+ q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
77
+ k = (torch.randn(B, H, L, DK)).cuda()
78
+ k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
79
+ v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
80
+ beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
81
+
82
+ o = delta_rule_recurrence(q, k, v, beta)
83
+ do = torch.randn(B, H, L, DV).cuda()
84
+ o.backward(do, retain_graph=True)
85
+ q_grad, q.grad = q.grad, None
86
+ k_grad, k.grad = k.grad, None
87
+ v_grad, v.grad = v.grad, None
88
+ beta_grad, beta.grad = beta.grad, None
89
+
90
+ o2 = delta_rule_chunkwise(q, k, v, beta)
91
+ o2.backward(do)
92
+ assert torch.allclose(o, o2, atol=1e-4), breakpoint()
93
+ assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
94
+ assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
95
+ assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
96
+ assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
97
+ print("All passed!")
opencompass/models/fla2/ops/delta_rule/recurrent_fuse.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023, Yu Zhang, Songlin Yang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...utils import contiguous
11
+
12
+ # on-the-fly computation without materializing hidden statets into HBMs
13
+
14
+
15
+ @triton.jit
16
+ def fused_recurrent_fwd_kernel(
17
+ # B: batch_size, H: n_heads, T: seq_len, D: d_head
18
+ q, # query [B, H, L, K]
19
+ k, # key [B, H, L, V]
20
+ v, # value [B, H, L, V].
21
+ beta, # beta [B, H, L]
22
+ o, # output [B, H, L, V]
23
+ h0,
24
+ ht, # final hidden state [B, H, K, V]
25
+ s_qk_h, # stride size: L * K
26
+ s_vo_h, # stride size: L * V
27
+ scale, # K ** -0.5
28
+ B, # batch size
29
+ H, # n_heads
30
+ T, # seq_len
31
+ K: tl.constexpr, # K
32
+ V: tl.constexpr, # V
33
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
34
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
35
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
36
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
37
+ IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar
38
+ ):
39
+
40
+ # indices
41
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
42
+
43
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
44
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
45
+ p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
46
+ if IS_HEADWISE_BETA:
47
+ p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
48
+ else:
49
+ p_beta = beta + i_bh * T
50
+ p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
51
+
52
+ mask_bk = (i_k * BK + tl.arange(0, BK)) < K
53
+ mask_bv = (i_v * BV + tl.arange(0, BV)) < V
54
+ mask_kv = mask_bk[None, :] & mask_bv[:, None]
55
+
56
+ h = tl.zeros([BV, BK], dtype=tl.float32)
57
+
58
+ if USE_INITIAL_STATE:
59
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
60
+ h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
61
+
62
+ for _ in range(0, T):
63
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
64
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
65
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
66
+ _v_minus = tl.sum(h * b_k[None, :], axis=1)
67
+ b_v -= _v_minus
68
+ if IS_HEADWISE_BETA:
69
+ b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)
70
+ else:
71
+ b_beta = tl.load(p_beta).to(tl.float32)
72
+ # in-place overwrite
73
+ tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv)
74
+ b_v *= b_beta
75
+ h += b_k[None, :] * b_v[:, None]
76
+ _o = h * b_q[None, :]
77
+ _o = tl.sum(_o, axis=1)
78
+ tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
79
+
80
+ p_q += K
81
+ p_k += K
82
+ p_o += V
83
+ p_v += V
84
+ p_beta += V if IS_HEADWISE_BETA else 1
85
+
86
+ if STORE_FINAL_STATE:
87
+ p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
88
+ tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)
89
+
90
+
91
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
92
+ @triton.jit
93
+ def fused_recurrent_bwd_kernel(
94
+ # B: batch_size, H: n_heads, T: seq_len, D: d_head
95
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
96
+ q, # query [B, H, L, K]
97
+ k, # key [B, H, L, V]
98
+ v, # value [B, H, L, V]
99
+ beta, # beta [B, H, L, (V)]
100
+
101
+ do, # gradient of output [B, H, L, V]
102
+ dq, # gradient of query [NV, B, H, L, K]
103
+ dk, # gradient of key [NV, B, H, L, K]
104
+ dv, # gradient of value [NK, B, H, L, V]
105
+ dbeta, # gradient of beta [NV, (NK), B, H, L]
106
+
107
+ # initial hidden state initialization [B, H, K, V]
108
+ h0,
109
+
110
+ s_qk_h, # stride size: L * K
111
+
112
+ s_vo_h, # stride size: L * V
113
+
114
+ NK, # NK block size
115
+ scale, # K ** -0.5
116
+
117
+ B, # batch_size
118
+ H, # n_heads
119
+ T, # seq_len
120
+ K: tl.constexpr, # K
121
+ V: tl.constexpr, # V
122
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
123
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
124
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
125
+ IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar
126
+ ):
127
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
128
+ mask_bk = i_k * BK + tl.arange(0, BK) < K
129
+ mask_bv = i_v * BV + tl.arange(0, BV) < V
130
+
131
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
132
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
133
+ p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
134
+ p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
135
+ if IS_HEADWISE_BETA:
136
+ p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
137
+ else:
138
+ p_beta = beta + i_bh * T + T - 1
139
+
140
+ p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
141
+ p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
142
+ if IS_HEADWISE_BETA:
143
+ p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V
144
+ else:
145
+ p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
146
+ d_h = tl.zeros([BK, BV], dtype=tl.float32)
147
+
148
+ for _ in range(T):
149
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
150
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
151
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
152
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
153
+ if IS_HEADWISE_BETA:
154
+ b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)
155
+ else:
156
+ b_beta = tl.load(p_beta).to(tl.float32)
157
+ d_h += b_q[:, None] * b_do[None, :]
158
+ d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1)
159
+ d_v = tl.sum(d_h * b_k[:, None], axis=0)
160
+
161
+ d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v)
162
+ d_v = d_v * b_beta
163
+
164
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
165
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
166
+ if IS_HEADWISE_BETA:
167
+ tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv)
168
+ else:
169
+ tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
170
+
171
+ d_h -= b_k[:, None] * d_v[None, :]
172
+
173
+ p_do -= V
174
+ p_q -= K
175
+ p_k -= K
176
+ p_v -= V
177
+ p_dk -= K
178
+ p_dv -= V
179
+ p_dbeta -= V if IS_HEADWISE_BETA else 1
180
+ p_beta -= V if IS_HEADWISE_BETA else 1
181
+
182
+ tl.debug_barrier()
183
+
184
+ h = tl.zeros([BK, BV], dtype=tl.float32)
185
+
186
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
187
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
188
+ p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
189
+ if IS_HEADWISE_BETA:
190
+ p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
191
+ else:
192
+ p_beta = beta + i_bh * T
193
+ p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
194
+ p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
195
+ p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V
196
+ p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K
197
+
198
+ if USE_INITIAL_STATE:
199
+ mask_kv = mask_bk[:, None] & mask_bv[None, :]
200
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
201
+ h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
202
+
203
+ for i in range(0, T):
204
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
205
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
206
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
207
+ if IS_HEADWISE_BETA:
208
+ b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)
209
+ else:
210
+ b_beta = tl.load(p_beta).to(tl.float32)
211
+ b_v *= b_beta
212
+
213
+ h += b_k[:, None] * b_v[None, :]
214
+ _d_q = h * b_do[None, :]
215
+ d_q = tl.sum(_d_q, axis=1) * scale
216
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
217
+
218
+ if i < T - 1:
219
+ d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)
220
+ d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)
221
+ d_k -= tl.sum(d_v[None, :] * h, axis=1)
222
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
223
+
224
+ p_k += K
225
+ p_do += V
226
+ p_v += V
227
+ p_dk += K
228
+ p_dv += V
229
+ p_dq += K
230
+ p_beta += V if IS_HEADWISE_BETA else 1
231
+
232
+
233
+ class FusedRecurrentFunction(torch.autograd.Function):
234
+
235
+ @contiguous
236
+ @staticmethod
237
+ def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False):
238
+ B, H, T, K, V = *q.shape, v.shape[-1]
239
+
240
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
241
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
242
+ num_stages = 1
243
+ num_warps = 1
244
+ assert NK == 1, "NK > 1 is not supported yet"
245
+ o = q.new_empty(NK, B, H, T, V)
246
+
247
+ if output_final_state:
248
+ final_state = q.new_empty(B, H, K, V)
249
+ else:
250
+ final_state = None
251
+
252
+ grid = (NV, NK, B * H)
253
+ fused_recurrent_fwd_kernel[grid](
254
+ q, k, v, beta, o, initial_state, final_state,
255
+ q.stride(1),
256
+ v.stride(1),
257
+ scale,
258
+ B=B, H=H, T=T, K=K, V=V,
259
+ BK=BK, BV=BV,
260
+ USE_INITIAL_STATE=initial_state is not None,
261
+ STORE_FINAL_STATE=final_state is not None,
262
+ IS_HEADWISE_BETA=beta.ndim == v.ndim,
263
+ num_warps=num_warps,
264
+ num_stages=num_stages,
265
+ )
266
+ o = o.sum(0)
267
+ ctx.save_for_backward(q, k, v, beta, initial_state)
268
+ ctx.scale = scale
269
+ return o, final_state
270
+
271
+ @contiguous
272
+ @staticmethod
273
+ def backward(ctx, do, dht=None):
274
+ q, k, v, beta, initial_state = ctx.saved_tensors
275
+ B, H, T, K, V = *q.shape, v.shape[-1]
276
+ scale = ctx.scale
277
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
278
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
279
+ assert NK == 1, "NK > 1 is not supported yet"
280
+ num_stages = 1
281
+ num_warps = 2
282
+
283
+ beta_vector = beta.ndim == v.ndim
284
+
285
+ dq = q.new_empty(NV, B, H, T, K)
286
+ dk = q.new_empty(NV, B, H, T, K)
287
+ dv = q.new_empty(NK, B, H, T, V)
288
+ if beta_vector:
289
+ dbeta = q.new_empty(NV, NK, B, H, T, V)
290
+ else:
291
+ dbeta = q.new_empty(NV, B, H, T)
292
+ grid = (NV, NK, B * H)
293
+
294
+ fused_recurrent_bwd_kernel[grid](
295
+ q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,
296
+ q.stride(1),
297
+ v.stride(1),
298
+ NK, scale,
299
+ B=B, H=H, T=T, K=K, V=V,
300
+ BK=BK, BV=BV,
301
+ USE_INITIAL_STATE=initial_state is not None,
302
+ IS_HEADWISE_BETA=beta_vector,
303
+ num_warps=num_warps,
304
+ num_stages=num_stages
305
+ )
306
+ dq = dq.sum(0)
307
+ dk = dk.sum(0)
308
+ dv = dv.sum(0)
309
+ dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0)
310
+ return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None
311
+
312
+
313
+ def fused_recurrent_delta_rule(
314
+ q: torch.Tensor,
315
+ k: torch.Tensor,
316
+ v: torch.Tensor,
317
+ beta: torch.Tensor = None,
318
+ scale: float = -1,
319
+ initial_state: torch.Tensor = None,
320
+ output_final_state: bool = False,
321
+ normalize: bool = False,
322
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
323
+ if scale == -1:
324
+ scale = q.shape[-1] ** -0.5
325
+ if initial_state is not None:
326
+ initial_state = initial_state.detach()
327
+ if beta is None:
328
+ beta = torch.ones_like(q[..., 0])
329
+ o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state)
330
+ return o, final_state
opencompass/models/fla2/ops/delta_rule/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from einops import rearrange
7
+
8
+ from ...ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
9
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
10
+
11
+
12
+ # Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
13
+ # o: cumprod
14
+ # o2: cumprodsum
15
+ @triton.autotune(
16
+ configs=[
17
+ triton.Config({}, num_warps=1),
18
+ triton.Config({}, num_warps=2),
19
+ triton.Config({}, num_warps=4),
20
+ triton.Config({}, num_warps=8),
21
+ triton.Config({}, num_warps=16),
22
+ triton.Config({}, num_warps=32),
23
+ ],
24
+ key=["BT", "BK", "BV"],
25
+ )
26
+ @triton.jit
27
+ def fwd_prepare_wy_repr_kernel(
28
+ k,
29
+ v,
30
+ beta,
31
+ o,
32
+ o2,
33
+ T,
34
+ K,
35
+ V,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+
42
+ p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
43
+ p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
44
+ p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
45
+ mask_bt = (tl.arange(0, BT) + i_t * BT) < T
46
+ mask_bk = tl.arange(0, BK) < K
47
+ mask_bv = tl.arange(0, BV) < V
48
+ mask_bk = mask_bk[None, :] & mask_bt[:, None]
49
+ mask_bv = mask_bv[None, :] & mask_bt[:, None]
50
+ # [BT, BK]
51
+ b_k = tl.load(p_k, mask=mask_bk, other=0)
52
+ # [BT,]
53
+ b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)
54
+ # [BT, BV]
55
+ b_v = tl.load(p_v, mask=mask_bv, other=0)
56
+ b_v = (b_v * b_beta[:, None]).to(b_v.dtype)
57
+ # [BT, BK]
58
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
59
+ # [BT, BT]
60
+ b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
61
+ b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
62
+
63
+ for i in range(BT):
64
+ mask = tl.arange(0, BT) == i
65
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
66
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
67
+ b_A = tl.where(mask[:, None], b_a, b_A)
68
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
69
+ b_A = b_A.to(b_k.dtype)
70
+ b_w = tl.dot(b_A, b_kb, allow_tf32=False)
71
+ b_u = tl.dot(b_A, b_v, allow_tf32=False)
72
+
73
+ p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
74
+ tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)
75
+ p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
76
+ tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)
77
+
78
+
79
+ @triton.autotune(
80
+ configs=[
81
+ triton.Config({}, num_warps=1),
82
+ triton.Config({}, num_warps=2),
83
+ triton.Config({}, num_warps=4),
84
+ triton.Config({}, num_warps=8),
85
+ triton.Config({}, num_warps=16),
86
+ triton.Config({}, num_warps=32),
87
+ ],
88
+ key=["BT", "BK", "BV"],
89
+ )
90
+ @triton.jit
91
+ def bwd_prepare_wy_repr_kernel(
92
+ k, v, beta,
93
+ o, o2, do, do2,
94
+ dk, dv, dbeta,
95
+ NT, K, V, T,
96
+ BT: tl.constexpr,
97
+ BK: tl.constexpr,
98
+ BV: tl.constexpr,
99
+ ):
100
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
101
+ p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
102
+ p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
103
+ p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
104
+
105
+ p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
106
+ mask_bt = (tl.arange(0, BT) + i_t * BT) < T
107
+ mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]
108
+ mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]
109
+ b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)
110
+
111
+ b_beta = b_beta.to(tl.float32)
112
+ A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]
113
+ A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)
114
+ b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)
115
+ b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)
116
+ dA = tl.zeros([BT, BT], dtype=tl.float32)
117
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
118
+ for i in range(BT-1, -1, -1):
119
+ mask = tl.arange(0, BT) == i
120
+ attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)
121
+ do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)
122
+ dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)
123
+ b_do = b_do - attn[:, None] * do_[None, :]
124
+ b_dv = b_dv - attn[:, None] * dv_[None, :]
125
+ tl.debug_barrier()
126
+ p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
127
+ b_v = tl.load(p_v, mask=mask_bv)
128
+ b_dk += b_do * b_beta[:, None]
129
+ b_dbeta = tl.sum(b_do * b_k, axis=1)
130
+ b_dbeta += tl.sum(b_dv * b_v, axis=1)
131
+ b_v = None
132
+
133
+ p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
134
+ p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
135
+ b_o = tl.load(p_o, mask=mask_bk)
136
+ b_o2 = tl.load(p_o2, mask=mask_bv)
137
+
138
+ dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)
139
+ dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),
140
+ allow_tf32=False)
141
+ dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)
142
+ b_dv *= b_beta[:, None]
143
+ p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
144
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
145
+
146
+ b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)
147
+ dA = dA * b_beta[:, None]
148
+ b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)
149
+ b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)
150
+ p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
151
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
152
+ p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)
153
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)
154
+
155
+
156
+ def fwd_prepare_wy_repr(k, v, beta, chunk_size):
157
+ B, H, T, K, V = *k.shape, v.shape[-1]
158
+ v_new = torch.empty_like(v)
159
+ o_cumdecay = torch.empty_like(k)
160
+ BT = chunk_size
161
+ NT = triton.cdiv(T, BT)
162
+ BK = triton.next_power_of_2(K)
163
+ BV = triton.next_power_of_2(V)
164
+ fwd_prepare_wy_repr_kernel[(NT, B*H)](
165
+ k, v, beta, o_cumdecay, v_new,
166
+ T, K, V, BT, BK, BV
167
+ )
168
+ return o_cumdecay, v_new
169
+
170
+
171
+ def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):
172
+ b, h, l, d_k = do.shape
173
+ d_v = v.shape[-1]
174
+ BK = triton.next_power_of_2(d_k)
175
+ BV = triton.next_power_of_2(d_v)
176
+ c = chunk_size
177
+ BK = d_k
178
+ NT = triton.cdiv(l, c)
179
+ dk = torch.empty_like(k)
180
+ dv = torch.empty_like(v)
181
+ dbeta = torch.zeros_like(beta)
182
+ bwd_prepare_wy_repr_kernel[(NT, b*h)](
183
+ k, v, beta,
184
+ o_cumdecay, v_new, do, do2,
185
+ dk, dv, dbeta,
186
+ NT, d_k, d_v, l, chunk_size, BK, BV
187
+ )
188
+ return dk, dv, dbeta
189
+
190
+
191
+ class WYRepresentationPrepration(torch.autograd.Function):
192
+ @contiguous
193
+ @autocast_custom_fwd
194
+ @staticmethod
195
+ def forward(ctx, k, v, beta, chunk_size):
196
+ o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
197
+ ctx.chunk_size = chunk_size
198
+ ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)
199
+ return o_cumdecay, v_new
200
+
201
+ @contiguous
202
+ @autocast_custom_bwd
203
+ @staticmethod
204
+ def backward(ctx, do, do2):
205
+ k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
206
+ dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)
207
+ return dk, dv, dbeta, None
208
+
209
+
210
+ prepare_wy_repr = WYRepresentationPrepration.apply
211
+
212
+
213
+ def naive(k, v, beta, chunk_size):
214
+ l_org = k.shape[2]
215
+ l_new = triton.next_power_of_2(l_org)
216
+ # pad k, v, beta
217
+ k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
218
+ v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
219
+ beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
220
+
221
+ k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
222
+ # k = torch.nn.functional.normalize(k, dim=-1, p=2)
223
+ beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
224
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
225
+ k_beta = k * beta[..., None]
226
+ v = v * beta[..., None]
227
+ attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
228
+ attn = attn * beta[..., None]
229
+ x = attn @ v
230
+
231
+ o = torch.zeros_like(k)
232
+ o2 = torch.zeros_like(v)
233
+
234
+ o[..., 0, :] = k_beta[..., 0, :].clone()
235
+ o2[..., 0, :] = x[..., 0, :].clone()
236
+ for i in range(1, chunk_size):
237
+ o_i = (o[..., :i, :]).clone()
238
+ o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
239
+ o2_i = (o2[..., :i, :]).clone()
240
+ o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
241
+ return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
242
+
243
+
244
+ if __name__ == "__main__":
245
+ torch.set_default_dtype(torch.bfloat16)
246
+ seq_len = 2048
247
+ b = 4
248
+ h = 8
249
+ k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2)
250
+ v = torch.randn(b, h, seq_len, 256)
251
+ beta = torch.rand(b, h, seq_len).sigmoid()
252
+ require_grad = True
253
+ k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
254
+ do = torch.rand_like(k)
255
+ do2 = torch.rand_like(v)
256
+
257
+ print("Start warmup.")
258
+ o1, o2 = prepare_wy_repr(k, v, beta, 32)
259
+ # (o1 * do + o2 * do2).sum().backward()
260
+ o3, o4 = prepare_wy_repr2(k, v, beta, 32)
261
+ # (o1 * do + o2 * do2).sum().backward()
262
+ print((o1 - o3).abs().max())
263
+ print((o2 - o4).abs().max())
264
+
265
+ for i in range(30):
266
+ o1, o2 = prepare_wy_repr(k, v, beta, 32)
267
+ (o1 * do + o2 * do2).sum().backward()
268
+ o1, o2 = prepare_wy_repr2(k, v, beta, 32)
269
+ (o1 * do + o2 * do2).sum().backward()
270
+
271
+ print("Done warmup.")
272
+
273
+ import time
274
+ torch.cuda.synchronize()
275
+ start = time.time()
276
+
277
+ for i in range(200):
278
+ o1, o2 = prepare_wy_repr(k, v, beta, 64)
279
+ (o1 * do + o2 * do2).sum().backward()
280
+
281
+ torch.cuda.synchronize()
282
+ print(time.time() - start)
283
+
284
+ torch.cuda.synchronize()
285
+ start = time.time()
286
+
287
+ for i in range(200):
288
+ o1, o2 = prepare_wy_repr2(k, v, beta, 64)
289
+ (o1 * do + o2 * do2).sum().backward()
290
+
291
+ torch.cuda.synchronize()
292
+ print(time.time() - start)
opencompass/models/fla2/ops/delta_rule/wy_fast.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from einops import rearrange
7
+
8
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
9
+
10
+
11
+ # Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
12
+ # o: cumprod
13
+ # o2: cumprodsum
14
+ @triton.autotune(
15
+ configs=[
16
+ triton.Config({}, num_warps=1),
17
+ triton.Config({}, num_warps=2),
18
+ triton.Config({}, num_warps=4),
19
+ triton.Config({}, num_warps=8),
20
+ triton.Config({}, num_warps=16)
21
+ ],
22
+ key=["BT", "BK", "BV"],
23
+ )
24
+ @triton.jit
25
+ def fwd_prepare_wy_repr_kernel(
26
+ k,
27
+ v,
28
+ beta,
29
+ w,
30
+ u,
31
+ A,
32
+ s_qk_h,
33
+ s_qk_t,
34
+ s_qk_d,
35
+ s_vo_h,
36
+ s_vo_t,
37
+ s_vo_d,
38
+ T,
39
+ K,
40
+ V,
41
+ BT: tl.constexpr,
42
+ BK: tl.constexpr,
43
+ BV: tl.constexpr
44
+ ):
45
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
46
+
47
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
48
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
49
+ b_beta = tl.load(p_beta, boundary_check=(0,))
50
+
51
+ for i_k in range(tl.cdiv(K, BK)):
52
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
53
+ b_k = tl.load(p_k, boundary_check=(0, 1))
54
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
55
+ b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
56
+
57
+ b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
58
+
59
+ for i in range(1, BT):
60
+ mask = tl.arange(0, BT) == i
61
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
62
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
63
+ b_A = tl.where(mask[:, None], b_a, b_A)
64
+
65
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
66
+
67
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
69
+ b_A = b_A.to(k.dtype.element_ty)
70
+
71
+ for i_v in range(tl.cdiv(V, BV)):
72
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ b_v = tl.load(p_v, boundary_check=(0, 1))
74
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
75
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
76
+ p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
78
+
79
+ for i_k in range(tl.cdiv(K, BK)):
80
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ b_k = tl.load(p_k, boundary_check=(0, 1))
82
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
83
+ b_w = tl.dot(b_A, b_kb, allow_tf32=False)
84
+ p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
85
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
86
+
87
+
88
+ @triton.autotune(
89
+ configs=[
90
+ triton.Config({}, num_warps=1),
91
+ triton.Config({}, num_warps=2),
92
+ triton.Config({}, num_warps=4),
93
+ triton.Config({}, num_warps=8),
94
+ triton.Config({}, num_warps=16)
95
+ ],
96
+ key=["BT", "BK", "BV"],
97
+ )
98
+ @triton.jit
99
+ def fwd_recompute_w_u_kernel(
100
+ k,
101
+ v,
102
+ beta,
103
+ w,
104
+ u,
105
+ A,
106
+ s_qk_h,
107
+ s_qk_t,
108
+ s_qk_d,
109
+ s_vo_h,
110
+ s_vo_t,
111
+ s_vo_d,
112
+ T,
113
+ K,
114
+ V,
115
+ BT: tl.constexpr,
116
+ BK: tl.constexpr,
117
+ BV: tl.constexpr
118
+ ):
119
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
120
+
121
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
122
+ b_beta = tl.load(p_beta, boundary_check=(0,))
123
+
124
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
125
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
126
+
127
+ for i_v in range(tl.cdiv(V, BV)):
128
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ b_v = tl.load(p_v, boundary_check=(0, 1))
130
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
131
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
132
+ p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
133
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
134
+
135
+ for i_k in range(tl.cdiv(K, BK)):
136
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
137
+ b_k = tl.load(p_k, boundary_check=(0, 1))
138
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
139
+ b_w = tl.dot(b_A, b_kb, allow_tf32=False)
140
+ p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
141
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
142
+
143
+
144
+ @triton.autotune(
145
+ configs=[
146
+ triton.Config({}, num_warps=1),
147
+ triton.Config({}, num_warps=2),
148
+ triton.Config({}, num_warps=4),
149
+ triton.Config({}, num_warps=8),
150
+ triton.Config({}, num_warps=16)
151
+ ],
152
+ key=["BT", "BK", "BV"],
153
+ )
154
+ @triton.jit
155
+ def bwd_prepare_wy_repr_kernel(
156
+ k, v, beta, A,
157
+ dw, du,
158
+ dk, dv, dbeta,
159
+ s_qk_h,
160
+ s_qk_t,
161
+ s_qk_d,
162
+ s_vo_h,
163
+ s_vo_t,
164
+ s_vo_d,
165
+ T,
166
+ K,
167
+ V,
168
+ BT: tl.constexpr,
169
+ BK: tl.constexpr,
170
+ BV: tl.constexpr
171
+ ):
172
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
174
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
175
+
176
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
177
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
178
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
179
+ b_beta = tl.load(p_beta, boundary_check=(0,))
180
+
181
+ for i_v in range(tl.cdiv(V, BV)):
182
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
183
+ p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
184
+ b_v = tl.load(p_v, boundary_check=(0, 1))
185
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
186
+ b_du = tl.load(p_du, boundary_check=(0, 1))
187
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
188
+ b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)
189
+ b_dv = b_dv_beta * b_beta[:, None]
190
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
191
+ # store
192
+ p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
193
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
194
+
195
+ for i_k in range(tl.cdiv(K, BK)):
196
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
198
+ b_k = tl.load(p_k, boundary_check=(0, 1))
199
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
200
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
201
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
202
+ b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
203
+ b_dk = b_dk_beta * b_beta[:, None]
204
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
205
+ # store
206
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
207
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
208
+
209
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
210
+ b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False)
211
+ b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False)
212
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
213
+
214
+ for i_k in range(tl.cdiv(K, BK)):
215
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
216
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
217
+ b_k = tl.load(p_k, boundary_check=(0, 1))
218
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
219
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
220
+
221
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
222
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
223
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
224
+ b_dk += b_dk_beta * b_beta[:, None]
225
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
226
+
227
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
228
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
229
+
230
+
231
+ def fwd_prepare_wy_repr(k, v, beta, BT):
232
+ B, H, T, K, V = *k.shape, v.shape[-1]
233
+ u = torch.empty_like(v)
234
+ w = torch.empty_like(k)
235
+ NT = triton.cdiv(T, BT)
236
+ BK = min(triton.next_power_of_2(K), 64)
237
+ BV = min(triton.next_power_of_2(V), 64)
238
+ A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
239
+ fwd_prepare_wy_repr_kernel[(NT, B*H)](
240
+ k, v, beta, w, u, A,
241
+ k.stride(1), k.stride(2), k.stride(3),
242
+ v.stride(1), v.stride(2), v.stride(3),
243
+ T, K, V, BT, BK, BV
244
+ )
245
+ return w, u, A
246
+
247
+
248
+ def fwd_recompute_w_u(k, v, beta, A, BT):
249
+ B, H, T, K, V = *k.shape, v.shape[-1]
250
+ u = torch.empty_like(v)
251
+ w = torch.empty_like(k)
252
+ NT = triton.cdiv(T, BT)
253
+ BK = min(triton.next_power_of_2(K), 64)
254
+ BV = min(triton.next_power_of_2(V), 64)
255
+ fwd_recompute_w_u_kernel[(NT, B*H)](
256
+ k, v, beta, w, u, A,
257
+ k.stride(1), k.stride(2), k.stride(3),
258
+ v.stride(1), v.stride(2), v.stride(3),
259
+ T, K, V, BT, BK, BV
260
+ )
261
+ return w, u
262
+
263
+
264
+ def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
265
+ B, H, T, K, V = *k.shape, v.shape[-1]
266
+
267
+ NT = triton.cdiv(T, BT)
268
+ BK = min(triton.next_power_of_2(K), 64)
269
+ BV = min(triton.next_power_of_2(V), 64)
270
+ NT = triton.cdiv(T, BT)
271
+ dk = torch.empty_like(k)
272
+ dv = torch.empty_like(v).contiguous()
273
+ dbeta = torch.zeros_like(beta)
274
+
275
+ bwd_prepare_wy_repr_kernel[(NT, B*H)](
276
+ k, v, beta, A,
277
+ dw, du,
278
+ dk, dv, dbeta,
279
+ k.stride(1), k.stride(2), k.stride(3),
280
+ v.stride(1), v.stride(2), v.stride(3),
281
+ T, K, V, BT, BK, BV
282
+ )
283
+ return dk, dv, dbeta
284
+
285
+
286
+ class WYRepresentationPrepration(torch.autograd.Function):
287
+
288
+ @staticmethod
289
+ @contiguous
290
+ @autocast_custom_fwd
291
+ def forward(ctx, k, v, beta, chunk_size=64):
292
+ ctx.BT = chunk_size
293
+ w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
294
+ ctx.save_for_backward(k, v, beta, A)
295
+ return w, u
296
+
297
+ @staticmethod
298
+ @contiguous
299
+ @autocast_custom_bwd
300
+ def backward(ctx, dw, du):
301
+ k, v, beta, A = ctx.saved_tensors
302
+ BT = ctx.BT
303
+ dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)
304
+ return dk, dv, dbeta, None
305
+
306
+
307
+ prepare_wy_repr = WYRepresentationPrepration.apply
308
+
309
+
310
+ def naive(k, v, beta, chunk_size):
311
+ l_org = k.shape[2]
312
+ l_new = triton.next_power_of_2(l_org)
313
+ # pad k, v, beta
314
+ k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
315
+ v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
316
+ beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
317
+
318
+ k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
319
+ # k = torch.nn.functional.normalize(k, dim=-1, p=2)
320
+ beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
321
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
322
+ k_beta = k * beta[..., None]
323
+ v = v * beta[..., None]
324
+ attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
325
+ attn = attn * beta[..., None]
326
+ x = attn @ v
327
+
328
+ o = torch.zeros_like(k)
329
+ o2 = torch.zeros_like(v)
330
+
331
+ o[..., 0, :] = k_beta[..., 0, :].clone()
332
+ o2[..., 0, :] = x[..., 0, :].clone()
333
+ for i in range(1, chunk_size):
334
+ o_i = (o[..., :i, :]).clone()
335
+ o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
336
+ o2_i = (o2[..., :i, :]).clone()
337
+ o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
338
+ return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
339
+
340
+
341
+ if __name__ == "__main__":
342
+ torch.set_default_dtype(torch.bfloat16)
343
+ seq_len = 1024
344
+ b = 4
345
+ h = 4
346
+ k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)
347
+ v = torch.randn(b, h, seq_len, 128)
348
+ beta = torch.rand(b, h, seq_len).sigmoid()
349
+ # beta = torch.ones(b, h, seq_len)
350
+ require_grad = True
351
+
352
+ k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
353
+ do = torch.rand_like(k)
354
+ do2 = torch.rand_like(v)
355
+
356
+ o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64)
357
+ if require_grad:
358
+ o1.backward(do, retain_graph=True)
359
+ o2.backward(do2, retain_graph=True)
360
+
361
+ k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
362
+ k.grad = v.grad = beta.grad = None
363
+ o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone(), 64)
364
+ print((o1-o3).abs().max())
365
+ print((o2-o4).abs().max())
366
+
367
+ if require_grad:
368
+ o3.backward(do, retain_graph=True)
369
+ o4.backward(do2, retain_graph=True)
370
+ k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
371
+ print((k_grad2-k_grad).abs().max())
372
+ print((v_grad2-v_grad).abs().max())
373
+ print((beta_grad2-beta_grad).abs().max())
374
+ breakpoint()
opencompass/models/fla2/ops/generalized_delta_rule/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generalized Delta Rule
2
+
3
+ In delta rule we have the recurrence:
4
+
5
+ ```math
6
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T
7
+ ```
8
+
9
+ This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$.
10
+
11
+ ## IPLR (Identity Plus Low Rank)
12
+
13
+ The first variant is IPLR, where we have:
14
+
15
+ ```math
16
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
17
+ ```
18
+
19
+ When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR.
20
+
21
+ ### Numerical Stability
22
+
23
+ $\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix.
24
+
25
+ ## DPLR (Diagonal Plus Low Rank)
26
+
27
+ The second variant is DPLR, where we have:
28
+
29
+ ```math
30
+ \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T
31
+ ```
32
+
33
+ Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7.
34
+
35
+ ## Efficient Chunkwise Implementation
36
+
37
+ For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing).
opencompass/models/fla2/ops/generalized_delta_rule/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule
2
+ from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule',
7
+ 'chunk_iplr_delta_rule',
8
+ 'fused_recurrent_iplr_delta_rule'
9
+ ]
opencompass/models/fla2/ops/generalized_delta_rule/dplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_dplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_dplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule'
7
+ ]
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import triton
9
+ from einops import rearrange
10
+
11
+ from ....ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
12
+ from ....ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_dplr_fwd_intra
13
+ from ....ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu
14
+ from ....ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h
15
+ from ....ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o
16
+ from ....ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o
17
+ from ....ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy
18
+ from ....ops.generalized_delta_rule.dplr.wy_fast_fwd import prepare_wy_repr_fwd
19
+ from ....ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum
20
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
21
+
22
+
23
+ def chunk_dplr_fwd(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ v: torch.Tensor,
27
+ a: torch.Tensor,
28
+ b: torch.Tensor,
29
+ gk: torch.Tensor,
30
+ scale: float,
31
+ initial_state: torch.Tensor,
32
+ output_final_state: bool,
33
+ cu_seqlens: Optional[torch.LongTensor] = None,
34
+ chunk_size: int = 64
35
+ ):
36
+ T = q.shape[1]
37
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
38
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens)
39
+
40
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
41
+ q=q,
42
+ k=k,
43
+ a=a,
44
+ b=b,
45
+ gi=gi,
46
+ ge=ge,
47
+ scale=scale,
48
+ cu_seqlens=cu_seqlens,
49
+ chunk_size=BT,
50
+ )
51
+ del ge
52
+
53
+ # A_ab, A_ak, gi, ge torch.float32
54
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
55
+ w, u, _ = prepare_wy_repr_fwd(
56
+ ag=ag,
57
+ A_ab=A_ab,
58
+ A_ak=A_ak,
59
+ v=v,
60
+ cu_seqlens=cu_seqlens,
61
+ chunk_size=BT
62
+ )
63
+ del A_ab, A_ak
64
+ h, v_new, final_state = chunk_dplr_fwd_h(
65
+ kg=kg,
66
+ bg=bg,
67
+ v=v,
68
+ w=w,
69
+ u=u,
70
+ gk=gi,
71
+ initial_state=initial_state,
72
+ output_final_state=output_final_state,
73
+ cu_seqlens=cu_seqlens,
74
+ chunk_size=BT
75
+ )
76
+ del u, kg, bg, gi
77
+
78
+ o = chunk_dplr_fwd_o(
79
+ qg=qg,
80
+ v=v,
81
+ v_new=v_new,
82
+ A_qk=A_qk,
83
+ A_qb=A_qb,
84
+ h=h,
85
+ cu_seqlens=cu_seqlens,
86
+ chunk_size=BT
87
+ )
88
+ del v_new, h, A_qk, A_qb
89
+
90
+ return o, final_state
91
+
92
+
93
+ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
94
+
95
+ @staticmethod
96
+ @input_guard
97
+ @autocast_custom_fwd
98
+ def forward(
99
+ ctx,
100
+ q: torch.Tensor,
101
+ k: torch.Tensor,
102
+ v: torch.Tensor,
103
+ a: torch.Tensor,
104
+ b: torch.Tensor,
105
+ gk: torch.Tensor,
106
+ scale: float,
107
+ initial_state: torch.Tensor,
108
+ output_final_state: bool,
109
+ cu_seqlens: Optional[torch.LongTensor] = None,
110
+ ):
111
+ chunk_size = 16
112
+ o, final_state = chunk_dplr_fwd(
113
+ q=q,
114
+ k=k,
115
+ v=v,
116
+ a=a,
117
+ b=b,
118
+ gk=gk,
119
+ scale=scale,
120
+ initial_state=initial_state,
121
+ output_final_state=output_final_state,
122
+ cu_seqlens=cu_seqlens,
123
+ chunk_size=chunk_size
124
+ )
125
+ ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
126
+ ctx.cu_seqlens = cu_seqlens
127
+ ctx.scale = scale
128
+ ctx.chunk_size = chunk_size
129
+ return o.to(q.dtype), final_state
130
+
131
+ @staticmethod
132
+ @input_guard
133
+ @autocast_custom_bwd
134
+ def backward(
135
+ ctx,
136
+ do: torch.Tensor,
137
+ dht: torch.Tensor
138
+ ):
139
+ q, k, v, a, b, gk, initial_state = ctx.saved_tensors
140
+ BT = ctx.chunk_size
141
+ cu_seqlens = ctx.cu_seqlens
142
+ scale = ctx.scale
143
+
144
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
145
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens)
146
+
147
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
148
+ q=q,
149
+ k=k,
150
+ a=a,
151
+ b=b,
152
+ gi=gi,
153
+ ge=ge,
154
+ scale=scale,
155
+ cu_seqlens=cu_seqlens,
156
+ chunk_size=BT,
157
+ )
158
+ w, u, A_ab_inv = prepare_wy_repr_fwd(
159
+ ag=ag,
160
+ A_ab=A_ab,
161
+ A_ak=A_ak,
162
+ v=v,
163
+ cu_seqlens=cu_seqlens,
164
+ chunk_size=BT
165
+ )
166
+ del A_ab
167
+ h, v_new, _ = chunk_dplr_fwd_h(
168
+ kg=kg,
169
+ bg=bg,
170
+ v=v,
171
+ w=w,
172
+ u=u,
173
+ gk=gi,
174
+ initial_state=initial_state,
175
+ cu_seqlens=cu_seqlens,
176
+ chunk_size=BT
177
+ )
178
+ del u
179
+ # ******* end of recomputation *******
180
+ # A_ak, A_ab_inv, gi, ge torch.float32
181
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
182
+
183
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
184
+ v=v,
185
+ v_new=v_new,
186
+ do=do,
187
+ A_qb=A_qb,
188
+ scale=scale,
189
+ cu_seqlens=cu_seqlens,
190
+ chunk_size=BT
191
+ )
192
+
193
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
194
+ qg=qg,
195
+ bg=bg,
196
+ w=w,
197
+ gk=gi,
198
+ h0=initial_state,
199
+ dht=dht,
200
+ do=do,
201
+ dv=dv_new_intra,
202
+ cu_seqlens=cu_seqlens,
203
+ chunk_size=BT
204
+ )
205
+
206
+ dv = chunk_dplr_bwd_dv(
207
+ A_qk=A_qk,
208
+ kg=kg,
209
+ do=do,
210
+ dh=dh,
211
+ cu_seqlens=cu_seqlens,
212
+ chunk_size=BT
213
+ )
214
+ del A_qk
215
+
216
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
217
+ k=kg,
218
+ b=bg,
219
+ v=v,
220
+ v_new=v_new,
221
+ do=do,
222
+ h=h,
223
+ dh=dh,
224
+ dv=dv_new,
225
+ w=w,
226
+ gk=gi,
227
+ cu_seqlens=cu_seqlens,
228
+ chunk_size=BT,
229
+ scale=scale,
230
+ )
231
+ del v_new
232
+
233
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
234
+ A_ab_inv=A_ab_inv,
235
+ A_ak=A_ak,
236
+ v=v,
237
+ ag=ag,
238
+ dw=dw,
239
+ du=dv_new,
240
+ dv0=dv,
241
+ cu_seqlens=cu_seqlens,
242
+ chunk_size=BT
243
+ )
244
+ del A_ak
245
+
246
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
247
+ q=q,
248
+ k=k,
249
+ a=a,
250
+ b=b,
251
+ gi=gi,
252
+ ge=ge,
253
+ dAqk=dA_qk,
254
+ dAqb=dA_qb,
255
+ dAak=dA_ak,
256
+ dAab=dA_ab,
257
+ dgk_last=dgk_last,
258
+ dqg=dqg,
259
+ dkg=dkg,
260
+ dag=dag,
261
+ dbg=dbg,
262
+ chunk_size=BT,
263
+ scale=scale,
264
+ cu_seqlens=cu_seqlens,
265
+ )
266
+
267
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None
268
+
269
+
270
+ @torch.compiler.disable
271
+ def chunk_dplr_delta_rule(
272
+ q: torch.Tensor,
273
+ k: torch.Tensor,
274
+ v: torch.Tensor,
275
+ a: torch.Tensor,
276
+ b: torch.Tensor,
277
+ gk: torch.Tensor,
278
+ scale: Optional[float] = None,
279
+ initial_state: Optional[torch.Tensor] = None,
280
+ output_final_state: bool = False,
281
+ cu_seqlens: Optional[torch.LongTensor] = None,
282
+ head_first: bool = False,
283
+ ):
284
+ r"""
285
+ Args:
286
+ q (torch.Tensor):
287
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
288
+ k (torch.Tensor):
289
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
290
+ v (torch.Tensor):
291
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
292
+ a (torch.Tensor):
293
+ activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
294
+ b (torch.Tensor):
295
+ betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
296
+ gk (torch.Tensor):
297
+ gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space!
298
+ scale (Optional[int]):
299
+ Scale factor for the RetNet attention scores.
300
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
301
+ initial_state (Optional[torch.Tensor]):
302
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
303
+ For equal-length input sequences, `N` equals the batch size `B`.
304
+ Default: `None`.
305
+ output_final_state (Optional[bool]):
306
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
307
+ cu_seqlens (torch.LongTensor):
308
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
309
+ consistent with the FlashAttention API.
310
+ head_first (Optional[bool]):
311
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
312
+ Default: `False`.
313
+
314
+ Returns:
315
+ o (torch.Tensor):
316
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
317
+ final_state (torch.Tensor):
318
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319
+ """
320
+ if head_first:
321
+ raise DeprecationWarning(
322
+ "head_first is deprecated and will be removed in a future version. "
323
+ "Please use head_first=False for now instead."
324
+ )
325
+ q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk))
326
+ if not head_first and q.shape[1] < q.shape[2]:
327
+ warnings.warn(
328
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
329
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
330
+ "when head_first=False was specified. "
331
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
332
+ )
333
+ if q.dtype == torch.float32:
334
+ raise DeprecationWarning(
335
+ """ChunkDeltaRuleFunction does not support float32. Please use bfloat16.
336
+ If you want to use float32, please solve the issue by yourself."""
337
+ )
338
+ if cu_seqlens is not None:
339
+ if q.shape[0] != 1:
340
+ raise ValueError(
341
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
342
+ f"Please flatten variable-length inputs before processing."
343
+ )
344
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
345
+ raise ValueError(
346
+ f"The number of initial states is expected to be equal to the number of input sequences, "
347
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
348
+ )
349
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
350
+ o, final_state = ChunkDPLRDeltaRuleFunction.apply(
351
+ q,
352
+ k,
353
+ v,
354
+ a,
355
+ b,
356
+ gk,
357
+ scale,
358
+ initial_state,
359
+ output_final_state,
360
+ cu_seqlens,
361
+ )
362
+ if head_first:
363
+ o = rearrange(o, 'b t h ... -> b h t ...')
364
+ return o, final_state
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_A_bwd.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import exp, gather
12
+ from ....utils import check_shared_mem, is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
21
+ for num_warps in [2, 4, 8, 16, 32]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BK', 'BT', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_bwd_kernel_intra(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi,
34
+ ge,
35
+ dAqk,
36
+ dAqb,
37
+ dAak,
38
+ dAab,
39
+ dq,
40
+ dk,
41
+ da,
42
+ db,
43
+ dqg,
44
+ dkg,
45
+ dag,
46
+ dbg,
47
+ dgk,
48
+ dgk_offset,
49
+ cu_seqlens,
50
+ chunk_indices,
51
+ scale: tl.constexpr,
52
+ T,
53
+ H: tl.constexpr,
54
+ K: tl.constexpr,
55
+ BT: tl.constexpr,
56
+ BC: tl.constexpr,
57
+ BK: tl.constexpr,
58
+ IS_VARLEN: tl.constexpr,
59
+ GATHER_SUPPORTED: tl.constexpr
60
+ ):
61
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_b, i_h = i_bh // H, i_bh % H
63
+ if IS_VARLEN:
64
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ else:
68
+ bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
69
+
70
+ if i_t * BT >= T:
71
+ return
72
+
73
+ # offset calculation
74
+ ge += (bos*H + i_h) * K
75
+ gi += (bos*H + i_h) * K
76
+ q += (bos*H + i_h) * K
77
+ a += (bos*H + i_h) * K
78
+ b += (bos*H + i_h) * K
79
+ k += (bos*H + i_h) * K
80
+ dq += (bos*H + i_h) * K
81
+ dk += (bos*H + i_h) * K
82
+ da += (bos*H + i_h) * K
83
+ db += (bos*H + i_h) * K
84
+ dqg += (bos*H + i_h) * K
85
+ dag += (bos*H + i_h) * K
86
+ dkg += (bos*H + i_h) * K
87
+ dbg += (bos*H + i_h) * K
88
+ dgk += (bos*H + i_h) * K
89
+ dgk_offset += (bos*H + i_h) * K
90
+ dAqk += (bos*H + i_h) * BT
91
+ dAqb += (bos*H + i_h) * BT
92
+ dAak += (bos*H + i_h) * BT
93
+ dAab += (bos*H + i_h) * BT
94
+
95
+ stride_qk = H*K
96
+ stride_A = H*BT
97
+
98
+ p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
99
+ p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
100
+ # [BC, BK]
101
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
102
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
103
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
104
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
107
+ # intra chunk gradient calculation
108
+ p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
109
+ p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
110
+ p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
111
+ p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0))
112
+ o_i = tl.arange(0, BC)
113
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
114
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
115
+ p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
116
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0))
117
+ b_k = tl.load(p_k, boundary_check=(0, 1))
118
+ b_b = tl.load(p_b, boundary_check=(0, 1))
119
+ b_q = tl.load(p_q, boundary_check=(0, 1))
120
+ b_a = tl.load(p_a, boundary_check=(0, 1))
121
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1))
122
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1))
123
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1))
124
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1))
125
+
126
+ # inter chunk gradient calculation
127
+ o_k = i_k * BK + tl.arange(0, BK)
128
+ m_k = o_k < K
129
+ # intra chunk gradient calculation
130
+ for j in range(0, min(BC, T - i_t * BT)):
131
+ # trick to index the block
132
+ if GATHER_SUPPORTED:
133
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
134
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
135
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
136
+ # [1, BK]
137
+ b_kj = gather(b_k, row_idx, axis=0)
138
+ b_bj = gather(b_b, row_idx, axis=0)
139
+ b_gij = gather(b_gi, row_idx, axis=0)
140
+ b_gej = gather(b_ge, row_idx, axis=0)
141
+ b_qj = gather(b_q, row_idx, axis=0)
142
+ b_aj = gather(b_a, row_idx, axis=0)
143
+ # [BC, 1]
144
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
145
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
146
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
147
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
148
+ # [1, BC] -> [BC, 1]
149
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
150
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
151
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
152
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
153
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
154
+ else:
155
+ mask_idx = tl.arange(0, BC) == j
156
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
157
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
158
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
159
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
160
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
161
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
162
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
163
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
164
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
165
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
166
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
167
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
168
+ # [1, BK] b_qj, b_aj
169
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
170
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
171
+
172
+ m_e = o_i[:, None] > j
173
+ m_i = o_i[:, None] >= j
174
+ tmp1 = exp(b_gi - b_gij)
175
+ tmp2 = exp(b_ge - b_gij)
176
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.)
177
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.)
178
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.)
179
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.)
180
+
181
+ m_i = o_i[:, None] <= j
182
+ m_e = o_i[:, None] < j
183
+ tmp1 = exp(b_gij - b_gi)
184
+ tmp2 = exp(b_gej - b_gi)
185
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.)
186
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.)
187
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.)
188
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.)
189
+
190
+ # post processing
191
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
192
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
193
+ p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
194
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
195
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
196
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
197
+ p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
198
+ p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
199
+ p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
200
+ p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
201
+ p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k
202
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
203
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
204
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
205
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
206
+ tmp = exp(b_gn[None, :] - b_gi)
207
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)).to(tl.float32) * tmp
208
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)).to(tl.float32) * tmp
209
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
212
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
213
+ b_dgk = (b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b).to(tl.float32)
214
+ b_dgk_offset = b_da * b_a
215
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
216
+ tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1))
217
+
218
+
219
+ @triton.heuristics({
220
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
221
+ })
222
+ @triton.autotune(
223
+ configs=[
224
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
225
+ for num_warps in [2, 4, 8, 16, 32]
226
+ for num_stages in [2, 3, 4]
227
+ for BK in [32, 64]
228
+ ],
229
+ key=['BK', 'BT', 'K'],
230
+ use_cuda_graph=use_cuda_graph,
231
+ )
232
+ @triton.jit(do_not_specialize=['T'])
233
+ def chunk_dplr_bwd_dgk_kernel(
234
+ dgk,
235
+ dgk_offset,
236
+ dgk_last,
237
+ dgk_output,
238
+ cu_seqlens,
239
+ chunk_indices,
240
+ T,
241
+ H: tl.constexpr,
242
+ K: tl.constexpr,
243
+ BT: tl.constexpr,
244
+ BK: tl.constexpr,
245
+ IS_VARLEN: tl.constexpr,
246
+ ):
247
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_b, i_h = i_bh // H, i_bh % H
249
+ if IS_VARLEN:
250
+ i_tg = i_t
251
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
252
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
253
+ T = eos - bos
254
+ NT = tl.cdiv(T, BT)
255
+ else:
256
+ NT = tl.cdiv(T, BT)
257
+ i_tg = (i_b * NT + i_t).to(tl.int32)
258
+ bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
259
+
260
+ stride_qk = H * K
261
+ dgk += (bos * H + i_h) * K
262
+ dgk_offset += (bos * H + i_h) * K
263
+ dgk_last += (i_tg * H + i_h) * K
264
+ dgk_output += (bos * H + i_h) * K
265
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
266
+ m_k = tl.arange(0, BK) + i_k * BK < K
267
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
268
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
270
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
271
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
272
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
273
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
274
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
275
+ b_dgk_cumsum += b_dgk_last[None, :]
276
+ b_dgk_cumsum -= b_dgk_offset
277
+ p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
278
+ tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1))
279
+
280
+
281
+ def chunk_dplr_bwd_dqk_intra(
282
+ q: torch.Tensor,
283
+ k: torch.Tensor,
284
+ a: torch.Tensor,
285
+ b: torch.Tensor,
286
+ gi: torch.Tensor,
287
+ ge: torch.Tensor,
288
+ dAqk: torch.Tensor,
289
+ dAqb: torch.Tensor,
290
+ dAak: torch.Tensor,
291
+ dAab: torch.Tensor,
292
+ dqg: torch.Tensor,
293
+ dkg: torch.Tensor,
294
+ dag: torch.Tensor,
295
+ dbg: torch.Tensor,
296
+ dgk_last: torch.Tensor,
297
+ scale: float = 1.0,
298
+ cu_seqlens: Optional[torch.LongTensor] = None,
299
+ chunk_size: int = 64,
300
+ ):
301
+ B, T, H, K = q.shape
302
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
303
+ BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
304
+
305
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
306
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
307
+ NK = triton.cdiv(K, BK)
308
+
309
+ dq = torch.empty_like(q)
310
+ dk = torch.empty_like(k)
311
+ da = torch.empty_like(a)
312
+ db = torch.empty_like(b)
313
+ dgk = torch.empty_like(gi, dtype=torch.float)
314
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
315
+
316
+ grid = (NK, NT, B * H)
317
+ chunk_dplr_bwd_kernel_intra[grid](
318
+ q=q,
319
+ k=k,
320
+ a=a,
321
+ b=b,
322
+ gi=gi,
323
+ ge=ge,
324
+ dAqk=dAqk,
325
+ dAqb=dAqb,
326
+ dAak=dAak,
327
+ dAab=dAab,
328
+ dq=dq,
329
+ dk=dk,
330
+ dgk=dgk,
331
+ dgk_offset=dgk_offset,
332
+ dqg=dqg,
333
+ dkg=dkg,
334
+ dag=dag,
335
+ dbg=dbg,
336
+ da=da,
337
+ db=db,
338
+ cu_seqlens=cu_seqlens,
339
+ chunk_indices=chunk_indices,
340
+ scale=scale,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ BT=BT,
345
+ BC=BT,
346
+ BK=BK,
347
+ GATHER_SUPPORTED=is_gather_supported
348
+ )
349
+
350
+ dgk_output = torch.empty_like(dgk)
351
+
352
+ def grid(meta): return (NT, triton.cdiv(K, meta['BK']), B * H)
353
+ chunk_dplr_bwd_dgk_kernel[grid](
354
+ dgk=dgk,
355
+ dgk_offset=dgk_offset,
356
+ dgk_last=dgk_last,
357
+ dgk_output=dgk_output,
358
+ cu_seqlens=cu_seqlens,
359
+ chunk_indices=chunk_indices,
360
+ T=T,
361
+ H=H,
362
+ K=K,
363
+ BT=BT,
364
+ )
365
+ return dq, dk, da, db, dgk_output
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import exp, gather
12
+ from ....utils import is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
21
+ for num_warps in [2, 4, 8, 16, 32]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BK', 'BT'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi,
34
+ ge,
35
+ qg,
36
+ kg,
37
+ ag,
38
+ bg,
39
+ Aqk,
40
+ Aqb,
41
+ Aab,
42
+ Aak,
43
+ cu_seqlens,
44
+ chunk_indices,
45
+ scale: tl.constexpr,
46
+ T,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ BT: tl.constexpr,
50
+ BC: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ IS_VARLEN: tl.constexpr,
53
+ GATHER_SUPPORTED: tl.constexpr
54
+ ):
55
+ i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+
57
+ if IS_VARLEN:
58
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
59
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ if i_t * BT >= T:
65
+ return
66
+
67
+ o_i = tl.arange(0, BC)
68
+ o_k = tl.arange(0, BK)
69
+ m_k = o_k < K
70
+ m_A = (i_t * BT + tl.arange(0, BC)) < T
71
+ last_idx = min((i_t+1) * BT, T) - 1
72
+ o_A = (bos + i_t * BT + tl.arange(0, BC)) * H*BT + i_h * BT
73
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
74
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
75
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
76
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
77
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
78
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
79
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
80
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
81
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
82
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
83
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
84
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
85
+
86
+ b_q = tl.load(p_q, boundary_check=(0, 1))
87
+ b_q = b_q * scale
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ b_a = tl.load(p_a, boundary_check=(0, 1))
90
+ b_b = tl.load(p_b, boundary_check=(0, 1))
91
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
92
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
93
+
94
+ # deal with decay term.
95
+ g_exp = exp(b_gi)
96
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
97
+ b_qg = b_q * g_exp
98
+ b_kg = b_k * g_exp_inv
99
+ b_bg = b_b * g_exp_inv
100
+ b_ag = b_a * exp(b_ge)
101
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
102
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
103
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
104
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
105
+ # tl.debug_barrier()
106
+
107
+ b_q = b_q.to(b_k.dtype)
108
+ # inner attn
109
+ for j in range(0, min(BC, T - i_t * BT)):
110
+ # a trick to index the j-th row of b_k, b_g, b_b
111
+ if GATHER_SUPPORTED:
112
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
113
+ # [1, BK]
114
+ b_k_j = gather(b_k, row_idx, axis=0)
115
+ b_gk_j = gather(b_gi, row_idx, axis=0)
116
+ b_b_j = gather(b_b, row_idx, axis=0)
117
+ else:
118
+ mask = tl.arange(0, BC) == j
119
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
120
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
121
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
122
+ tmp = exp(b_gi - b_gk_j)
123
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
124
+ m_i = (o_i >= j).to(tl.float32)
125
+ b_A_qk = b_A_qk * m_i
126
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
127
+ b_A_qb = b_A_qb * m_i
128
+ tmp2 = exp(b_ge - b_gk_j)
129
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
130
+ m_i2 = (o_i > j).to(tl.float32)
131
+ b_A_ak = b_A_ak * m_i2
132
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
133
+ b_A_ab = b_A_ab * m_i2
134
+
135
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
136
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
137
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
138
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
139
+
140
+
141
+ def chunk_dplr_fwd_intra(
142
+ q: torch.Tensor,
143
+ k: torch.Tensor,
144
+ a: torch.Tensor,
145
+ b: torch.Tensor,
146
+ gi: torch.Tensor,
147
+ ge: torch.Tensor,
148
+ scale: float,
149
+ chunk_size: int,
150
+ cu_seqlens: Optional[torch.LongTensor] = None,
151
+ ):
152
+ B, T, H, K = k.shape
153
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
154
+
155
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
156
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
157
+
158
+ Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype)
159
+ Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype)
160
+ # involving matrix inverse and it'd be better to use float here.
161
+ Aab = q.new_empty(B, T, H, BT, dtype=torch.float)
162
+ Aak = q.new_empty(B, T, H, BT, dtype=torch.float)
163
+
164
+ grid = (NT, B, H)
165
+ BK = triton.next_power_of_2(K)
166
+ qg = torch.empty_like(q)
167
+ kg = torch.empty_like(k, dtype=q.dtype)
168
+ ag = torch.empty_like(a, dtype=q.dtype)
169
+ bg = torch.empty_like(b, dtype=q.dtype)
170
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
171
+ q=q,
172
+ k=k,
173
+ a=a,
174
+ b=b,
175
+ gi=gi,
176
+ ge=ge,
177
+ Aqk=Aqk,
178
+ Aqb=Aqb,
179
+ Aab=Aab,
180
+ Aak=Aak,
181
+ qg=qg,
182
+ kg=kg,
183
+ ag=ag,
184
+ bg=bg,
185
+ cu_seqlens=cu_seqlens,
186
+ chunk_indices=chunk_indices,
187
+ scale=scale,
188
+ T=T,
189
+ H=H,
190
+ K=K,
191
+ BT=BT,
192
+ BC=BT,
193
+ BK=BK,
194
+ GATHER_SUPPORTED=is_gather_supported
195
+ )
196
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ cu_seqlens,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ IS_VARLEN: tl.constexpr,
54
+ ):
55
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ NT = tl.cdiv(T, BT)
61
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
62
+ else:
63
+ bos, eos = i_n * T, i_n * T + T
64
+ NT = tl.cdiv(T, BT)
65
+ boh = i_n * NT
66
+
67
+ # [BK, BV]
68
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
69
+ if USE_FINAL_STATE_GRADIENT:
70
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
71
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
72
+
73
+ mask_k = tl.arange(0, BK) < K
74
+ for i_t in range(NT - 1, -1, -1):
75
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
77
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
78
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
79
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
80
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
83
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
84
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ # [BK, BT]
86
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
87
+ # [BT, BK]
88
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
89
+ b_w = tl.load(p_w, boundary_check=(0, 1))
90
+ # [BT, V]
91
+ b_do = tl.load(p_do, boundary_check=(0, 1))
92
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
93
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
94
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
95
+ # [BK, BV]
96
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
97
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
98
+ last_idx = min((i_t + 1) * BT, T) - 1
99
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
100
+ b_dh *= exp(bg_last)[:, None]
101
+ b_dh += b_dh_tmp
102
+
103
+ if USE_INITIAL_STATE:
104
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
105
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
106
+
107
+
108
+ def chunk_dplr_bwd_dhu(
109
+ qg: torch.Tensor,
110
+ bg: torch.Tensor,
111
+ w: torch.Tensor,
112
+ gk: torch.Tensor,
113
+ h0: torch.Tensor,
114
+ dht: Optional[torch.Tensor],
115
+ do: torch.Tensor,
116
+ dv: torch.Tensor,
117
+ cu_seqlens: Optional[torch.LongTensor] = None,
118
+ chunk_size: int = 64
119
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
120
+ B, T, H, K, V = *qg.shape, do.shape[-1]
121
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
122
+ BK = triton.next_power_of_2(K)
123
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
124
+ # H100
125
+ if check_shared_mem('hopper', qg.device.index):
126
+ BV = 64
127
+ BC = 64 if K <= 128 else 32
128
+ elif check_shared_mem('ampere', qg.device.index): # A100
129
+ BV = 32
130
+ BC = 32
131
+ else: # Etc: 4090
132
+ BV = 16
133
+ BC = 16
134
+
135
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
136
+ # N: the actual number of sequences in the batch with either equal or variable lengths
137
+ if cu_seqlens is None:
138
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
139
+ else:
140
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
141
+
142
+ BC = min(BT, BC)
143
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
144
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
145
+
146
+ dh = qg.new_empty(B, NT, H, K, V)
147
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
148
+ dv2 = torch.zeros_like(dv)
149
+
150
+ grid = (NK, NV, N * H)
151
+ chunk_dplr_bwd_kernel_dhu[grid](
152
+ qg=qg,
153
+ bg=bg,
154
+ w=w,
155
+ gk=gk,
156
+ dht=dht,
157
+ dh0=dh0,
158
+ do=do,
159
+ dh=dh,
160
+ dv=dv,
161
+ dv2=dv2,
162
+ cu_seqlens=cu_seqlens,
163
+ chunk_offsets=chunk_offsets,
164
+ T=T,
165
+ H=H,
166
+ K=K,
167
+ V=V,
168
+ BT=BT,
169
+ BC=BC,
170
+ BK=BK,
171
+ BV=BV,
172
+ )
173
+ return dh, dh0, dv2
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ cu_seqlens,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ IS_VARLEN: tl.constexpr,
54
+ ):
55
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ NT = tl.cdiv(T, BT)
61
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
62
+ else:
63
+ bos, eos = i_n * T, i_n * T + T
64
+ NT = tl.cdiv(T, BT)
65
+ boh = i_n * NT
66
+ o_k = i_k * BK + tl.arange(0, BK)
67
+
68
+ # [BK, BV]
69
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_INITIAL_STATE:
71
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
73
+
74
+ for i_t in range(NT):
75
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
79
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
80
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
81
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
83
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
84
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
87
+ # [BK, BC]
88
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
89
+ b_v = tl.load(p_v, boundary_check=(0, 1))
90
+ b_w = tl.load(p_w, boundary_check=(0, 1))
91
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
92
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
93
+ b_hc += tl.dot(b_kg, b_v)
94
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
95
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
96
+
97
+ last_idx = min((i_t + 1) * BT, T) - 1
98
+ b_g_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k, mask=o_k < K).to(tl.float32)
99
+ b_h *= exp(b_g_last[:, None])
100
+ b_h += b_hc
101
+
102
+ if STORE_FINAL_STATE:
103
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
104
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
105
+
106
+
107
+ def chunk_dplr_fwd_h(
108
+ kg: torch.Tensor,
109
+ v: torch.Tensor,
110
+ w: torch.Tensor,
111
+ u: torch.Tensor,
112
+ bg: torch.Tensor,
113
+ gk: torch.Tensor,
114
+ initial_state: Optional[torch.Tensor] = None,
115
+ output_final_state: bool = False,
116
+ cu_seqlens: Optional[torch.LongTensor] = None,
117
+ chunk_size: int = 64
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ B, T, H, K, V = *kg.shape, u.shape[-1]
120
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
121
+
122
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
123
+ # N: the actual number of sequences in the batch with either equal or variable lengths
124
+ if cu_seqlens is None:
125
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
126
+ else:
127
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
128
+ BK = triton.next_power_of_2(K)
129
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
130
+ # H100 can have larger block size
131
+
132
+ if check_shared_mem('hopper', kg.device.index):
133
+ BV = 64
134
+ BC = 64 if K <= 128 else 32
135
+ elif check_shared_mem('ampere', kg.device.index): # A100
136
+ BV = 32
137
+ BC = 32
138
+ else:
139
+ BV = 16
140
+ BC = 16
141
+
142
+ BC = min(BT, BC)
143
+ NK = triton.cdiv(K, BK)
144
+ NV = triton.cdiv(V, BV)
145
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
146
+
147
+ h = kg.new_empty(B, NT, H, K, V)
148
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
149
+ v_new = torch.empty_like(u)
150
+ grid = (NK, NV, N * H)
151
+ chunk_dplr_fwd_kernel_h[grid](
152
+ kg=kg,
153
+ v=v,
154
+ w=w,
155
+ bg=bg,
156
+ u=u,
157
+ v_new=v_new,
158
+ h=h,
159
+ gk=gk,
160
+ h0=initial_state,
161
+ ht=final_state,
162
+ cu_seqlens=cu_seqlens,
163
+ chunk_offsets=chunk_offsets,
164
+ T=T,
165
+ H=H,
166
+ K=K,
167
+ V=V,
168
+ BT=BT,
169
+ BC=BC,
170
+ BK=BK,
171
+ BV=BV,
172
+ )
173
+ return h, v_new, final_state
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BV', 'BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dAu(
31
+ v,
32
+ do,
33
+ v_new,
34
+ A_qb,
35
+ dA_qk,
36
+ dA_qb,
37
+ dv_new,
38
+ cu_seqlens,
39
+ chunk_indices,
40
+ scale: tl.constexpr,
41
+ T,
42
+ H: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BV: tl.constexpr,
46
+ IS_VARLEN: tl.constexpr,
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if IS_VARLEN:
51
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
53
+ else:
54
+ bos, eos = i_b * T, i_b * T + T
55
+ T = eos - bos
56
+
57
+ b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32)
58
+ b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32)
59
+
60
+ p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
61
+
62
+ b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1))
63
+ # causal mask
64
+ b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype)
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
68
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
69
+ p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
70
+ p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
71
+ b_v = tl.load(p_v, boundary_check=(0, 1))
72
+ b_do = tl.load(p_do, boundary_check=(0, 1))
73
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
74
+ b_dA_qk += tl.dot(b_do, b_v)
75
+ b_dA_qb += tl.dot(b_do, b_v_new)
76
+ b_dv_new = tl.dot(tl.trans(b_A_qb), b_do)
77
+ # for recurrent
78
+ tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1))
79
+
80
+ p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
81
+ p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
82
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
83
+ b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.)
84
+ tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1))
85
+ b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.)
86
+ tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1))
87
+
88
+
89
+ @triton.heuristics({
90
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
91
+ })
92
+ @triton.autotune(
93
+ configs=[
94
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
95
+ for num_warps in [2, 4, 8, 16, 32]
96
+ for num_stages in [2, 3, 4]
97
+ ],
98
+ key=['BT', 'BK', 'BV'],
99
+ use_cuda_graph=use_cuda_graph,
100
+ )
101
+ @triton.jit
102
+ def chunk_dplr_bwd_o_kernel(
103
+ v,
104
+ v_new,
105
+ h,
106
+ do,
107
+ dh,
108
+ dk,
109
+ db,
110
+ w,
111
+ dq,
112
+ dv,
113
+ dw,
114
+ gk,
115
+ dgk_last,
116
+ k,
117
+ b,
118
+ cu_seqlens,
119
+ chunk_indices,
120
+ T,
121
+ H: tl.constexpr,
122
+ K: tl.constexpr,
123
+ V: tl.constexpr,
124
+ BT: tl.constexpr,
125
+ BK: tl.constexpr,
126
+ BV: tl.constexpr,
127
+ IS_VARLEN: tl.constexpr,
128
+ ):
129
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
130
+ i_b, i_h = i_bh // H, i_bh % H
131
+
132
+ if IS_VARLEN:
133
+ i_tg = i_t
134
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
135
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
136
+ T = eos - bos
137
+ NT = tl.cdiv(T, BT)
138
+ else:
139
+ NT = tl.cdiv(T, BT)
140
+ i_tg = i_b * NT + i_t
141
+ bos, eos = i_b * T, i_b * T + T
142
+
143
+ # offset calculation
144
+ v += (bos * H + i_h) * V
145
+ v_new += (bos * H + i_h) * V
146
+ do += (bos * H + i_h) * V
147
+ h += (i_tg * H + i_h) * K * V
148
+ dh += (i_tg * H + i_h) * K * V
149
+ dk += (bos * H + i_h) * K
150
+ k += (bos * H + i_h) * K
151
+ db += (bos * H + i_h) * K
152
+ b += (bos * H + i_h) * K
153
+ dw += (bos * H + i_h) * K
154
+ dv += (bos * H + i_h) * V
155
+ dq += (bos * H + i_h) * K
156
+ w += (bos * H + i_h) * K
157
+
158
+ dgk_last += (i_tg * H + i_h) * K
159
+ gk += (bos * H + i_h) * K
160
+
161
+ stride_qk = H*K
162
+ stride_vo = H*V
163
+
164
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
165
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
166
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32)
167
+ b_db = tl.zeros([BT, BK], dtype=tl.float32)
168
+ b_dgk_last = tl.zeros([BK], dtype=tl.float32)
169
+
170
+ for i_v in range(tl.cdiv(V, BV)):
171
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
172
+ p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
173
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
174
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
175
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
176
+ # [BT, BV]
177
+ b_v = tl.load(p_v, boundary_check=(0, 1))
178
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
179
+ b_do = tl.load(p_do, boundary_check=(0, 1))
180
+ # [BV, BK]
181
+ b_h = tl.load(p_h, boundary_check=(0, 1))
182
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
183
+ b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0)
184
+
185
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
186
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
187
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
188
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
189
+ b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype))
190
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
191
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
192
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
193
+
194
+ m_k = (i_k*BK+tl.arange(0, BK)) < K
195
+ last_idx = min(i_t * BT + BT, T) - 1
196
+ b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf'))
197
+ b_dgk_last *= exp(b_gk_last)
198
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
199
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
200
+ b_k = tl.load(p_k, boundary_check=(0, 1))
201
+ b_b = tl.load(p_b, boundary_check=(0, 1))
202
+ b_dgk_last += tl.sum(b_k * b_dk, axis=0)
203
+ b_dgk_last += tl.sum(b_b * b_db, axis=0)
204
+ tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k)
205
+
206
+ p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
207
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
208
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
209
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
210
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
212
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
213
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
214
+
215
+
216
+ @triton.heuristics({
217
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
218
+ })
219
+ @triton.autotune(
220
+ configs=[
221
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
222
+ for num_warps in [2, 4, 8, 16, 32]
223
+ for num_stages in [2, 3, 4]
224
+ for BK in BK_LIST
225
+ for BV in BK_LIST
226
+ ],
227
+ key=['BT'],
228
+ use_cuda_graph=use_cuda_graph,
229
+ )
230
+ @triton.jit
231
+ def chunk_dplr_bwd_kernel_dv(
232
+ A_qk,
233
+ kg,
234
+ do,
235
+ dv,
236
+ dh,
237
+ cu_seqlens,
238
+ chunk_indices,
239
+ T,
240
+ H: tl.constexpr,
241
+ K: tl.constexpr,
242
+ V: tl.constexpr,
243
+ BT: tl.constexpr,
244
+ BK: tl.constexpr,
245
+ BV: tl.constexpr,
246
+ IS_VARLEN: tl.constexpr,
247
+ ):
248
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
249
+ i_b, i_h = i_bh // H, i_bh % H
250
+ if IS_VARLEN:
251
+ i_tg = i_t
252
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
253
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
254
+ T = eos - bos
255
+ NT = tl.cdiv(T, BT)
256
+ else:
257
+ NT = tl.cdiv(T, BT)
258
+ i_tg = i_b * NT + i_t
259
+ bos, eos = i_b * T, i_b * T + T
260
+
261
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
262
+
263
+ # offset calculation
264
+ A_qk += (bos * H + i_h) * BT
265
+ do += (bos * H + i_h) * V
266
+ dv += (bos * H + i_h) * V
267
+ kg += (bos * H + i_h) * K
268
+ dh += (i_tg * H + i_h) * K*V
269
+
270
+ stride_qk = H*K
271
+ stride_vo = H*V
272
+ stride_A = H*BT
273
+
274
+ for i_k in range(tl.cdiv(K, BK)):
275
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
276
+ p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
277
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
278
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
279
+ b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype))
280
+
281
+ p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1))
282
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0)
283
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
284
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
285
+ b_do = tl.load(p_do, boundary_check=(0, 1))
286
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ def chunk_dplr_bwd_dv(
291
+ A_qk: torch.Tensor,
292
+ kg: torch.Tensor,
293
+ do: torch.Tensor,
294
+ dh: torch.Tensor,
295
+ cu_seqlens: Optional[torch.LongTensor] = None,
296
+ chunk_size: int = 64
297
+ ) -> torch.Tensor:
298
+ B, T, H, K, V = *kg.shape, do.shape[-1]
299
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
300
+
301
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
302
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
303
+
304
+ dv = torch.empty_like(do)
305
+
306
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
307
+ chunk_dplr_bwd_kernel_dv[grid](
308
+ A_qk=A_qk,
309
+ kg=kg,
310
+ do=do,
311
+ dv=dv,
312
+ dh=dh,
313
+ cu_seqlens=cu_seqlens,
314
+ chunk_indices=chunk_indices,
315
+ T=T,
316
+ H=H,
317
+ K=K,
318
+ V=V,
319
+ BT=BT,
320
+ )
321
+ return dv
322
+
323
+
324
+ def chunk_dplr_bwd_o(
325
+ k: torch.Tensor,
326
+ b: torch.Tensor,
327
+ v: torch.Tensor,
328
+ v_new: torch.Tensor,
329
+ gk: torch.Tensor,
330
+ do: torch.Tensor,
331
+ h: torch.Tensor,
332
+ dh: torch.Tensor,
333
+ dv: torch.Tensor,
334
+ w: torch.Tensor,
335
+ cu_seqlens: Optional[torch.LongTensor] = None,
336
+ chunk_size: int = 64,
337
+ scale: float = 1.0,
338
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
339
+
340
+ B, T, H, K, V = *w.shape, v.shape[-1]
341
+
342
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
343
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
344
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
345
+
346
+ BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
347
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32)
348
+ NK = triton.cdiv(K, BK)
349
+ dq = torch.empty_like(k)
350
+ dk = torch.empty_like(k)
351
+ dw = torch.empty_like(w)
352
+ db = torch.empty_like(b)
353
+ grid = (NK, NT, B * H)
354
+
355
+ dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
356
+
357
+ chunk_dplr_bwd_o_kernel[grid](
358
+ k=k,
359
+ b=b,
360
+ v=v,
361
+ v_new=v_new,
362
+ h=h,
363
+ do=do,
364
+ dh=dh,
365
+ dq=dq,
366
+ dk=dk,
367
+ db=db,
368
+ dgk_last=dgk_last,
369
+ w=w,
370
+ dv=dv,
371
+ dw=dw,
372
+ gk=gk,
373
+ cu_seqlens=cu_seqlens,
374
+ chunk_indices=chunk_indices,
375
+ T=T,
376
+ H=H,
377
+ K=K,
378
+ V=V,
379
+ BT=BT,
380
+ BK=BK,
381
+ BV=BV,
382
+ )
383
+ return dq, dk, dw, db, dgk_last
384
+
385
+
386
+ def chunk_dplr_bwd_dAu(
387
+ v: torch.Tensor,
388
+ v_new: torch.Tensor,
389
+ do: torch.Tensor,
390
+ A_qb: torch.Tensor,
391
+ scale: float,
392
+ cu_seqlens: Optional[torch.LongTensor] = None,
393
+ chunk_size: int = 64
394
+ ) -> torch.Tensor:
395
+ B, T, H, V = v.shape
396
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
397
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
398
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
399
+
400
+ if check_shared_mem('ampere'): # A100
401
+ BV = min(triton.next_power_of_2(V), 128)
402
+ elif check_shared_mem('ada'): # 4090
403
+ BV = min(triton.next_power_of_2(V), 64)
404
+ else:
405
+ BV = min(triton.next_power_of_2(V), 32)
406
+
407
+ grid = (NT, B * H)
408
+ dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
409
+ dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
410
+ dv_new = torch.empty_like(v_new)
411
+ chunk_dplr_bwd_kernel_dAu[grid](
412
+ v=v,
413
+ do=do,
414
+ v_new=v_new,
415
+ A_qb=A_qb,
416
+ dA_qk=dA_qk,
417
+ dA_qb=dA_qb,
418
+ dv_new=dv_new,
419
+ cu_seqlens=cu_seqlens,
420
+ chunk_indices=chunk_indices,
421
+ scale=scale,
422
+ T=T,
423
+ H=H,
424
+ V=V,
425
+ BT=BT,
426
+ BV=BV,
427
+ )
428
+ return dv_new, dA_qk, dA_qb
opencompass/models/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, use_cuda_graph
12
+
13
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BK in BK_LIST
23
+ for BV in BK_LIST
24
+ for num_warps in [2, 4, 8, 16, 32]
25
+ for num_stages in [2, 3, 4]
26
+ ],
27
+ key=['BT'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_dplr_fwd_kernel_o(
32
+ qg,
33
+ v,
34
+ v_new,
35
+ A_qk,
36
+ A_qb,
37
+ h,
38
+ o,
39
+ cu_seqlens,
40
+ chunk_indices,
41
+ T,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if IS_VARLEN:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
67
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
68
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
69
+ b_h = tl.load(p_h, boundary_check=(0, 1))
70
+ b_o += tl.dot(b_qg, b_h)
71
+
72
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
74
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
75
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+
78
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
79
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
80
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
81
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
82
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
83
+ b_v = tl.load(p_v, boundary_check=(0, 1))
84
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
85
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
86
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
87
+
88
+
89
+ def chunk_dplr_fwd_o(
90
+ qg: torch.Tensor,
91
+ v: torch.Tensor,
92
+ v_new: torch.Tensor,
93
+ A_qk: torch.Tensor,
94
+ A_qb: torch.Tensor,
95
+ h: torch.Tensor,
96
+ cu_seqlens: Optional[torch.LongTensor] = None,
97
+ chunk_size: int = 64
98
+ ) -> torch.Tensor:
99
+ B, T, H, K, V = *qg.shape, v.shape[-1]
100
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
101
+
102
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
103
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
104
+
105
+ o = torch.empty_like(v)
106
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
107
+ chunk_dplr_fwd_kernel_o[grid](
108
+ qg=qg,
109
+ v=v,
110
+ v_new=v_new,
111
+ A_qk=A_qk,
112
+ A_qb=A_qb,
113
+ h=h,
114
+ o=o,
115
+ cu_seqlens=cu_seqlens,
116
+ chunk_indices=chunk_indices,
117
+ T=T,
118
+ H=H,
119
+ K=K,
120
+ V=V,
121
+ BT=BT,
122
+ )
123
+ return o
opencompass/models/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils.op import exp
11
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ cu_seqlens,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ IS_VARLEN: tl.constexpr,
53
+ ):
54
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
55
+ i_n, i_h = i_nh // H, i_nh % H
56
+
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
59
+ T = eos - bos
60
+ else:
61
+ bos, eos = i_n * T, i_n * T + T
62
+
63
+ o_k = tl.arange(0, BK)
64
+ o_v = i_v * BV + tl.arange(0, BV)
65
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
66
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
67
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
68
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
69
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
70
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
71
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
72
+
73
+ mask_k = o_k < K
74
+ mask_v = o_v < V
75
+ mask_h = mask_k[None, :] & mask_v[:, None]
76
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
77
+
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
80
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
81
+
82
+ for _ in range(0, T):
83
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
84
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
86
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
87
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
88
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
89
+
90
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
91
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
92
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
93
+
94
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
95
+ p_q += (-1 if REVERSE else 1) * H*K
96
+ p_k += (-1 if REVERSE else 1) * H*K
97
+ p_a += (-1 if REVERSE else 1) * H*K
98
+ p_b += (-1 if REVERSE else 1) * H*K
99
+ p_gk += (-1 if REVERSE else 1) * H*K
100
+ p_v += (-1 if REVERSE else 1) * H*V
101
+ p_o += (-1 if REVERSE else 1) * H*V
102
+
103
+ if STORE_FINAL_STATE:
104
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
105
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
106
+
107
+
108
+ def fused_recurrent_dplr_delta_rule_fwd(
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: Optional[float] = 1.0,
116
+ initial_state: Optional[torch.Tensor] = None,
117
+ output_final_state: bool = False,
118
+ reverse: bool = False,
119
+ cu_seqlens: Optional[torch.LongTensor] = None,
120
+ ):
121
+ B, T, H, K, V = *k.shape, v.shape[-1]
122
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
123
+ BK = triton.next_power_of_2(K)
124
+
125
+ h0 = initial_state
126
+ if output_final_state:
127
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
128
+ else:
129
+ ht = None
130
+ o = torch.empty_like(v)
131
+
132
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
133
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
134
+ q,
135
+ k,
136
+ v,
137
+ a,
138
+ b,
139
+ gk,
140
+ o,
141
+ h0,
142
+ ht,
143
+ cu_seqlens,
144
+ scale,
145
+ T=T,
146
+ B=B,
147
+ H=H,
148
+ K=K,
149
+ V=V,
150
+ BK=BK,
151
+ REVERSE=reverse,
152
+ )
153
+ return o, ht
154
+
155
+
156
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ @input_guard
160
+ @autocast_custom_fwd
161
+ def forward(
162
+ ctx,
163
+ q: torch.Tensor,
164
+ k: torch.Tensor,
165
+ v: torch.Tensor,
166
+ a: torch.Tensor,
167
+ b: torch.Tensor,
168
+ gk: torch.Tensor,
169
+ scale: Optional[float] = 1.0,
170
+ initial_state: Optional[torch.Tensor] = None,
171
+ output_final_state: bool = False,
172
+ reverse: bool = False,
173
+ cu_seqlens: Optional[torch.LongTensor] = None,
174
+ ):
175
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
176
+ q=q,
177
+ k=k,
178
+ v=v,
179
+ a=a,
180
+ b=b,
181
+ gk=gk,
182
+ scale=scale,
183
+ initial_state=initial_state,
184
+ output_final_state=output_final_state,
185
+ reverse=reverse,
186
+ cu_seqlens=cu_seqlens,
187
+ )
188
+ return o, ht
189
+
190
+ @staticmethod
191
+ @input_guard
192
+ @autocast_custom_bwd
193
+ def backward(ctx, do, dht):
194
+ raise NotImplementedError(
195
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
196
+ "This kernel is only for inference. "
197
+ "For training, please use `chunk_dplr_delta_rule`."
198
+ )
199
+
200
+
201
+ def fused_recurrent_dplr_delta_rule(
202
+ q: torch.Tensor,
203
+ k: torch.Tensor,
204
+ v: torch.Tensor,
205
+ a: torch.Tensor,
206
+ b: torch.Tensor,
207
+ gk: torch.Tensor,
208
+ scale: Optional[float] = 1.0,
209
+ initial_state: Optional[torch.Tensor] = None,
210
+ output_final_state: bool = False,
211
+ reverse: bool = False,
212
+ cu_seqlens: Optional[torch.Tensor] = None,
213
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
214
+ r"""
215
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
216
+
217
+ Args:
218
+ q (torch.Tensor):
219
+ queries of shape `[B, T, H, K]`.
220
+ k (torch.Tensor):
221
+ keys of shape `[B, T, H, K]`.
222
+ v (torch.Tensor):
223
+ values of shape `[B, T, H, V]`.
224
+ a (torch.Tensor):
225
+ a of shape `[B, T, H, K]`.
226
+ b (torch.Tensor):
227
+ b of shape `[B, T, H, K]`.
228
+ gk (torch.Tensor):
229
+ gk of shape `[B, T, H, K]`. decay term in log space!
230
+ scale (Optional[int]):
231
+ Scale factor for the RetNet attention scores.
232
+ If not provided, it will default to `1 / sqrt(K)`. Default: 1.
233
+ initial_state (Optional[torch.Tensor]):
234
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
235
+ For equal-length input sequences, `N` equals the batch size `B`.
236
+ Default: `None`.
237
+ output_final_state (Optional[bool]):
238
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
239
+ reverse (Optional[bool]):
240
+ If `True`, process the state passing in reverse order. Default: `False`.
241
+ cu_seqlens (Optional[torch.Tensor]):
242
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
243
+ consistent with the FlashAttention API.
244
+ """
245
+ if cu_seqlens is not None:
246
+ if q.shape[0] != 1:
247
+ raise ValueError(
248
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
249
+ f"Please flatten variable-length inputs before processing."
250
+ )
251
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
252
+ raise ValueError(
253
+ f"The number of initial states is expected to be equal to the number of input sequences, "
254
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
255
+ )
256
+ if scale is None:
257
+ scale = q.shape[-1] ** -0.5
258
+ else:
259
+ assert scale > 0, "scale must be positive"
260
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
261
+ q,
262
+ k,
263
+ v,
264
+ a,
265
+ b,
266
+ gk,
267
+ scale,
268
+ initial_state,
269
+ output_final_state,
270
+ reverse,
271
+ cu_seqlens,
272
+ )
273
+ return o, final_state
opencompass/models/fla2/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
opencompass/models/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
12
+
13
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
14
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def prepare_wy_repr_bwd_kernel(
31
+ A_ab_inv,
32
+ A_ak,
33
+ ag,
34
+ v,
35
+ dw,
36
+ du,
37
+ dv,
38
+ dv0,
39
+ dag,
40
+ dAak,
41
+ dAab,
42
+ cu_seqlens,
43
+ chunk_indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ IS_VARLEN: tl.constexpr,
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if IS_VARLEN:
56
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
65
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+
67
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
68
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
69
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
70
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
71
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
72
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
73
+
74
+ for i_v in range(tl.cdiv(V, BV)):
75
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
76
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
77
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
78
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
79
+ b_v = tl.load(p_v, boundary_check=(0, 1))
80
+ b_du = tl.load(p_du, boundary_check=(0, 1))
81
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
82
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
83
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
84
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
85
+
86
+ m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :]
87
+ b_dA_tmp = tl.where(m_i, b_dA_tmp, 0)
88
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
89
+ b_dA_ak = tl.where(m_i, b_dA_ak, 0)
90
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
91
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
92
+
93
+ for i_k in range(tl.cdiv(K, BK)):
94
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
95
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
96
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
97
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
98
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
99
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
100
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
101
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
102
+
103
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
104
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
105
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
106
+ # denote A = I - lower(A_ab), B = A^-1
107
+ # in the backward pass.
108
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
109
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
110
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
111
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
112
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
113
+ b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0)
114
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
115
+
116
+
117
+ def chunk_dplr_bwd_wy(
118
+ A_ab_inv: torch.Tensor,
119
+ A_ak: torch.Tensor,
120
+ v: torch.Tensor,
121
+ ag: torch.Tensor,
122
+ dw: torch.Tensor,
123
+ du: torch.Tensor,
124
+ dv0: torch.Tensor,
125
+ cu_seqlens: Optional[torch.LongTensor],
126
+ chunk_size: int,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
129
+ B, T, H, K, V = *dw.shape, du.shape[-1]
130
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
131
+
132
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
133
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
134
+ BK = min(triton.next_power_of_2(K), 64)
135
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
136
+
137
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
138
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
139
+ dv = torch.empty_like(v)
140
+ dag = torch.empty_like(ag)
141
+
142
+ prepare_wy_repr_bwd_kernel[(NT, B * H)](
143
+ A_ab_inv=A_ab_inv,
144
+ A_ak=A_ak,
145
+ ag=ag,
146
+ v=v,
147
+ dw=dw,
148
+ du=du,
149
+ dv=dv,
150
+ dv0=dv0,
151
+ dag=dag,
152
+ dAak=dA_ak,
153
+ dAab=dA_ab,
154
+ cu_seqlens=cu_seqlens,
155
+ chunk_indices=chunk_indices,
156
+ T=T,
157
+ H=H,
158
+ K=K,
159
+ V=V,
160
+ BT=BT,
161
+ BK=BK,
162
+ BV=BV,
163
+ )
164
+ return dA_ab, dA_ak, dv, dag
opencompass/models/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import gather
12
+ from ....utils import is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps)
21
+ for num_warps in [1, 2, 4, 8, 16]
22
+ ],
23
+ key=['BT'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def prepare_wy_repr_fwd_kernel_chunk32(
28
+ A_ab,
29
+ A_ab_inv,
30
+ cu_seqlens,
31
+ chunk_indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BC: tl.constexpr, # placeholder, do not delete
36
+ IS_VARLEN: tl.constexpr,
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
47
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
49
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
50
+ for i in range(1, BT):
51
+ mask = tl.arange(0, BT) == i
52
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
53
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
54
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
55
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
56
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
57
+
58
+
59
+ @triton.heuristics({
60
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
61
+ })
62
+ @triton.autotune(
63
+ configs=[
64
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
65
+ for num_warps in [2, 4, 8]
66
+ for num_stages in [2, 3, 4]
67
+ ],
68
+ key=['BC'],
69
+ use_cuda_graph=use_cuda_graph,
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def prepare_wy_repr_fwd_kernel_chunk64(
73
+ A_ab,
74
+ A_ab_inv,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ BT: tl.constexpr,
80
+ BC: tl.constexpr,
81
+ IS_VARLEN: tl.constexpr,
82
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
83
+ ):
84
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
88
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
89
+ T = eos - bos
90
+ else:
91
+ bos, eos = i_b * T, i_b * T + T
92
+
93
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
95
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
96
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
98
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
99
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
100
+
101
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
102
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
103
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
104
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
105
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
106
+
107
+ for i in range(1, BC):
108
+ if GATHER_SUPPORTED:
109
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
110
+ # [1, BK] -> [BK]
111
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
112
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
113
+ else:
114
+ mask = tl.arange(0, BC) == i
115
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
116
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
117
+ mask = tl.arange(0, BC) == i
118
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
119
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
120
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
121
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
122
+ b_A = tl.where(mask[:, None], b_a, b_A)
123
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
124
+
125
+ # blockwise computation of lower triangular matrix's inverse
126
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
127
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
128
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
129
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
130
+ # tl.debug_barrier()
131
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
132
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
133
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
134
+ # causal mask
135
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
136
+
137
+
138
+ @triton.heuristics({
139
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
140
+ })
141
+ @triton.autotune(
142
+ configs=[
143
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
144
+ for num_warps in [2, 4, 8, 16]
145
+ for num_stages in [2, 3, 4]
146
+ ],
147
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
148
+ use_cuda_graph=use_cuda_graph,
149
+ )
150
+ @triton.jit(do_not_specialize=['T'])
151
+ def wu_fwd_kernel(
152
+ w,
153
+ u,
154
+ ag,
155
+ v,
156
+ A_ab_inv,
157
+ A_ak,
158
+ cu_seqlens,
159
+ chunk_indices,
160
+ T,
161
+ H: tl.constexpr,
162
+ K: tl.constexpr,
163
+ V: tl.constexpr,
164
+ BT: tl.constexpr,
165
+ BK: tl.constexpr,
166
+ BV: tl.constexpr,
167
+ IS_VARLEN: tl.constexpr,
168
+ ):
169
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
170
+ i_b, i_h = i_bh // H, i_bh % H
171
+ if IS_VARLEN:
172
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
173
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
174
+ T = eos - bos
175
+ else:
176
+ bos, eos = i_b * T, i_b * T + T
177
+ o_s = tl.arange(0, BT)
178
+
179
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
180
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+
182
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
183
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
184
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
185
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
186
+ # let's use tf32 here
187
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
188
+ # (SY 01/04) should be bf16 or tf32? To verify.
189
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
190
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
191
+
192
+ for i_k in range(tl.cdiv(K, BK)):
193
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
195
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
196
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
197
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
198
+
199
+ for i_v in range(tl.cdiv(V, BV)):
200
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
201
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
202
+ b_v = tl.load(p_v, boundary_check=(0, 1))
203
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
204
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
205
+
206
+
207
+ def wu_fwd(
208
+ ag: torch.Tensor,
209
+ v: torch.Tensor,
210
+ A_ak: torch.Tensor,
211
+ A_ab_inv: torch.Tensor,
212
+ cu_seqlens: Optional[torch.LongTensor],
213
+ chunk_size: int
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ B, T, H, K, V = *ag.shape, v.shape[-1]
216
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
217
+
218
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
219
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
220
+ BK = min(triton.next_power_of_2(K), 64)
221
+ BV = min(triton.next_power_of_2(V), 64)
222
+
223
+ w = torch.empty_like(ag)
224
+ u = torch.empty_like(v)
225
+ wu_fwd_kernel[(NT, B * H)](
226
+ ag=ag,
227
+ v=v,
228
+ A_ak=A_ak,
229
+ A_ab_inv=A_ab_inv,
230
+ w=w,
231
+ u=u,
232
+ cu_seqlens=cu_seqlens,
233
+ chunk_indices=chunk_indices,
234
+ T=T,
235
+ H=H,
236
+ K=K,
237
+ V=V,
238
+ BT=BT,
239
+ BK=BK,
240
+ BV=BV,
241
+ )
242
+ return w, u
243
+
244
+
245
+ def prepare_wy_repr_fwd(
246
+ ag: torch.Tensor,
247
+ v: torch.Tensor,
248
+ A_ak: torch.Tensor,
249
+ A_ab: torch.Tensor,
250
+ cu_seqlens: Optional[torch.LongTensor],
251
+ chunk_size: int = 64
252
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
253
+ B, T, H, _ = ag.shape
254
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
255
+
256
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
257
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
258
+ BC = min(BT, 32)
259
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
260
+ A_ab_inv = torch.empty_like(A_ab)
261
+ fwd_fn[(NT, B * H)](
262
+ A_ab=A_ab,
263
+ A_ab_inv=A_ab_inv,
264
+ cu_seqlens=cu_seqlens,
265
+ chunk_indices=chunk_indices,
266
+ T=T,
267
+ H=H,
268
+ BT=BT,
269
+ BC=BC,
270
+ )
271
+ w, u = wu_fwd(
272
+ ag=ag,
273
+ v=v,
274
+ A_ak=A_ak,
275
+ A_ab_inv=A_ab_inv,
276
+ cu_seqlens=cu_seqlens,
277
+ chunk_size=BT
278
+ )
279
+ return w, u, A_ab_inv
280
+
281
+
282
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
283
+
284
+ fwd_wu = wu_fwd
opencompass/models/fla2/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
opencompass/models/fla2/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from ....ops.generalized_delta_rule.iplr.wy_fast import prepare_wy_repr_fwd
13
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
14
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
15
+
16
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
17
+
18
+
19
+ @triton.heuristics({
20
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
21
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
22
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
23
+ })
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=num_warps)
27
+ for num_warps in [2, 4, 8, 16]
28
+ ],
29
+ key=['BT', 'BK', 'BV'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ b,
38
+ u,
39
+ v_new,
40
+ h,
41
+ h0,
42
+ ht,
43
+ cu_seqlens,
44
+ chunk_offsets,
45
+ T,
46
+ H: tl.constexpr,
47
+ K: tl.constexpr,
48
+ V: tl.constexpr,
49
+ BT: tl.constexpr,
50
+ BC: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ BV: tl.constexpr,
53
+ USE_INITIAL_STATE: tl.constexpr,
54
+ STORE_FINAL_STATE: tl.constexpr,
55
+ IS_VARLEN: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if IS_VARLEN:
60
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
77
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
78
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
79
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
80
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
81
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
83
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
84
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
87
+ # [BK, BC]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ b_v = tl.load(p_v, boundary_check=(0, 1))
90
+ b_d = tl.load(p_d, boundary_check=(0, 1))
91
+ b_b = tl.load(p_b, boundary_check=(0, 1))
92
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
93
+ b_hc += tl.dot(b_k, b_v)
94
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
95
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
96
+ b_h += b_hc
97
+
98
+ if STORE_FINAL_STATE:
99
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
100
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
101
+
102
+
103
+ @triton.heuristics({
104
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
105
+ })
106
+ @triton.autotune(
107
+ configs=[
108
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
109
+ for BK in BKV_LIST
110
+ for BV in BKV_LIST
111
+ for num_warps in [2, 4, 8]
112
+ for num_stages in [2, 3]
113
+ ],
114
+ key=['BT'],
115
+ use_cuda_graph=use_cuda_graph,
116
+ )
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
119
+ q,
120
+ k,
121
+ v,
122
+ u,
123
+ b,
124
+ h,
125
+ o,
126
+ cu_seqlens,
127
+ chunk_indices,
128
+ scale,
129
+ T,
130
+ H: tl.constexpr,
131
+ K: tl.constexpr,
132
+ V: tl.constexpr,
133
+ BT: tl.constexpr,
134
+ BK: tl.constexpr,
135
+ BV: tl.constexpr,
136
+ IS_VARLEN: tl.constexpr,
137
+ ):
138
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
139
+ i_b, i_h = i_bh // H, i_bh % H
140
+
141
+ if IS_VARLEN:
142
+ i_tg = i_t
143
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ NT = tl.cdiv(T, BT)
147
+ else:
148
+ NT = tl.cdiv(T, BT)
149
+ i_tg = i_b * NT + i_t
150
+ bos, eos = i_b * T, i_b * T + T
151
+
152
+ # offset calculation
153
+ q += (bos * H + i_h) * K
154
+ k += (bos * H + i_h) * K
155
+ b += (bos * H + i_h) * K
156
+ v += (bos * H + i_h) * V
157
+ u += (bos * H + i_h) * V
158
+ o += (bos * H + i_h) * V
159
+ h += (i_tg * H + i_h) * K * V
160
+ stride_qk = H*K
161
+ stride_vo = H*V
162
+
163
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
164
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
165
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
166
+
167
+ for i_k in range(tl.cdiv(K, BK)):
168
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
170
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
171
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
172
+ # [BT, BK]
173
+ b_q = tl.load(p_q, boundary_check=(0, 1))
174
+ # [BK, BT]
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_b = tl.load(p_b, boundary_check=(0, 1))
177
+ # [BK, BV]
178
+ b_h = tl.load(p_h, boundary_check=(0, 1))
179
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
180
+ b_o += tl.dot(b_q, b_h)
181
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
182
+ b_Aqk += tl.dot(b_q, b_k)
183
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
184
+ b_Aqb += tl.dot(b_q, b_b)
185
+
186
+ o_i = tl.arange(0, BT)
187
+ m_A = o_i[:, None] >= o_i[None, :]
188
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
189
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
190
+
191
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
192
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
193
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
194
+ b_v = tl.load(p_v, boundary_check=(0, 1))
195
+ b_u = tl.load(p_u, boundary_check=(0, 1))
196
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
197
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
198
+
199
+
200
+ def chunk_generalized_iplr_delta_rule_fwd_o(
201
+ q: torch.Tensor,
202
+ k: torch.Tensor,
203
+ v: torch.Tensor,
204
+ v_new: torch.Tensor,
205
+ b: torch.Tensor,
206
+ h: torch.Tensor,
207
+ scale: Optional[float] = None,
208
+ cu_seqlens: Optional[torch.LongTensor] = None,
209
+ chunk_size: int = 64
210
+ ) -> torch.Tensor:
211
+ B, T, H, K, V = *q.shape, v.shape[-1]
212
+ if scale is None:
213
+ scale = k.shape[-1] ** -0.5
214
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
215
+
216
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
217
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
218
+
219
+ o = torch.empty_like(v)
220
+
221
+ def grid(meta): return (
222
+ triton.cdiv(V, meta['BV']),
223
+ NT,
224
+ B * H
225
+ )
226
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
227
+ q=q,
228
+ k=k,
229
+ v=v,
230
+ u=v_new,
231
+ b=b,
232
+ h=h,
233
+ o=o,
234
+ cu_seqlens=cu_seqlens,
235
+ chunk_indices=chunk_indices,
236
+ scale=scale,
237
+ T=T,
238
+ H=H,
239
+ K=K,
240
+ V=V,
241
+ BT=BT,
242
+ )
243
+ return o
244
+
245
+
246
+ def chunk_generalized_iplr_delta_rule_fwd_h(
247
+ k: torch.Tensor,
248
+ v: torch.Tensor,
249
+ w: torch.Tensor,
250
+ u: torch.Tensor,
251
+ b: torch.Tensor,
252
+ initial_state: Optional[torch.Tensor] = None,
253
+ output_final_state: bool = False,
254
+ cu_seqlens: Optional[torch.LongTensor] = None,
255
+ chunk_size: int = 64
256
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
257
+ B, T, H, K, V = *k.shape, u.shape[-1]
258
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
259
+
260
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
261
+ # N: the actual number of sequences in the batch with either equal or variable lengths
262
+ if cu_seqlens is None:
263
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
264
+ else:
265
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
266
+
267
+ BK = triton.next_power_of_2(K)
268
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
269
+ # H100 can have larger block size
270
+
271
+ if check_shared_mem('hopper', k.device.index):
272
+ BV = 64
273
+ BC = 64 if K <= 128 else 32
274
+ elif check_shared_mem('ampere', k.device.index): # A100
275
+ BV = 32
276
+ BC = 32
277
+ else:
278
+ BV = 16
279
+ BC = 16
280
+
281
+ BC = min(BT, BC)
282
+ NK = triton.cdiv(K, BK)
283
+ NV = triton.cdiv(V, BV)
284
+
285
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
286
+
287
+ h = k.new_empty(B, NT, H, K, V)
288
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
289
+
290
+ v_new = torch.empty_like(u)
291
+ grid = (NK, NV, N * H)
292
+
293
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
294
+ k=k,
295
+ v=v,
296
+ d=w,
297
+ b=b,
298
+ u=u,
299
+ v_new=v_new,
300
+ h=h,
301
+ h0=initial_state,
302
+ ht=final_state,
303
+ cu_seqlens=cu_seqlens,
304
+ chunk_offsets=chunk_offsets,
305
+ T=T,
306
+ H=H,
307
+ K=K,
308
+ V=V,
309
+ BT=BT,
310
+ BC=BC,
311
+ BK=BK,
312
+ BV=BV,
313
+ )
314
+ return h, v_new, final_state
315
+
316
+
317
+ def chunk_generalized_iplr_delta_rule_fwd(
318
+ q: torch.Tensor,
319
+ k: torch.Tensor,
320
+ v: torch.Tensor,
321
+ a: torch.Tensor,
322
+ b: torch.Tensor,
323
+ scale: float,
324
+ initial_state: torch.Tensor,
325
+ output_final_state: bool,
326
+ cu_seqlens: Optional[torch.LongTensor] = None,
327
+ chunk_size: int = 64
328
+ ):
329
+ T = q.shape[1]
330
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
331
+ w, u, _ = prepare_wy_repr_fwd(
332
+ a=a,
333
+ b=b,
334
+ k=k,
335
+ v=v,
336
+ cu_seqlens=cu_seqlens,
337
+ chunk_size=BT
338
+ )
339
+
340
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
341
+ k=k,
342
+ v=v,
343
+ b=b,
344
+ w=w,
345
+ u=u,
346
+ initial_state=initial_state,
347
+ output_final_state=output_final_state,
348
+ cu_seqlens=cu_seqlens,
349
+ chunk_size=BT
350
+ )
351
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
352
+ q=q,
353
+ k=k,
354
+ v=v,
355
+ v_new=v_new,
356
+ b=b,
357
+ h=h,
358
+ scale=scale,
359
+ cu_seqlens=cu_seqlens,
360
+ chunk_size=BT
361
+ )
362
+ return o, final_state
363
+
364
+
365
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
366
+
367
+ @staticmethod
368
+ @input_guard
369
+ @autocast_custom_fwd
370
+ def forward(
371
+ ctx,
372
+ q: torch.Tensor,
373
+ k: torch.Tensor,
374
+ v: torch.Tensor,
375
+ a: torch.Tensor,
376
+ b: torch.Tensor,
377
+ scale: float,
378
+ initial_state: torch.Tensor,
379
+ output_final_state: bool,
380
+ cu_seqlens: Optional[torch.LongTensor] = None,
381
+ ):
382
+ chunk_size = 64
383
+
384
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
385
+ q=q,
386
+ k=k,
387
+ v=v,
388
+ a=a,
389
+ b=b,
390
+ scale=scale,
391
+ initial_state=initial_state,
392
+ output_final_state=output_final_state,
393
+ cu_seqlens=cu_seqlens,
394
+ chunk_size=chunk_size
395
+ )
396
+ return o.to(q.dtype), final_state
397
+
398
+ @staticmethod
399
+ @input_guard
400
+ @autocast_custom_bwd
401
+ def backward(
402
+ ctx,
403
+ do: torch.Tensor,
404
+ dht: torch.Tensor
405
+ ):
406
+ raise NotImplementedError(
407
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
408
+ "Stay tuned!"
409
+ )
410
+
411
+
412
+ @torch.compiler.disable
413
+ def chunk_iplr_delta_rule(
414
+ q: torch.Tensor,
415
+ k: torch.Tensor,
416
+ v: torch.Tensor,
417
+ a: torch.Tensor,
418
+ b: torch.Tensor,
419
+ scale: float = None,
420
+ initial_state: torch.Tensor = None,
421
+ output_final_state: bool = False,
422
+ cu_seqlens: Optional[torch.LongTensor] = None,
423
+ head_first: bool = False
424
+ ):
425
+ r"""
426
+ Args:
427
+ q (torch.Tensor):
428
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
429
+ k (torch.Tensor):
430
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
431
+ v (torch.Tensor):
432
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
433
+ a (torch.Tensor):
434
+ activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
435
+ b (torch.Tensor):
436
+ betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
437
+ scale (Optional[int]):
438
+ Scale factor for the RetNet attention scores.
439
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
440
+ initial_state (Optional[torch.Tensor]):
441
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
442
+ For equal-length input sequences, `N` equals the batch size `B`.
443
+ Default: `None`.
444
+ output_final_state (Optional[bool]):
445
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
446
+ cu_seqlens (torch.LongTensor):
447
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
448
+ consistent with the FlashAttention API.
449
+ head_first (Optional[bool]):
450
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
451
+ Default: `False`.
452
+
453
+ Returns:
454
+ o (torch.Tensor):
455
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
456
+ final_state (torch.Tensor):
457
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
458
+ """
459
+ assert q.dtype == k.dtype == v.dtype
460
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
461
+
462
+ if head_first:
463
+ raise DeprecationWarning(
464
+ "head_first is deprecated and will be removed in a future version. "
465
+ "Please use head_first=False for now instead."
466
+ )
467
+ q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
468
+ if not head_first and q.shape[1] < q.shape[2]:
469
+ warnings.warn(
470
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
471
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
472
+ "when head_first=False was specified. "
473
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
474
+ )
475
+ if cu_seqlens is not None:
476
+ if q.shape[0] != 1:
477
+ raise ValueError(
478
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
479
+ f"Please ...tten variable-length inputs before processing."
480
+ )
481
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
482
+ raise ValueError(
483
+ f"The number of initial states is expected to be equal to the number of input sequences, "
484
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
485
+ )
486
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
487
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
488
+ q,
489
+ k,
490
+ v,
491
+ a,
492
+ b,
493
+ scale,
494
+ initial_state,
495
+ output_final_state,
496
+ cu_seqlens,
497
+ )
498
+ if head_first:
499
+ o = rearrange(o, 'b t h ... -> b h t ...')
500
+ return o, final_state
opencompass/models/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....utils import input_guard
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BV in [32, 64]
22
+ for num_warps in [2, 4, 8, 16]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=["BK"],
26
+ )
27
+ @triton.jit
28
+ def fused_recurrent_fwd_kernel(
29
+ q, # query [B, H, L, K]
30
+ k, # key [B, H, L, V]
31
+ v, # value [B, H, L, V].
32
+ a, # a [B, H, L, K]
33
+ b, # b [B, H, L, K]
34
+ o, # output [B, H, L, V]
35
+ ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
36
+ h0, # initial hidden state [B, H, K, V]
37
+ ht, # final hidden state [B, H, K, V]
38
+ cu_seqlens, # varlen cu_seqlens
39
+ scale, # K ** -0.5
40
+ H, # n_heads
41
+ T, # seq_len
42
+ K: tl.constexpr, # K
43
+ V: tl.constexpr, # V
44
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
45
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
46
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
47
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
51
+ i_n, i_h = i_nh // H, i_nh % H
52
+
53
+ if IS_VARLEN:
54
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
55
+ T = eos - bos
56
+ else:
57
+ bos, eos = i_n * T, i_n * T + T
58
+
59
+ p_q = q + (bos * H + i_h) * K + tl.arange(0, BK)
60
+ p_k = k + (bos * H + i_h) * K + tl.arange(0, BK)
61
+ p_a = a + (bos * H + i_h) * K + tl.arange(0, BK)
62
+ p_b = b + (bos * H + i_h) * K + tl.arange(0, BK)
63
+ p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
64
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
65
+ p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
66
+
67
+ mask_k = tl.arange(0, BK) < K
68
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
69
+ mask_h = mask_k[None, :] & mask_v[:, None]
70
+
71
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
72
+
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
75
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
76
+
77
+ for _ in range(0, T):
78
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
81
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
82
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
83
+ # to store
84
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
85
+ b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
86
+ b_o = b_h * b_q[None, :]
87
+ b_o = tl.sum(b_o, axis=1)
88
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
89
+ tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v)
90
+ p_q += K*H
91
+ p_k += K*H
92
+ p_o += V*H
93
+ p_v += V*H
94
+ p_ha += V*H
95
+ p_a += K*H
96
+ p_b += K*H
97
+
98
+ if STORE_FINAL_STATE:
99
+ p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None])
100
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
101
+
102
+
103
+ @triton.heuristics({
104
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
105
+ 'USE_DHT': lambda args: args['dht'] is not None,
106
+ 'USE_DH0': lambda args: args['dh0'] is not None,
107
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
108
+ })
109
+ @triton.autotune(
110
+ configs=[
111
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
112
+ for num_warps in [2, 4, 8, 16]
113
+ for num_stages in [2, 3]
114
+ ],
115
+ key=["BK", "BV"],
116
+ )
117
+ @triton.jit
118
+ def fused_recurrent_bwd_kernel(
119
+ # B: batch_size, H: n_heads, T: seq_len, D: b_dhead
120
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
121
+ q, # query [B, H, L, K]
122
+ k, # key [B, H, L, V]
123
+ v, # value [B, H, L, V]
124
+ a, # a [B, H, L, K]
125
+ b, # b [B, H, L, K]
126
+ ha, # ha [B, H, L, V]
127
+ dht, # gradient of final state [B, H, K, V]
128
+ dh0, # gradient of initial state [B, H, K, V]
129
+ do, # gradient of output [B, H, L, V]
130
+ dq, # gradient of query [NV, B, H, L, K]
131
+ dk, # gradient of key [NV, B, H, L, K]
132
+ dv, # gradient of value [NK, B, H, L, V]
133
+ da, # gradient of a [NV, B, H, L, K]
134
+ db, # gradient of b [NV, B, H, L, K]
135
+ dha, # gradient of ha [NK, B, H, L, V]
136
+ h0, # initial state [B, H, K, V]
137
+ scale, # K ** -0.5
138
+ cu_seqlens, # cu_seqlens
139
+ B, # batch_size
140
+ H, # n_heads
141
+ T, # seq_len
142
+ K: tl.constexpr, # K
143
+ V: tl.constexpr, # V
144
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
145
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
146
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0
147
+ USE_DH0: tl.constexpr, # whether to use dh0
148
+ USE_DHT: tl.constexpr, # whether to use dht
149
+ IS_VARLEN: tl.constexpr,
150
+ ):
151
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
152
+ i_n, i_h = i_nh // H, i_nh % H
153
+ dk += i_v * B * H * K * T
154
+ db += i_v * B * H * K * T
155
+ dq += i_v * B * H * K * T
156
+ da += i_v * B * H * K * T
157
+ if IS_VARLEN:
158
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
159
+ T = eos - bos
160
+ else:
161
+ bos, eos = i_n * T, i_n * T + T
162
+ mask_k = tl.arange(0, BK) < K
163
+ mask_v = (tl.arange(0, BV) + i_v * BV) < V
164
+
165
+ q += (bos * H + i_h) * K
166
+ k += (bos * H + i_h) * K
167
+ v += (bos * H + i_h) * V + i_v * BV
168
+ ha += (bos * H + i_h) * V + i_v * BV
169
+ a += (bos * H + i_h) * K
170
+ b += (bos * H + i_h) * K
171
+ do += (bos * H + i_h) * V + i_v * BV
172
+ dq += (bos * H + i_h) * K
173
+ dk += (bos * H + i_h) * K
174
+ dv += (bos * H + i_h) * V + i_v * BV
175
+ da += (bos * H + i_h) * K
176
+ db += (bos * H + i_h) * K
177
+ dha += (bos * H + i_h) * V + i_v * BV
178
+
179
+ p_q = q + tl.arange(0, BK) + (T - 1) * H*K
180
+ p_k = k + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_v = v + tl.arange(0, BV) + (T - 1) * H*V
182
+ p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V
183
+ p_a = a + tl.arange(0, BK) + (T - 1) * H*K
184
+ p_b = b + tl.arange(0, BK) + (T - 1) * H*K
185
+ p_do = do + tl.arange(0, BV) + (T - 1) * H*V
186
+ p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K
187
+ p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V
188
+ p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V
189
+ p_db = db + tl.arange(0, BK) + (T - 1) * H*K
190
+ p_da = da + tl.arange(0, BK) + (T - 1) * H*K
191
+ p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K
192
+
193
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
194
+ if USE_DHT:
195
+ p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
196
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
197
+
198
+ for _ in range(T):
199
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
200
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
201
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
202
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
203
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
204
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
205
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
206
+
207
+ b_dh += b_q[:, None] * b_do[None, :]
208
+ d_k = tl.sum(b_dh * b_v[None, :], axis=1)
209
+ d_v = tl.sum(b_dh * b_k[:, None], axis=0)
210
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v)
212
+
213
+ b_dha = tl.sum(b_dh * b_b[:, None], axis=0)
214
+ tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v)
215
+ b_db = tl.sum(b_dh * b_ha[None, :], axis=1)
216
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k)
217
+
218
+ b_dh += b_dha[None, :] * b_a[:, None]
219
+ p_do -= H*V
220
+ p_q -= H*K
221
+ p_k -= H*K
222
+ p_v -= H*V
223
+ p_dk -= H*K
224
+ p_dv -= H*V
225
+ p_b -= H*K
226
+ p_db -= H*K
227
+ p_a -= H*K
228
+ p_dha -= H*V
229
+ p_ha -= H*V
230
+
231
+ if USE_DH0:
232
+ p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
233
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
234
+
235
+ tl.debug_barrier()
236
+
237
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
238
+
239
+ if USE_INITIAL_STATE:
240
+ mask_kv = mask_k[:, None] & mask_v[None, :]
241
+ p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
242
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
243
+
244
+ p_k = k + tl.arange(0, BK)
245
+ p_v = v + tl.arange(0, BV)
246
+ p_ha = ha + tl.arange(0, BV)
247
+ p_do = do + tl.arange(0, BV)
248
+ p_dha = dha + tl.arange(0, BV)
249
+ p_da = da + tl.arange(0, BK)
250
+ p_dq = dq + tl.arange(0, BK)
251
+ p_b = b + tl.arange(0, BK)
252
+
253
+ for i in range(0, T):
254
+ b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32)
255
+ d_a = tl.sum(b_dha[None, :] * b_h, axis=1)
256
+ tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k)
257
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
258
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
259
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
260
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
261
+ b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32)
262
+ b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :]
263
+ _d_q = b_h * b_do[None, :]
264
+ d_q = tl.sum(_d_q, axis=1) * scale
265
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
266
+
267
+ p_k += H*K
268
+ p_do += H*V
269
+ p_v += H*V
270
+ p_da += H*K
271
+ p_dha += H*V
272
+ p_ha += H*V
273
+ p_dq += H*K
274
+ p_b += H*K
275
+
276
+
277
+ class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function):
278
+
279
+ @staticmethod
280
+ @input_guard
281
+ def forward(
282
+ ctx,
283
+ q: torch.Tensor,
284
+ k: torch.Tensor,
285
+ v: torch.Tensor,
286
+ a: torch.Tensor,
287
+ b: torch.Tensor,
288
+ scale: Optional[float] = None,
289
+ initial_state: Optional[torch.Tensor] = None,
290
+ output_final_state: bool = False,
291
+ cu_seqlens: Optional[torch.LongTensor] = None
292
+ ):
293
+ B, T, H, K, V = *k.shape, v.shape[-1]
294
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
295
+
296
+ BK = triton.next_power_of_2(K)
297
+ if output_final_state:
298
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float32)
299
+ else:
300
+ final_state = None
301
+
302
+ ha = torch.empty_like(v, dtype=torch.float32)
303
+
304
+ def grid(meta): return (
305
+ triton.cdiv(V, meta['BV']),
306
+ N * H
307
+ )
308
+ o = torch.empty_like(v)
309
+ fused_recurrent_fwd_kernel[grid](
310
+ q=q,
311
+ k=k,
312
+ v=v,
313
+ a=a,
314
+ b=b,
315
+ o=o,
316
+ ha=ha,
317
+ h0=initial_state,
318
+ ht=final_state,
319
+ scale=scale,
320
+ cu_seqlens=cu_seqlens,
321
+ H=H,
322
+ T=T,
323
+ K=K,
324
+ V=V,
325
+ BK=BK,
326
+ )
327
+ ctx.save_for_backward(q, k, v, a, b, ha, initial_state)
328
+ ctx.scale = scale
329
+ ctx.cu_seqlens = cu_seqlens
330
+ return o, final_state
331
+
332
+ @staticmethod
333
+ @input_guard
334
+ def backward(ctx, do, dht):
335
+ q, k, v, a, b, ha, initial_state = ctx.saved_tensors
336
+ B, T, H, K, V = *q.shape, v.shape[-1]
337
+ N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1
338
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
339
+ NV = triton.cdiv(V, BV)
340
+ scale = ctx.scale
341
+
342
+ dq = q.new_empty(NV, *q.shape)
343
+ dk = k.new_empty(NV, *k.shape)
344
+ da = a.new_empty(NV, *a.shape)
345
+ db = b.new_empty(NV, *b.shape)
346
+ dv = torch.empty_like(v)
347
+ dha = torch.empty_like(ha)
348
+ grid = (NV, N * H)
349
+
350
+ if initial_state is not None and initial_state.requires_grad:
351
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
352
+ else:
353
+ dh0 = None
354
+
355
+ fused_recurrent_bwd_kernel[grid](
356
+ q=q,
357
+ k=k,
358
+ v=v,
359
+ a=a,
360
+ b=b,
361
+ ha=ha,
362
+ dht=dht,
363
+ dh0=dh0,
364
+ do=do,
365
+ dq=dq,
366
+ dk=dk,
367
+ dv=dv,
368
+ da=da,
369
+ db=db,
370
+ dha=dha,
371
+ h0=initial_state,
372
+ scale=scale,
373
+ cu_seqlens=ctx.cu_seqlens,
374
+ B=B,
375
+ H=H,
376
+ T=T,
377
+ K=K,
378
+ V=V,
379
+ BK=BK,
380
+ BV=BV,
381
+ )
382
+ dq = dq.sum(0)
383
+ dk = dk.sum(0)
384
+ da = da.sum(0)
385
+ db = db.sum(0)
386
+ return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None
387
+
388
+
389
+ def fused_recurrent_iplr_delta_rule(
390
+ q: torch.Tensor,
391
+ k: torch.Tensor,
392
+ v: torch.Tensor,
393
+ a: torch.Tensor,
394
+ b: torch.Tensor,
395
+ scale: float = None,
396
+ initial_state: torch.Tensor = None,
397
+ output_final_state: bool = False,
398
+ cu_seqlens: Optional[torch.Tensor] = None,
399
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ r"""
401
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
402
+
403
+ Args:
404
+ q (torch.Tensor):
405
+ queries of shape `[B, T, H, K]`
406
+ k (torch.Tensor):
407
+ keys of shape `[B, T, H, K]`
408
+ v (torch.Tensor):
409
+ values of shape `[B, T, H, V]`
410
+ a (torch.Tensor):
411
+ as of shape `[B, T, H, K]`
412
+ b (torch.Tensor):
413
+ bs of shape `[B, T, H, K]`
414
+ scale (Optional[int]):
415
+ Scale factor for the RetNet attention scores.
416
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
417
+ initial_state (Optional[torch.Tensor]):
418
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
419
+ output_final_state (Optional[bool]):
420
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
421
+ cu_seqlens (torch.LongTensor):
422
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
423
+ consistent with the FlashAttention API.
424
+
425
+ """
426
+ if cu_seqlens is not None:
427
+ if q.shape[0] != 1:
428
+ raise ValueError(
429
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
430
+ f"Please flatten variable-length inputs before processing."
431
+ )
432
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
433
+ raise ValueError(
434
+ f"The number of initial states is expected to be equal to the number of input sequences, "
435
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
436
+ )
437
+ if scale is None:
438
+ scale = q.shape[-1] ** -0.5
439
+ else:
440
+ assert scale > 0, "scale must be positive"
441
+ o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
442
+ q,
443
+ k,
444
+ v,
445
+ a,
446
+ b,
447
+ scale,
448
+ initial_state,
449
+ output_final_state,
450
+ cu_seqlens
451
+ )
452
+ return o, final_state
opencompass/models/fla2/ops/generalized_delta_rule/iplr/naive.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
8
+ # q, k, alpha, beta [B, H, L, D_K]
9
+ # v [B, H, L, D_V]
10
+ def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True):
11
+ orig_dtype = q.dtype
12
+ b, h, l, d_k = q.shape
13
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
14
+ d_v = v.shape[-1]
15
+ o = torch.zeros_like(v)
16
+ S = torch.zeros(b, h, d_k, d_v).to(v)
17
+ q = q * (d_k ** -0.5)
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i]
26
+ _alpha = alpha[:, :, i]
27
+ _beta = beta[:, :, i]
28
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
29
+ S = S + _kv
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v
40
+ assert l % chunk_size == 0
41
+
42
+ S = k.new_zeros(b, h, d_k, d_v)
43
+ if initial_state is not None:
44
+ S += initial_state
45
+
46
+ # note that diagonal is masked.
47
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
48
+ q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta])
49
+
50
+ v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
51
+ attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0)
52
+ for i in range(1, chunk_size):
53
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
54
+
55
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
56
+ u = attn @ v2
57
+ w = attn @ alpha
58
+ o = torch.zeros_like(v)
59
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
60
+ for i in range(0, l // chunk_size):
61
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
62
+ o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i
63
+ v2_i = u_i + w_i @ S
64
+ o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i)
65
+ o_3 = q_i @ S
66
+ o[:, :, i] = o_1 + o_2 + o_3
67
+ S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i
68
+ S = None if output_final_state is False else S
69
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
opencompass/models/fla2/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ....ops.utils import prepare_chunk_indices
12
+ from ....utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4, 8, 16]
24
+ ],
25
+ key=['BK']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def prepare_wy_repr_fwd_kernel_chunk32(
29
+ a,
30
+ b,
31
+ A,
32
+ cu_seqlens,
33
+ chunk_indices,
34
+ T,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ BT: tl.constexpr,
38
+ BK: tl.constexpr,
39
+ BC: tl.constexpr, # dummy placeholder
40
+ IS_VARLEN: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if IS_VARLEN:
45
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
54
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
55
+ b_a = tl.load(p_a, boundary_check=(0, 1))
56
+ b_b = tl.load(p_b, boundary_check=(0, 1))
57
+ b_A += tl.dot(b_a, b_b)
58
+
59
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
60
+ for i in range(1, BT):
61
+ mask = tl.arange(0, BT) == i
62
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
63
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
64
+ b_A = tl.where(mask[:, None], b_a, b_A)
65
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
66
+
67
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
69
+
70
+
71
+ @triton.heuristics({
72
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
73
+ })
74
+ @triton.autotune(
75
+ configs=[
76
+ triton.Config({}, num_warps=num_warps)
77
+ for num_warps in [1, 2, 4, 8, 16]
78
+ ],
79
+ key=['BK']
80
+ )
81
+ @triton.jit(do_not_specialize=['T'])
82
+ def prepare_wy_repr_fwd_kernel_chunk64(
83
+ a,
84
+ b,
85
+ A,
86
+ cu_seqlens,
87
+ chunk_indices,
88
+ T,
89
+ H: tl.constexpr,
90
+ K: tl.constexpr,
91
+ BT: tl.constexpr,
92
+ BK: tl.constexpr,
93
+ BC: tl.constexpr,
94
+ IS_VARLEN: tl.constexpr,
95
+ ):
96
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
97
+ i_b, i_h = i_bh // H, i_bh % H
98
+ if IS_VARLEN:
99
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
100
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
101
+ T = eos - bos
102
+ else:
103
+ bos, eos = i_b * T, i_b * T + T
104
+
105
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
106
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
107
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
108
+
109
+ for i_k in range(tl.cdiv(K, BK)):
110
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
111
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
112
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
113
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
114
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
115
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
116
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
117
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
118
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
119
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
120
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
121
+
122
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
123
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
124
+
125
+ for i in range(1, BC):
126
+ mask = tl.arange(0, BC) == i
127
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
128
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
129
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
130
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
131
+ b_A = tl.where(mask[:, None], b_a, b_A)
132
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
133
+
134
+ # blockwise computation of lower triangular matrix's inverse
135
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
136
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
137
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
138
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
139
+
140
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
141
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
142
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
143
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
144
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
145
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
146
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
147
+ # causal mask
148
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
149
+
150
+
151
+ @triton.heuristics({
152
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
153
+ })
154
+ @triton.autotune(
155
+ configs=[
156
+ triton.Config({}, num_warps=num_warps)
157
+ for num_warps in NUM_WARPS
158
+ ],
159
+ key=['BT', 'BK', 'BV']
160
+ )
161
+ @triton.jit(do_not_specialize=['T'])
162
+ def wu_fwd_kernel(
163
+ w,
164
+ u,
165
+ a,
166
+ k,
167
+ v,
168
+ A,
169
+ cu_seqlens,
170
+ chunk_indices,
171
+ T,
172
+ H: tl.constexpr,
173
+ K: tl.constexpr,
174
+ V: tl.constexpr,
175
+ BT: tl.constexpr,
176
+ BK: tl.constexpr,
177
+ BV: tl.constexpr,
178
+ IS_VARLEN: tl.constexpr,
179
+ ):
180
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
181
+ i_b, i_h = i_bh // H, i_bh % H
182
+ if IS_VARLEN:
183
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
184
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
185
+ T = eos - bos
186
+ else:
187
+ bos, eos = i_b * T, i_b * T + T
188
+
189
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
190
+
191
+ b_A = tl.load(p_A, boundary_check=(0, 1))
192
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
193
+
194
+ for i_k in range(tl.cdiv(K, BK)):
195
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
198
+ b_k = tl.load(p_k, boundary_check=(0, 1))
199
+ b_a = tl.load(p_a, boundary_check=(0, 1))
200
+ b_w = tl.dot(b_A, b_a)
201
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
202
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
205
+ b_Aak = b_Aak.to(k.dtype.element_ty)
206
+
207
+ for i_v in range(tl.cdiv(V, BV)):
208
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
209
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
210
+ b_v = tl.load(p_v, boundary_check=(0, 1))
211
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
212
+ b_u = tl.dot(b_A, b_v)
213
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
214
+
215
+
216
+ def prepare_wy_repr_fwd(
217
+ a: torch.Tensor,
218
+ b: torch.Tensor,
219
+ v: torch.Tensor,
220
+ k: torch.Tensor,
221
+ cu_seqlens: Optional[torch.LongTensor],
222
+ chunk_size: int = 64
223
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
224
+ B, T, H, K = a.shape
225
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
226
+
227
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
228
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
229
+ BC = min(BT, 32)
230
+ BK = min(triton.next_power_of_2(K), 64)
231
+
232
+ A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype)
233
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
234
+
235
+ fwd_fn[(NT, B * H)](
236
+ a=a,
237
+ b=b,
238
+ A=A,
239
+ cu_seqlens=cu_seqlens,
240
+ chunk_indices=chunk_indices,
241
+ T=T,
242
+ H=H,
243
+ K=K,
244
+ BT=BT,
245
+ BK=BK,
246
+ BC=BC,
247
+ )
248
+ w, u = wu_fwd(
249
+ a=a,
250
+ v=v,
251
+ k=k,
252
+ A=A,
253
+ cu_seqlens=cu_seqlens,
254
+ chunk_size=chunk_size
255
+ )
256
+ return w, u, A
257
+
258
+
259
+ def wu_fwd(
260
+ a: torch.Tensor,
261
+ v: torch.Tensor,
262
+ k: torch.Tensor,
263
+ A: torch.Tensor,
264
+ cu_seqlens: Optional[torch.LongTensor],
265
+ chunk_size: int
266
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
267
+ B, T, H, K, V = *a.shape, v.shape[-1]
268
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
269
+
270
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
271
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
272
+ CONST_TILING = 64 if check_shared_mem() else 32
273
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
274
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
275
+
276
+ u = torch.empty_like(v)
277
+ w = torch.empty_like(a)
278
+ wu_fwd_kernel[(NT, B*H)](
279
+ a=a,
280
+ v=v,
281
+ w=w,
282
+ u=u,
283
+ A=A,
284
+ k=k,
285
+ cu_seqlens=cu_seqlens,
286
+ chunk_indices=chunk_indices,
287
+ T=T,
288
+ H=H,
289
+ K=K,
290
+ V=V,
291
+ BT=BT,
292
+ BK=BK,
293
+ BV=BV,
294
+ )
295
+ return w, u
296
+
297
+
298
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
299
+
300
+ fwd_wu = wu_fwd
opencompass/models/fla2/ops/gla/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gla
4
+ from .chunk_fuse import fused_chunk_gla
5
+ from .recurrent_fuse import fused_recurrent_gla
6
+
7
+ __all__ = [
8
+ 'chunk_gla',
9
+ 'fused_chunk_gla',
10
+ 'fused_recurrent_gla'
11
+ ]
opencompass/models/fla2/ops/gla/chunk.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ...ops.utils import chunk_global_reversed_cumsum, chunk_local_cumsum
12
+ from ...ops.common.chunk_h import chunk_fwd_h_fn, chunk_bwd_dh_fn
13
+ from ...utils import contiguous
14
+
15
+
16
+ @triton.jit
17
+ def chunk_gla_fwd_kernel_intra(
18
+ q,
19
+ k,
20
+ g,
21
+ A,
22
+ s_k_h,
23
+ s_k_t,
24
+ s_k_d,
25
+ scale,
26
+ T: tl.constexpr,
27
+ K: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BC: tl.constexpr,
30
+ BK: tl.constexpr,
31
+ NC: tl.constexpr
32
+ ):
33
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
34
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
35
+ n_bh = tl.num_programs(2)
36
+
37
+ if i_i > i_j:
38
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
39
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
40
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
41
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
42
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
43
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
44
+ # [BK,]
45
+ b_gn = tl.load(p_gn, boundary_check=(0,))
46
+ # [BC, BK]
47
+ b_q = tl.load(p_q, boundary_check=(0, 1))
48
+ b_g = tl.load(p_g, boundary_check=(0, 1))
49
+ b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
50
+ # [BK, BC]
51
+ b_k = tl.load(p_k, boundary_check=(0, 1))
52
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
53
+ b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
54
+ # [BC, BC]
55
+ b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
56
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
57
+ elif i_i == i_j:
58
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
59
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
60
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
61
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
62
+ # [BC, BK]
63
+ b_q = tl.load(p_q, boundary_check=(0, 1))
64
+ b_g = tl.load(p_g, boundary_check=(0, 1))
65
+
66
+ o_i = tl.arange(0, BC)
67
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
68
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
69
+ for j in range(0, BC):
70
+ # [BK,]
71
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
72
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
73
+ # [BC,]
74
+ b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
75
+ b_A = tl.where(o_i >= j, b_A, 0.)
76
+ tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
77
+
78
+ p_k = tl.advance(p_k, (K,))
79
+ p_gk = tl.advance(p_gk, (K,))
80
+
81
+
82
+ @triton.jit
83
+ def chunk_gla_fwd_kernel_inter(
84
+ q,
85
+ v,
86
+ g,
87
+ h,
88
+ o,
89
+ A,
90
+ s_k_h,
91
+ s_k_t,
92
+ s_k_d,
93
+ s_v_h,
94
+ s_v_t,
95
+ s_v_d,
96
+ s_h_h,
97
+ s_h_t,
98
+ s_h_d,
99
+ scale,
100
+ T: tl.constexpr,
101
+ K: tl.constexpr,
102
+ V: tl.constexpr,
103
+ BT: tl.constexpr,
104
+ BK: tl.constexpr,
105
+ BV: tl.constexpr
106
+ ):
107
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
108
+
109
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
110
+ for i_k in range(tl.cdiv(K, BK)):
111
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
112
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
113
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
114
+
115
+ # [BT, BK]
116
+ b_q = tl.load(p_q, boundary_check=(0, 1))
117
+ b_q = (b_q * scale).to(b_q.dtype)
118
+ # [BT, BK]
119
+ b_g = tl.load(p_g, boundary_check=(0, 1))
120
+ # [BT, BK]
121
+ b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
122
+ # [BK, BV]
123
+ b_h = tl.load(p_h, boundary_check=(0, 1))
124
+ # works but dkw, owing to divine benevolence
125
+ # [BT, BV]
126
+ if i_k >= 0:
127
+ b_o += tl.dot(b_qg, b_h, allow_tf32=False)
128
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
130
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
131
+ # [BT, BV]
132
+ b_v = tl.load(p_v, boundary_check=(0, 1))
133
+ # [BT, BT]
134
+ b_A = tl.load(p_A, boundary_check=(0, 1))
135
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
136
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
137
+
138
+
139
+ @triton.jit
140
+ def chunk_gla_bwd_kernel_intra(
141
+ q,
142
+ k,
143
+ g,
144
+ dA,
145
+ dq,
146
+ dk,
147
+ dg,
148
+ s_k_h,
149
+ s_k_t,
150
+ s_k_d,
151
+ T: tl.constexpr,
152
+ K: tl.constexpr,
153
+ BT: tl.constexpr,
154
+ BC: tl.constexpr,
155
+ BK: tl.constexpr,
156
+ NC: tl.constexpr
157
+ ):
158
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
159
+ i_t, i_i = i_c // NC, i_c % NC
160
+
161
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
162
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
163
+ # [BK,]
164
+ b_gn = tl.load(p_gn, boundary_check=(0,))
165
+ # [BC, BK]
166
+ b_g = tl.load(p_g, boundary_check=(0, 1))
167
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
168
+ for i_j in range(0, i_i):
169
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
170
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
171
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
172
+ # [BC, BK]
173
+ b_k = tl.load(p_k, boundary_check=(0, 1))
174
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
175
+ b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
176
+ # [BC, BC]
177
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
178
+ # [BC, BK]
179
+ b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
180
+ b_dq *= tl.exp(b_g - b_gn[None, :])
181
+
182
+ o_i = tl.arange(0, BC)
183
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
184
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
185
+ for j in range(0, BC):
186
+ p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
187
+ p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
188
+ # [BC,]
189
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
190
+ # [BK,]
191
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
192
+ b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
193
+ # [BC, BK]
194
+ m_i = o_i[:, None] >= j
195
+ # [BC, BK]
196
+ b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
197
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
198
+
199
+ b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
200
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
201
+
202
+ tl.debug_barrier()
203
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
204
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
205
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
206
+ # [BK,]
207
+ b_gn = tl.load(p_gn, boundary_check=(0,))
208
+ # [BC, BK]
209
+ b_k = tl.load(p_k, boundary_check=(0, 1))
210
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
211
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
212
+ for i_j in range(i_i + 1, NC):
213
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
214
+ p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
215
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
216
+ # [BC, BK]
217
+ b_q = tl.load(p_q, boundary_check=(0, 1))
218
+ b_g = tl.load(p_g, boundary_check=(0, 1))
219
+ b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
220
+ # [BC, BC]
221
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
222
+ # [BC, BK]
223
+ b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
224
+ b_dk *= tl.exp(b_gn[None, :] - b_gk)
225
+
226
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
227
+ for j in range(0, BC):
228
+ p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
229
+ p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
230
+ # [BC,]
231
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
232
+ # [BK,]
233
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
234
+ b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
235
+ # [BC, BK]
236
+ m_i = o_i[:, None] <= j
237
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
238
+
239
+ p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
240
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
241
+ p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
242
+
243
+ b_q = tl.load(p_q, boundary_check=(0, 1))
244
+ b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
245
+ b_dg = b_q * b_dq - b_k * b_dk
246
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
247
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
248
+
249
+
250
+ @triton.jit
251
+ def chunk_gla_bwd_kernel_inter(
252
+ k,
253
+ v,
254
+ h,
255
+ g,
256
+ A,
257
+ do,
258
+ dh,
259
+ dq,
260
+ dk,
261
+ dv,
262
+ dA,
263
+ s_k_h,
264
+ s_k_t,
265
+ s_k_d,
266
+ s_v_h,
267
+ s_v_t,
268
+ s_v_d,
269
+ s_h_h,
270
+ s_h_t,
271
+ s_h_d,
272
+ scale,
273
+ T: tl.constexpr,
274
+ K: tl.constexpr,
275
+ V: tl.constexpr,
276
+ BT: tl.constexpr,
277
+ BK: tl.constexpr,
278
+ BV: tl.constexpr
279
+ ):
280
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
281
+ n_bh = tl.num_programs(2)
282
+
283
+ p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
284
+ p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
285
+ p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
286
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
287
+
288
+ # [BT, BK]
289
+ b_k = tl.load(p_k, boundary_check=(0, 1))
290
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
291
+ b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
292
+ b_k = (b_k * b_gn).to(b_k.dtype)
293
+ # [BT, BT]
294
+ b_A = tl.load(p_A, boundary_check=(0, 1))
295
+
296
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
297
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
298
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
299
+ for i_v in range(tl.cdiv(V, BV)):
300
+ p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
301
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
302
+ p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
303
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
304
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
305
+
306
+ # [BT, BV]
307
+ b_v = tl.load(p_v, boundary_check=(0, 1))
308
+ # [BV, BK]
309
+ b_h = tl.load(p_h, boundary_check=(0, 1))
310
+ # [BT, BV]
311
+ b_do = tl.load(p_do, boundary_check=(0, 1))
312
+ # [BK, BV]
313
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
314
+
315
+ # [BT, BV]
316
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
317
+ if i_k == 0:
318
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
319
+ b_do = (b_do * scale).to(b_do.dtype)
320
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
321
+ # [BT, BT]
322
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
323
+ # [BT, BK]
324
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
325
+ # [BT, BK]
326
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
327
+ b_dq = b_dq * tl.exp(b_gk)
328
+ b_dk = b_dk * b_gn
329
+
330
+ p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
331
+ p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
332
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
333
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
334
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
335
+
336
+ o_i = tl.arange(0, BT)
337
+ m_s = o_i[:, None] >= o_i[None, :]
338
+ # [BT, BT]
339
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
340
+ if i_k == 0:
341
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
342
+
343
+ class ChunkGLAFunction(torch.autograd.Function):
344
+
345
+ @staticmethod
346
+ @contiguous
347
+ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):
348
+ B, H, T, K, V = *q.shape, v.shape[-1]
349
+ BT, BC = 64, 16
350
+ BK = min(64, triton.next_power_of_2(K))
351
+ BV = min(64, triton.next_power_of_2(V))
352
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
353
+ NK = triton.cdiv(K, BK)
354
+ NV = triton.cdiv(V, BV)
355
+ num_warps = 4 if BK == 64 else 2
356
+ num_stages = 1
357
+
358
+ g_cumsum = chunk_local_cumsum(g, BT=BT)
359
+ g_org, g = g, g_cumsum
360
+
361
+ h, ht = chunk_fwd_h_fn(
362
+ k=k, v=v, g=None, gk=g, gv=None, BT=BT, h0=initial_state, output_final_state=output_final_state
363
+ )
364
+ A = q.new_zeros(NK, B, H, T, BT)
365
+ grid = (NK, NT * NC * NC, B * H)
366
+ chunk_gla_fwd_kernel_intra[grid](
367
+ q, k, g, A,
368
+ k.stride(1), k.stride(2), k.stride(3),
369
+ scale,
370
+ T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
371
+ num_warps=num_warps,
372
+ num_stages=num_stages
373
+ )
374
+ A = A.sum(0, dtype=A.dtype)
375
+ o = torch.empty_like(v)
376
+ grid = (NV, NT, B * H)
377
+ chunk_gla_fwd_kernel_inter[grid](
378
+ q, v, g, h, o, A,
379
+ k.stride(1), k.stride(2), k.stride(3),
380
+ v.stride(1), v.stride(2), v.stride(3),
381
+ h.stride(1), h.stride(2), h.stride(3),
382
+ scale,
383
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
384
+ num_warps=num_warps,
385
+ num_stages=num_stages
386
+ )
387
+ if checkpoint_level >= 1:
388
+ del g
389
+ g = g_org
390
+ if checkpoint_level > 1:
391
+ del h
392
+ h = None
393
+
394
+ ctx.save_for_backward(q, k, v, g, h, initial_state, A)
395
+ ctx.BT = BT
396
+ ctx.scale = scale
397
+ ctx.checkpoint_level = checkpoint_level
398
+ return o, ht
399
+
400
+ @staticmethod
401
+ @contiguous
402
+ def backward(ctx, do, dht):
403
+ q, k, v, g, h, initial_state, A = ctx.saved_tensors
404
+ B, H, T, K, V = *q.shape, v.shape[-1]
405
+ BT, BC = ctx.BT, 16
406
+ BK = min(64, triton.next_power_of_2(K))
407
+ BV = min(64, triton.next_power_of_2(V))
408
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
409
+ NK = triton.cdiv(K, BK)
410
+ num_warps = 4 if BK == 64 else 2
411
+ num_stages = 1
412
+
413
+ if ctx.checkpoint_level >= 1:
414
+ g_cumsum = chunk_local_cumsum(g, BT=BT)
415
+ g_org, g = g, g_cumsum
416
+
417
+ if h is None:
418
+ h, _ = chunk_fwd_h_fn(
419
+ k=k, v=v, g=None, gk=g, gv=None, BT=BT, h0=initial_state, output_final_state=False
420
+ )
421
+
422
+ scale = ctx.scale
423
+ dh, dh0 = chunk_bwd_dh_fn(q=q, k=k, v=v, g=None, gk=g, gv=None, do=do, h0=initial_state, dht=dht, BT=BT, scale=scale)
424
+ dq = torch.empty_like(q)
425
+ dk = torch.empty_like(k)
426
+ dg = torch.empty_like(k, dtype=torch.float)
427
+ dv = v.new_empty(NK, *v.shape)
428
+ dA = q.new_zeros(B, H, T, BT)
429
+ grid = (NK, NT, B * H)
430
+ chunk_gla_bwd_kernel_inter[grid](
431
+ k, v, h, g, A, do, dh, dq, dk, dv, dA,
432
+ k.stride(1), k.stride(2), k.stride(3),
433
+ v.stride(1), v.stride(2), v.stride(3),
434
+ h.stride(1), h.stride(2), h.stride(3),
435
+ scale,
436
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
437
+ num_warps=num_warps,
438
+ num_stages=num_stages
439
+ )
440
+ dv = dv.sum(0, dtype=v.dtype)
441
+ grid = (NK, NT * NC, B * H)
442
+ chunk_gla_bwd_kernel_intra[grid](
443
+ q, k, g, dA, dq, dk, dg,
444
+ k.stride(1), k.stride(2), k.stride(3),
445
+ T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
446
+ num_warps=num_warps,
447
+ num_stages=num_stages
448
+ )
449
+ dg = chunk_global_reversed_cumsum(dg).to(k.dtype)
450
+ return dq, dk, dv, dg, None, dh0, None, None
451
+
452
+
453
+ def chunk_gla(
454
+ q: torch.Tensor,
455
+ k: torch.Tensor,
456
+ v: torch.Tensor,
457
+ g: torch.Tensor,
458
+ scale: Optional[int] = None,
459
+ initial_state: torch.Tensor = None,
460
+ output_final_state: bool = False,
461
+ checkpoint_level: Optional[int] = 2
462
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
463
+ r"""
464
+ Args:
465
+ q (torch.Tensor):
466
+ queries of shape `(B, H, T, K)`
467
+ k (torch.Tensor):
468
+ keys of shape `(B, H, T, K)`
469
+ v (torch.Tensor):
470
+ values of shape `(B, H, T, V)`
471
+ g (torch.Tensor):
472
+ Forget gates of shape `(B, H, T, K)` applied to keys.
473
+ scale (Optional[int]):
474
+ Scale factor for the GLA attention scores.
475
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
476
+ initial_state (Optional[torch.Tensor]):
477
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
478
+ output_final_state (Optional[bool]):
479
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
480
+ checkpoint_level (Optional[int]):
481
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
482
+ Default: `0`:
483
+ - Level `0`: no memory saved, no recomputation.
484
+ - Level `1`: recompute the fp32 cumulative values during backward.
485
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
486
+ """
487
+ assert checkpoint_level in [0, 1, 2]
488
+ if scale is None:
489
+ scale = q.shape[-1] ** -0.5
490
+ o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
491
+ return o, final_state
opencompass/models/fla2/ops/gla/chunk_fuse.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Songlin Yang
4
+ # Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635
5
+ # on-the-fly computation without materializing hidden statets into HBMs
6
+
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import triton
12
+ import triton.language as tl
13
+ from einops import rearrange
14
+ from packaging import version
15
+
16
+ from .chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
17
+ prepare_qg_kg)
18
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
19
+
20
+
21
+ @triton.jit
22
+ def fused_chunk_gla_fwd_kernel(
23
+ q, # query [B, H, L, K]
24
+ k, # key [B, H, L, K]
25
+ v, # value [B, H, L, V]
26
+ g, # cumulative sum of log decay [B, H, L, K]
27
+ o, # output [B, H, L, V]
28
+
29
+ h0, # initial state of the chunk [B, H, K, V]
30
+ ht, # final state of the chunk [B, H, K, V]
31
+
32
+ s_qk_h, # stride size: L * K
33
+ s_qk_t, # stride size: K
34
+ s_qk_d, # stride size: 1
35
+
36
+ s_vo_h, # stride size: L * V
37
+ s_vo_t, # stride size: V
38
+ s_vo_d, # stride size: 1
39
+
40
+ B: tl.constexpr, # batch size
41
+ H: tl.constexpr, # H
42
+ T: tl.constexpr, # T
43
+ K: tl.constexpr, # K
44
+ V: tl.constexpr, # V
45
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
46
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
47
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
48
+ USE_INITIAL_STATE: tl.constexpr,
49
+ STORE_FINAL_STATE: tl.constexpr,
50
+ CHECK: tl.constexpr
51
+ ):
52
+ # indices
53
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
54
+
55
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
56
+
57
+ # make block pointers
58
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
59
+ p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
60
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
61
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
62
+ p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
63
+
64
+ if USE_INITIAL_STATE:
65
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
66
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
67
+
68
+ mask = (i_k * BK + tl.arange(0, BK)) < K
69
+
70
+ for i in range(0, tl.cdiv(T, BT)):
71
+ # [BK, BT]
72
+ b_k = tl.load(p_k, boundary_check=(0, 1))
73
+ # [BT, BV]
74
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
75
+ b_v = tl.load(p_v, boundary_check=(0, 1))
76
+ # [BT, BK]
77
+ b_q = tl.load(p_q, boundary_check=(0, 1))
78
+ d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
79
+ if CHECK and i == 0:
80
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
81
+ b_h = b_h * tl.exp(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
82
+ else:
83
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
84
+ b_h = b_h * tl.exp(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
85
+
86
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
87
+ p_q = tl.advance(p_q, (BT, 0))
88
+ p_k = tl.advance(p_k, (0, BT))
89
+ p_v = tl.advance(p_v, (BT, 0))
90
+ p_o = tl.advance(p_o, (BT, 0))
91
+ p_db += BT * K
92
+
93
+ if STORE_FINAL_STATE:
94
+ p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
95
+ tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
96
+
97
+
98
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
99
+ @triton.jit
100
+ def fused_chunk_gla_bwd_kernel(
101
+ q, k, v, g,
102
+ do, # gradient of output [B, H, L, V]
103
+ dq, # gradient of query [NV, B, H, L, K]
104
+ dk, # gradient of key [NV, B, H, L, K]
105
+ dv, # gradient of value [NK, B, H, L, V]
106
+
107
+ h0, # initial state of the chunk [B, H, K, V]
108
+
109
+ s_qk_h, # stride size: L * K
110
+ s_qk_t, # stride size: K
111
+ s_qk_d, # stride size: 1
112
+
113
+ s_vo_h, # stride size: L * V
114
+ s_vo_t, # stride size: V
115
+ s_vo_d, # stride size: 1
116
+ scale, # K ** -0.5
117
+
118
+ B: tl.constexpr, # B
119
+ H: tl.constexpr, # H
120
+ T: tl.constexpr, # T
121
+ K: tl.constexpr, # K
122
+ V: tl.constexpr, # V
123
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
124
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
125
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
126
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
127
+ USE_INITIAL_STATE: tl.constexpr,
128
+ CHECK: tl.constexpr
129
+ ):
130
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
131
+ # [BV, BK]
132
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
133
+
134
+ if USE_INITIAL_STATE:
135
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
136
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
137
+
138
+ mask = (i_k * BK + tl.arange(0, BK)) < K
139
+ for i in range(0, tl.cdiv(T, BT)):
140
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
141
+ p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
142
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
143
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
144
+ p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
145
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
146
+ # [BT, K]
147
+ b_k = tl.load(p_k, boundary_check=(0, 1))
148
+ d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
149
+
150
+ # [V, BT]
151
+ b_v = tl.load(p_v, boundary_check=(0, 1))
152
+ # [BT, V]
153
+ b_do = tl.load(p_do, boundary_check=(0, 1))
154
+ # [V, K]
155
+ if CHECK and i == 0:
156
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
157
+ b_h = b_h * tl.exp(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
158
+ else:
159
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
160
+ b_h = b_h * tl.exp(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
161
+ b_dq *= scale
162
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
163
+
164
+ # sync threads
165
+ b_h = None
166
+ tl.debug_barrier()
167
+ # [BK, BV]
168
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
169
+
170
+ # cum = tl.zeros([BK], dtype=tl.float32)
171
+ for i in range(1, tl.cdiv(T, BT) + 1):
172
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
173
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
174
+ p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
175
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
176
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
177
+ p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, K),
178
+ (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
179
+ p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, V),
180
+ (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
181
+ # [K, BT]
182
+ b_q = tl.load(p_q, boundary_check=(0, 1))
183
+ # [BT, K]
184
+ b_k = tl.load(p_k, boundary_check=(0, 1))
185
+ # [BT, V]
186
+ b_v = tl.load(p_v, boundary_check=(0, 1))
187
+ b_do = tl.load(p_do, boundary_check=(0, 1))
188
+ b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)
189
+
190
+ # inter-chunk
191
+ # [K, V]
192
+ if CHECK and i == 1:
193
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
194
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
195
+ b_dh = b_dh * tl.exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
196
+ else:
197
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
198
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
199
+ b_dh = b_dh * tl.exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
200
+
201
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
202
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+
205
+ @triton.jit
206
+ def fwd_inner_chunk(
207
+ q, k, g, A,
208
+ s_qk_h, # stride size: L * K
209
+ s_qk_t, # stride size: K
210
+ s_qk_d, # stride size: 1
211
+ B, # B
212
+ H, # H
213
+ T, # T
214
+ scale, # K ** -0.5
215
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
216
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
217
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
218
+ K: tl.constexpr, # K
219
+ ):
220
+
221
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
222
+
223
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
224
+
225
+ b_k = tl.load(p_k, boundary_check=(0, 1))
226
+
227
+ p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
228
+
229
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
230
+
231
+ mask = (i_k * BK + tl.arange(0, BK)) < K
232
+ o_i = tl.arange(0, BT)
233
+
234
+ p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK)
235
+ p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK)
236
+ p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
237
+
238
+ for i in range(BT):
239
+ _q = tl.load(p_q, mask=mask, other=0) * scale
240
+ gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
241
+ s = _q[None, :] * b_k * tl.exp(gq[None, :] - b_g)
242
+ score = tl.sum(s, axis=1)
243
+ score = tl.where(o_i <= i, score, 0)
244
+ tl.store(p_A, score.to(p_A.dtype.element_ty))
245
+ p_q += K
246
+ p_gq += K
247
+ p_A += BT
248
+
249
+
250
+ @triton.jit
251
+ def bwd_inner_chunk(
252
+ q,
253
+ k,
254
+ g,
255
+ dA,
256
+ dq,
257
+ dk,
258
+ s_qk_h, # stride size: L * K
259
+ s_qk_t, # stride size: K
260
+ s_qk_d, # stride size: 1
261
+ T: tl.constexpr, # T
262
+ K: tl.constexpr, # K
263
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
264
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
265
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
266
+ ):
267
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
268
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ b_k = tl.load(p_k, boundary_check=(0, 1))
270
+ p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
271
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
272
+
273
+ mask = (i_k * BK + tl.arange(0, BK)) < K
274
+ o_i = tl.arange(0, BT)
275
+
276
+ p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK)
277
+ p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK)
278
+ p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK)
279
+ p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
280
+
281
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
282
+
283
+ for i in range(BT):
284
+ _q = tl.load(p_q, mask=mask, other=0)
285
+ gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
286
+ score = tl.exp(gq[None, :] - b_g)
287
+ score = tl.where(o_i[:, None] <= i, score, 0)
288
+ _dA = tl.load(p_dA)
289
+ _dA = tl.where(o_i <= i, _dA, 0)
290
+ b_dk += (_dA[:, None] * score * _q[None, :])
291
+ b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)
292
+ tl.store(p_dq, b_dq, mask=mask)
293
+ p_q += K
294
+ p_dq += K
295
+ p_gq += K
296
+ p_dA += BT
297
+
298
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
299
+ tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
300
+
301
+
302
+ class FusedChunkGLAFunction(torch.autograd.Function):
303
+
304
+ @staticmethod
305
+ @contiguous
306
+ @autocast_custom_fwd
307
+ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
308
+ ctx.g_dtype = g.dtype
309
+ g_original = g
310
+ # cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
311
+ g = torch.empty_like(g, dtype=torch.float32)
312
+ B, H, T, K, V = *k.shape, v.shape[-1]
313
+ ctx.scale = scale
314
+
315
+ # inter-chunk
316
+ BT = 16 # chunk_size
317
+ BK, BV = min(K, 64), min(V, 64)
318
+ num_stages = 1
319
+ num_warps = 2
320
+
321
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
322
+ o = q.new_empty(NK, B, H, T, V)
323
+ q_g = torch.empty_like(q)
324
+ k_g = torch.empty_like(k)
325
+ grid = (NK, triton.cdiv(T, BT), B * H)
326
+
327
+
328
+
329
+ fwd_decay_cumsum[grid](
330
+ g_original,
331
+ g,
332
+ #q.stride(1),
333
+ T*K,
334
+ K=K,
335
+ BT=BT, BK=BK, num_warps=1
336
+ )
337
+ # print(g)
338
+ # print('gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg')
339
+ prepare_qg_kg[grid](
340
+ q, k, g, q_g, k_g,
341
+ #q.stride(1),
342
+ T*K,
343
+ scale,
344
+ K=K, BT=BT, BK=BK, num_warps=1
345
+ )
346
+
347
+ # data = {
348
+ # 'q': q,
349
+ # 'k': k,
350
+ # 'g': g,
351
+ # 'q_g': q_g,
352
+ # 'k_g': k_g,
353
+ # }
354
+
355
+ # 保存到文件
356
+ # save_path = '/raid/ligq/msj/lra_test/lra_new_test/tensors.pth'
357
+ # torch.save(data, save_path)
358
+ # print(f"Tensors saved to {save_path}")
359
+
360
+ # print(q_g)
361
+ # print('qgqgqgqgqgqgqggqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgq')
362
+ # print(g.min())
363
+ # print('minminminminminminminminminminminminminminminminminminminmin')
364
+ # print(k_g)
365
+ # print('kgkgkgkgkgkgkgkgkkkgkgkgkgkgkgkgkgkgkkgkgkgkgkgkgkgkgkkgkgkgkgkgkgkgk')
366
+
367
+ if output_final_state:
368
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
369
+ else:
370
+ final_state = None
371
+ # the bug still exists even for Triton 2.2 on H100 GPUs
372
+ # so we always enable initial checks
373
+ CHECK = True
374
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
375
+ import warnings
376
+ warnings.warn(
377
+ "Triton<2.2.0 detected for running this kernel, "
378
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
379
+ "that lead to significant precision loss. "
380
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
381
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
382
+ )
383
+ CHECK = True
384
+
385
+ grid = (NV, NK, B * H)
386
+ fused_chunk_gla_fwd_kernel[grid](
387
+ q_g, k_g, v, g, o, initial_state, final_state,
388
+ T*K,K,1,
389
+ T*V,V,1,
390
+ # q.stride(1), q.stride(2), q.stride(3),
391
+ # v.stride(1), v.stride(2), v.stride(3),
392
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
393
+ USE_INITIAL_STATE=initial_state is not None,
394
+ STORE_FINAL_STATE=output_final_state,
395
+ CHECK=CHECK,
396
+ num_warps=num_warps,
397
+ num_stages=num_stages
398
+ )
399
+
400
+ o = o.sum(0)#沿着nk维度求和
401
+ # print(o)
402
+ # print('oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo')
403
+ #intra-chunk
404
+ chunk_size = 16
405
+ num_chunk = T // chunk_size
406
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
407
+ BK = min(K, 64)
408
+ NK = triton.cdiv(K, BK)
409
+ A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT)
410
+ grid = (NK, triton.cdiv(T, BT), B * H)
411
+ fwd_inner_chunk[grid](
412
+ q, k, g, A,
413
+ T*K,K,1,
414
+ #q.stride(1), q.stride(2), q.stride(3),
415
+ B, H, T, scale, BT=BT, BK=BK, K=K, num_stages=3,
416
+ num_warps=4
417
+ )
418
+ A = A.sum(0)
419
+ o2 = A @ v2
420
+ # print(o2)
421
+ # print('ooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo')
422
+ o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
423
+ # combine inner and inter
424
+ o.add_(o2)
425
+ ctx.save_for_backward(q, k, v, g_original, A, initial_state)
426
+ ctx.CHECK = CHECK
427
+ return o.to(v), final_state
428
+
429
+ @staticmethod
430
+ @contiguous
431
+ @autocast_custom_bwd
432
+ def backward(ctx, do, dht=None):
433
+ q, k, v, g_origin, A, initial_state = ctx.saved_tensors
434
+ B, H, T, K, V = *k.shape, v.shape[-1]
435
+ scale = ctx.scale
436
+
437
+ # recomputation
438
+ # inter-chunk
439
+ BT = 16 # chunk_size
440
+ g = torch.empty_like(g_origin, dtype=torch.float32)#仍旧相当于全部参与了运算
441
+ BK, BV = min(K, 64), min(V, 64)
442
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
443
+ q_g = torch.empty_like(q)
444
+ k_g = torch.empty_like(k)
445
+ grid = (NK, triton.cdiv(T, BT), B * H)
446
+ fwd_decay_cumsum[grid](
447
+ g_origin,
448
+ g,
449
+ #q.stride(1),
450
+ T*K,
451
+ K=K,
452
+ BT=BT, BK=BK, num_warps=1
453
+ )
454
+ prepare_qg_kg[grid](
455
+ q, k, g, q_g, k_g,
456
+ #q.stride(1),
457
+ T*K,
458
+ scale,
459
+ K=K, BT=BT, BK=BK, num_warps=1
460
+ )
461
+
462
+ #这部分读取是否导致出错,还是有很大的计算结果在
463
+ # inter-chunk
464
+ BT = 16
465
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
466
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
467
+ num_stages = 1
468
+ num_warps = 2
469
+ dq = q.new_empty(NV, B, H, T, K)
470
+ dk = q.new_empty(NV, B, H, T, K)
471
+ dv = q.new_empty(NK, B, H, T, V)
472
+
473
+ grid = (NV, NK, B * H)
474
+
475
+ fused_chunk_gla_bwd_kernel[grid](
476
+ q_g, k_g, v, g, do, dq, dk, dv, initial_state,
477
+ T*K,K,1,
478
+ T*V,V,1,
479
+ # q.stride(1), q.stride(2), q.stride(3),
480
+ # v.stride(1), v.stride(2), v.stride(3),
481
+ scale,
482
+ B=B, H=H, T=T, K=K, V=V,
483
+ BT=BT, BK=BK, BV=BV,
484
+ USE_INITIAL_STATE=initial_state is not None,
485
+ CHECK=ctx.CHECK,
486
+ num_warps=num_warps,
487
+ num_stages=num_stages,
488
+ )
489
+ dq = dq.sum(0)
490
+ dk = dk.sum(0)
491
+ dv = dv.sum(0)
492
+
493
+ # intra chunk
494
+ num_chunk = T // BT
495
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
496
+ do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)
497
+ dA2 = (do2 @ v2.transpose(-2, -1)) * scale
498
+ dv2 = A.transpose(-1, -2) @ do2
499
+ dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)
500
+
501
+ BK = min(triton.next_power_of_2(K), 16)
502
+ NK = triton.cdiv(K, BK)
503
+ dk2 = torch.empty_like(k)
504
+ dq2 = torch.empty_like(q)
505
+
506
+ grid = (NK, triton.cdiv(T, BT), B * H)
507
+ bwd_inner_chunk[grid](
508
+ q, k, g,
509
+ dA2, dq2, dk2,
510
+ T*K,K,1,
511
+ # q.stride(1), q.stride(2), q.stride(3),
512
+ T=T, K=K, BT=BT, BK=BK,
513
+ num_warps=1,
514
+ num_stages=3
515
+ )
516
+
517
+ BK = min(triton.next_power_of_2(K), 32)
518
+ NK = triton.cdiv(K, BK)
519
+ dg = torch.empty_like(g, dtype=torch.float32)
520
+ grid = (NK, triton.cdiv(T, BT), B * H)
521
+ bwd_decay_global_cumsum[grid](
522
+ dq2, dq, dk2, dk, q, k, g, dg,
523
+ T*K,K,1,
524
+ #q.stride(1), q.stride(2), q.stride(3),
525
+ B, H, T, scale,
526
+ BT=BT, K=K, BK=BK,
527
+ num_warps=1,
528
+ num_stages=1
529
+ )
530
+ dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
531
+
532
+ def rev_cumsum_exclusive(x):
533
+ cumsum_x = x.cumsum(-2)
534
+ rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
535
+ return rev_cumsum_x
536
+
537
+ rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
538
+ dg.add_(rev_cumsum_dg.unsqueeze(-2))
539
+ dv.add_(dv2)
540
+ dg = rearrange(dg, 'b h n c d -> b h (n c) d')
541
+
542
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
543
+
544
+
545
+ def pad(x, chunk_size=16):
546
+ T = x.shape[-2]
547
+ padded_seq_len = ceildiv(T, chunk_size) * chunk_size
548
+ if x.shape[-2] % chunk_size != 0:
549
+ x = F.pad(x, (0, 0, 0, padded_seq_len - T))
550
+ return x
551
+
552
+
553
+ def ceildiv(a, b):
554
+ return -(a // -b)
555
+
556
+ #默认head_first
557
+ def fused_chunk_gla(
558
+ q: torch.Tensor,
559
+ k: torch.Tensor,
560
+ v: torch.Tensor,
561
+ g: torch.Tensor,
562
+ scale: int = -1,
563
+ initial_state: torch.Tensor = None,
564
+ output_final_state: bool = False
565
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
566
+ if scale == -1:
567
+ scale = q.shape[-1] ** -0.5
568
+ if initial_state is not None:
569
+ initial_state = initial_state.detach()
570
+ seq_len = q.shape[-2]
571
+ q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
572
+ o, final_state = FusedChunkGLAFunction.apply(
573
+ q, k, v, g, scale, initial_state, output_final_state)
574
+ o = o[..., :seq_len, :]
575
+ return o, final_state
opencompass/models/fla2/ops/gla/chunk_util.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def fwd_decay_cumsum(
7
+ g,
8
+ g_o,
9
+ s_qk_h,
10
+ K: tl.constexpr,
11
+ BT: tl.constexpr,
12
+ BK: tl.constexpr
13
+ ):
14
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
15
+ p_g = g + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
16
+ p_go = g_o + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
17
+ cum_decay = tl.zeros([BK], dtype=tl.float32)
18
+ mask = (i_k * BK + tl.arange(0, BK)) < K
19
+
20
+ for i in range(BT):
21
+ _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
22
+ cum_decay += _g
23
+ tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
24
+ p_g += K
25
+ p_go += K
26
+
27
+
28
+ @triton.jit
29
+ def prepare_qg_kg(
30
+ q,
31
+ k,
32
+ g,
33
+ qg,
34
+ kg,
35
+ s_qk_h,
36
+ scale,
37
+ K: tl.constexpr,
38
+ BT: tl.constexpr,
39
+ BK: tl.constexpr
40
+ ):
41
+
42
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
43
+ p_q = q + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
44
+ p_g = g + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
45
+ p_k = k + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
46
+ p_qg = qg + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
47
+ p_kg = kg + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK)
48
+
49
+ mask = (i_k * BK + tl.arange(0, BK)) < K
50
+
51
+ last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK))
52
+
53
+
54
+ for i in range(BT):
55
+ _q = tl.load(p_q, mask=mask, other=0)
56
+ _k = tl.load(p_k, mask=mask, other=0)
57
+ _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
58
+ _q *= tl.exp(_g) * scale
59
+ _k *= tl.exp(last_decay - _g)
60
+ tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
61
+ tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
62
+ p_q += K
63
+ p_g += K
64
+ p_k += K
65
+ p_kg += K
66
+ p_qg += K
67
+
68
+
69
+ @triton.jit
70
+ def bwd_decay_global_cumsum(
71
+ dq_inner,
72
+ dq_inter,
73
+ dk_inner,
74
+ dk_inter,
75
+ q, k, g, dg,
76
+ s_qk_h,
77
+ s_qk_t,
78
+ s_qk_d,
79
+ B,
80
+ H,
81
+ T,
82
+ scale,
83
+ BT: tl.constexpr,
84
+ BK: tl.constexpr,
85
+ K: tl.constexpr
86
+ ):
87
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
88
+ p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
89
+ p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
90
+ p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
91
+ p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
92
+ p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
93
+ p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
94
+ p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
95
+ p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
96
+ cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
97
+ mask = (i_k * BK + tl.arange(0, BK)) < K
98
+ last_g = tl.zeros([BK], dtype=tl.float32)
99
+ for j in range(BT-1, -1, -1):
100
+ _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
101
+ if j == (BT-1):
102
+ last_g = _g
103
+ _dq1 = tl.load(p_dq_inner, mask=mask, other=0)
104
+ _dq2 = tl.load(p_dq_inter, mask=mask, other=0)
105
+ _dq2 *= tl.exp(_g)
106
+ _dq = _dq1 + _dq2
107
+ tl.store(p_dq_inter, _dq, mask=mask)
108
+ _dk1 = tl.load(p_dk_inner, mask=mask, other=0)
109
+ _dk2 = tl.load(p_dk_inter, mask=mask, other=0)
110
+ _dk2 *= tl.exp(last_g - _g)
111
+ _dk = _dk1 + _dk2
112
+ tl.store(p_dk_inter, _dk, mask=mask)
113
+ _q = tl.load(p_q, mask=mask, other=0)
114
+ _k = tl.load(p_k, mask=mask, other=0)
115
+ _dg = _dq * _q - _dk * _k
116
+ cum_grad_dg += _dg
117
+ tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
118
+ p_g -= K
119
+ p_k -= K
120
+ p_q -= K
121
+ p_dq_inner -= K
122
+ p_dk_inner -= K
123
+ p_dq_inter -= K
124
+ p_dk_inter -= K
125
+ p_dg -= K
opencompass/models/fla2/ops/gla/naive.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ...ops.gla.recurrent_fuse import fused_recurrent_gla
7
+
8
+
9
+ def ceildiv(a, b):
10
+ return -(a // -b)
11
+
12
+
13
+ def naive_recurrent_gla(
14
+ q,
15
+ k,
16
+ v,
17
+ gk,
18
+ initial_state=None,
19
+ output_final_state=False,
20
+ causal=True
21
+ ):
22
+ orig_dtype = q.dtype
23
+ q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
24
+ batch_size, n_heads, seq_len, d_head_k = q.shape
25
+ _, _, _, d_head_v = v.shape
26
+ h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
27
+ o = torch.zeros_like(v)
28
+ scale = d_head_k ** -0.5
29
+
30
+ if initial_state is not None:
31
+ h += initial_state
32
+
33
+ for i in range(seq_len):
34
+ q_i = q[:, :, i, :] * scale
35
+ k_i = k[:, :, i]
36
+ v_i = v[:, :, i, :]
37
+ gk_i = gk[:, :, i].exp()
38
+ kv_i = k_i[..., None] * v_i[..., None, :]
39
+ h = h * gk_i[..., None] + kv_i
40
+ o_i = (q_i[..., None] * h).sum(-2)
41
+ o[:, :, i] = o_i
42
+
43
+ if causal:
44
+ return o.to(orig_dtype), h
45
+ else:
46
+ o_reverse = torch.zeros_like(v)
47
+ h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
48
+ for i in range(seq_len-1, -1, -1):
49
+ q_i = q[:, :, i, :] * scale
50
+ k_i = k[:, :, i]
51
+ v_i = v[:, :, i, :]
52
+ gk_i = gk[:, :, i].exp()
53
+ kv_i = k_i[..., None] * v_i[..., None, :]
54
+ h = h * gk_i[..., None] + kv_i
55
+ o_i = (q_i[..., None] * h).sum(-2)
56
+ o_reverse[:, :, i] = o_i
57
+
58
+ return o, o_reverse
59
+
60
+
61
+ if __name__ == "__main__":
62
+ B = 4
63
+ H = 4
64
+ L = 512
65
+ D = 128
66
+ dtype = torch.float32
67
+ q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
68
+ k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
69
+ v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
70
+ g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
71
+ ).clamp_min(-1).to(torch.float32).requires_grad_(True)
72
+
73
+ do = torch.rand_like(v).cuda()
74
+ do2 = torch.rand_like(v).cuda()
75
+ intial_state = torch.rand(B, H, D, D).cuda()
76
+
77
+ ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
78
+
79
+ ref.backward(do, retain_graph=True)
80
+ ref_rev.backward(do2, retain_graph=True)
81
+
82
+ ref_dq, q.grad = q.grad.clone(), None
83
+ ref_dk, k.grad = k.grad.clone(), None
84
+ ref_dv, v.grad = v.grad.clone(), None
85
+ ref_dg, g.grad = g.grad.clone(), None
86
+
87
+ tri, tri_rev = fused_recurrent_gla(
88
+ q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
89
+ tri.backward(do, retain_graph=True)
90
+ tri_rev.backward(do2, retain_graph=True)
91
+ tri_dq, q.grad = q.grad.clone(), None
92
+ tri_dk, k.grad = k.grad.clone(), None
93
+ tri_dv, v.grad = v.grad.clone(), None
94
+ tri_dg, g.grad = g.grad.clone(), None
95
+
96
+ assert ref.allclose(tri, 0, 1e-5), breakpoint()
97
+ assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
98
+ assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
99
+ assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
100
+ assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
101
+ assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
102
+
103
+ # tri = fused_chunk_gla(q, k, v, g)
104
+ # tri.backward(do, retain_graph=True)
105
+ # tri_dq, q.grad = q.grad.clone(), None
106
+ # tri_dk, k.grad = k.grad.clone(), None
107
+ # tri_dv, v.grad = v.grad.clone(), None
108
+ # tri_dg, g.grad = g.grad.clone(), None
109
+
110
+ # assert ref.allclose(tri, 0, 1e-5), breakpoint()
111
+ # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
112
+ # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
113
+ # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
114
+ # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
115
+ # breakpoint()
116
+ print("Pass")
opencompass/models/fla2/ops/gla/recurrent_fuse.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Songlin Yang
4
+
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
11
+ from ...ops.common.fused_recurrent import fused_recurrent
12
+
13
+ def fused_recurrent_gla(
14
+ q: torch.Tensor,
15
+ k: torch.Tensor,
16
+ v: torch.Tensor,
17
+ gk: torch.Tensor = None,
18
+ gv: torch.Tensor = None,
19
+ scale: int = None,
20
+ initial_state: torch.Tensor = None,
21
+ output_final_state: bool = False,
22
+ reverse: bool = False
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ if scale is None:
25
+ scale = q.shape[-1] ** -0.5
26
+ o, final_state = fused_recurrent(q, k, v, None, gk, gv, scale, initial_state, output_final_state, reverse)
27
+ return o, final_state
opencompass/models/fla2/ops/hgrn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_hgrn
4
+ from .recurrent_fuse import fused_recurrent_hgrn
5
+
6
+ __all__ = [
7
+ 'chunk_hgrn',
8
+ 'fused_recurrent_hgrn'
9
+ ]
opencompass/models/fla2/ops/hgrn/chunk.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2024, Yu Zhang, Songlin Yang
4
+
5
+ # this function implements the chunkwise form of HGRN, inspired by
6
+ # [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
7
+ # also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
8
+
9
+ # from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent:
10
+ #
11
+ # Performance:
12
+ # seq_len chunk recurrent chunk_bwd recurrent_bwd
13
+ # 0 128.0 0.039360 0.061056 0.312160 0.205008
14
+ # 1 256.0 0.045824 0.123712 0.308784 0.297696
15
+ # 2 512.0 0.058688 0.241952 0.310720 0.626528
16
+ # 3 1024.0 0.088288 0.476992 0.313184 1.333152
17
+ # 4 2048.0 0.169472 0.943264 0.452464 2.724864
18
+ # 5 4096.0 0.329920 1.886144 0.881600 5.551520
19
+ # 6 8192.0 0.647872 3.755040 1.740496 11.117184
20
+ # 7 16384.0 1.272064 7.520576 3.446608 22.362528
21
+
22
+ from typing import Tuple
23
+
24
+ import torch
25
+ import triton
26
+ import triton.language as tl
27
+
28
+ from fla.utils import contiguous
29
+
30
+
31
+ @triton.autotune(
32
+ configs=[
33
+ triton.Config({'BD': 32}, num_warps=1),
34
+ triton.Config({'BD': 32}, num_warps=2),
35
+ triton.Config({'BD': 32}, num_warps=4),
36
+ triton.Config({'BD': 32}, num_warps=8),
37
+ triton.Config({'BD': 64}, num_warps=1),
38
+ triton.Config({'BD': 64}, num_warps=2),
39
+ triton.Config({'BD': 64}, num_warps=4),
40
+ triton.Config({'BD': 64}, num_warps=8),
41
+ triton.Config({'BD': 128}, num_warps=1),
42
+ triton.Config({'BD': 128}, num_warps=2),
43
+ triton.Config({'BD': 128}, num_warps=4),
44
+ triton.Config({'BD': 128}, num_warps=8),
45
+ ],
46
+ key=['D']
47
+ )
48
+ @triton.jit
49
+ def chunk_hgrn_fwd_kernel_h(
50
+ x,
51
+ g,
52
+ gc,
53
+ o,
54
+ h0,
55
+ T: tl.constexpr,
56
+ D: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BD: tl.constexpr,
59
+ USE_INITIAL_STATE: tl.constexpr
60
+ ):
61
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ o_d = i_d * BD + tl.arange(0, BD)
63
+ mask = o_d < D
64
+
65
+ p_x = x + i_bh * T * D + i_t * BT * D + o_d
66
+ p_g = g + i_bh * T * D + i_t * BT * D + o_d
67
+ p_gc = gc + i_bh * T * D + i_t * BT * D + o_d
68
+ p_o = o + i_bh * T * D + i_t * BT * D + o_d
69
+
70
+ b_h = tl.zeros([BD], dtype=tl.float32)
71
+ b_gc = tl.zeros([BD], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ if i_t == 0:
74
+ b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
75
+ for i in range(0, BT):
76
+ mask_t = mask & ((i_t * BT + i) < T)
77
+ b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
78
+ b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
79
+ b_h = tl.exp(b_g) * b_h + b_x
80
+ b_gc = b_gc + b_g
81
+ tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
82
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
83
+
84
+ p_x += D
85
+ p_g += D
86
+ p_gc += D
87
+ p_o += D
88
+
89
+
90
+ @triton.jit
91
+ def chunk_hgrn_fwd_kernel_o(
92
+ gc,
93
+ o,
94
+ s_h,
95
+ s_t,
96
+ s_d,
97
+ T: tl.constexpr,
98
+ D: tl.constexpr,
99
+ BT: tl.constexpr,
100
+ BD: tl.constexpr
101
+ ):
102
+ i_d, i_bh = tl.program_id(0), tl.program_id(1)
103
+ o_d = i_d * BD + tl.arange(0, BD)
104
+ mask = o_d < D
105
+
106
+ for i_t in range(1, tl.cdiv(T, BT)):
107
+ p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
108
+ p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
109
+
110
+ # [BD,]
111
+ b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
112
+ # [BT, BD]
113
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
114
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
115
+ b_o = b_o + tl.exp(b_gc) * b_h0[None, :]
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.autotune(
120
+ configs=[
121
+ triton.Config({'BD': 32}, num_warps=1),
122
+ triton.Config({'BD': 32}, num_warps=2),
123
+ triton.Config({'BD': 32}, num_warps=4),
124
+ triton.Config({'BD': 32}, num_warps=8),
125
+ triton.Config({'BD': 64}, num_warps=1),
126
+ triton.Config({'BD': 64}, num_warps=2),
127
+ triton.Config({'BD': 64}, num_warps=4),
128
+ triton.Config({'BD': 64}, num_warps=8),
129
+ triton.Config({'BD': 128}, num_warps=1),
130
+ triton.Config({'BD': 128}, num_warps=2),
131
+ triton.Config({'BD': 128}, num_warps=4),
132
+ triton.Config({'BD': 128}, num_warps=8),
133
+ ],
134
+ key=['D']
135
+ )
136
+ @triton.jit
137
+ def chunk_hgrn_bwd_kernel_h(
138
+ g,
139
+ gc,
140
+ dx,
141
+ do,
142
+ T: tl.constexpr,
143
+ D: tl.constexpr,
144
+ BT: tl.constexpr,
145
+ BD: tl.constexpr
146
+ ):
147
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
148
+ o_d = i_d * BD + tl.arange(0, BD)
149
+ mask = o_d < D
150
+ BC = min(BT, T - i_t * BT)
151
+ NT = tl.num_programs(1)
152
+
153
+ p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d
154
+ p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d
155
+ p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d
156
+ p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d
157
+
158
+ if i_t == NT - 1:
159
+ b_gc = tl.zeros([BD], dtype=tl.float32)
160
+ else:
161
+ b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
162
+ b_dh = tl.zeros([BD], dtype=tl.float32)
163
+ for _ in range(BC - 1, -1, -1):
164
+ tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
165
+
166
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
167
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
168
+
169
+ b_gc = b_gc + b_g
170
+ b_dh = b_dh + b_do
171
+ b_dx = b_dh
172
+ b_dh = b_dh * tl.exp(b_g)
173
+
174
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
175
+
176
+ p_g -= D
177
+ p_gc -= D
178
+ p_dx -= D
179
+ p_do -= D
180
+
181
+
182
+ @triton.jit
183
+ def chunk_hgrn_bwd_kernel_o(
184
+ g,
185
+ gc,
186
+ o,
187
+ dx,
188
+ dg,
189
+ s_h,
190
+ s_t,
191
+ s_d,
192
+ T: tl.constexpr,
193
+ D: tl.constexpr,
194
+ BT: tl.constexpr,
195
+ BD: tl.constexpr
196
+ ):
197
+ i_d, i_bh = tl.program_id(0), tl.program_id(1)
198
+ o_d = i_d * BD + tl.arange(0, BD)
199
+ mask = o_d < D
200
+
201
+ for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
202
+ p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
203
+ p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
204
+ p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
205
+ p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
206
+ p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
207
+
208
+ # [BD,]
209
+ mask_t = mask & ((i_t + 1) * BT < T)
210
+ b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
211
+ # [BT, BD]
212
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
213
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
214
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
215
+ b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
216
+
217
+ b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]
218
+ b_dg = b_o * b_dx * tl.exp(b_g)
219
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
220
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
221
+
222
+
223
+ class ChunkHGRNFunction(torch.autograd.Function):
224
+
225
+ @staticmethod
226
+ @contiguous
227
+ def forward(ctx, x, g, initial_state=None, output_final_state=False):
228
+ B, H, T, D = x.shape
229
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
230
+ num_warps = 8 if BD == 64 else 4
231
+
232
+ gc = torch.empty_like(g, dtype=torch.float)
233
+ o = torch.empty_like(x, dtype=torch.float)
234
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
235
+ chunk_hgrn_fwd_kernel_h[grid](
236
+ x, g, gc, o, initial_state,
237
+ T=T, D=D, BT=BT,
238
+ USE_INITIAL_STATE=initial_state is not None
239
+ )
240
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
241
+ chunk_hgrn_fwd_kernel_o[grid](
242
+ gc, o,
243
+ o.stride(1), o.stride(2), o.stride(3),
244
+ T=T, D=D, BT=BT, BD=BD,
245
+ num_warps=num_warps
246
+ )
247
+ final_state = None
248
+ if output_final_state:
249
+ final_state = o[:, :, -1].clone()
250
+ o = o.to(x.dtype)
251
+ ctx.save_for_backward(g, o, initial_state)
252
+ return o, final_state
253
+
254
+ @staticmethod
255
+ @contiguous
256
+ def backward(ctx, do, dht=None):
257
+ g, o, initial_state = ctx.saved_tensors
258
+ B, H, T, D = do.shape
259
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
260
+ num_warps = 8 if BD == 64 else 4
261
+
262
+ gc = torch.empty_like(g, dtype=torch.float)
263
+ dx = torch.empty_like(o, dtype=torch.float)
264
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
265
+ chunk_hgrn_bwd_kernel_h[grid](
266
+ g, gc, dx, do,
267
+ T=T, D=D, BT=BT
268
+ )
269
+
270
+ dg = torch.empty_like(g, dtype=torch.float)
271
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
272
+ chunk_hgrn_bwd_kernel_o[grid](
273
+ g, gc, o, dx, dg,
274
+ o.stride(1), o.stride(2), o.stride(3),
275
+ T=T, D=D, BT=BT, BD=BD,
276
+ num_warps=num_warps
277
+ )
278
+ if initial_state is not None:
279
+ dg[:, :, 0] = (initial_state * dx[:, :, 0] * g[:, :, 0].float().exp()).to(dg.dtype)
280
+
281
+ return dx.to(o.dtype), dg, None, None
282
+
283
+
284
+ def chunk_hgrn(
285
+ x: torch.Tensor,
286
+ g: torch.Tensor,
287
+ initial_state: torch.Tensor = None,
288
+ output_final_state: bool = False
289
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
290
+ return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
opencompass/models/fla2/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, H, T, D = x.shape
17
+
18
+ h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, :, i].exp() * h + x[:, :, i]
27
+ o[:, :, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, H, T, D = x.shape
44
+
45
+ gc = g.view(B, H, -1, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, :, j].exp() * h + x[:, :, j]
58
+ o[:, :, j] = hp * gc[:, :, j].exp() + h
59
+ h = o[:, :, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
opencompass/models/fla2/ops/hgrn/recurrent_fuse.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Songlin Yang
4
+
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import contiguous
12
+
13
+
14
+ @triton.autotune(
15
+ configs=[
16
+ triton.Config({'BD': 32}, num_warps=1),
17
+ triton.Config({'BD': 32}, num_warps=2),
18
+ triton.Config({'BD': 32}, num_warps=4),
19
+ triton.Config({'BD': 32}, num_warps=8),
20
+ triton.Config({'BD': 64}, num_warps=1),
21
+ triton.Config({'BD': 64}, num_warps=2),
22
+ triton.Config({'BD': 64}, num_warps=4),
23
+ triton.Config({'BD': 64}, num_warps=8),
24
+ triton.Config({'BD': 128}, num_warps=1),
25
+ triton.Config({'BD': 128}, num_warps=2),
26
+ triton.Config({'BD': 128}, num_warps=4),
27
+ triton.Config({'BD': 128}, num_warps=8),
28
+ ],
29
+ key=['D']
30
+ )
31
+ @triton.jit
32
+ def fused_recurrent_hgrn_fwd_kernel(
33
+ x,
34
+ g,
35
+ o,
36
+ h0,
37
+ ht,
38
+ T: tl.constexpr,
39
+ D: tl.constexpr,
40
+ BD: tl.constexpr,
41
+ USE_INITIAL_STATE: tl.constexpr,
42
+ STORE_FINAL_STATE: tl.constexpr
43
+ ):
44
+ i_d, i_bh = tl.program_id(0), tl.program_id(1)
45
+ o_d = i_d * BD + tl.arange(0, BD)
46
+ mask = o_d < D
47
+
48
+ p_x = x + i_bh * T * D + o_d
49
+ p_g = g + i_bh * T * D + o_d
50
+ p_o = o + i_bh * T * D + o_d
51
+
52
+ b_h = tl.zeros([BD], dtype=tl.float32)
53
+ if USE_INITIAL_STATE:
54
+ p_h0 = h0 + i_bh * D + o_d
55
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
56
+ for _ in range(0, T):
57
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
58
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
59
+ b_h = tl.exp(b_g) * b_h + b_x
60
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
61
+
62
+ p_x += D
63
+ p_g += D
64
+ p_o += D
65
+
66
+ if STORE_FINAL_STATE:
67
+ p_ht = ht + i_bh * D + o_d
68
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
69
+
70
+
71
+ @triton.autotune(
72
+ configs=[
73
+ triton.Config({'BD': 32}, num_warps=1),
74
+ triton.Config({'BD': 32}, num_warps=2),
75
+ triton.Config({'BD': 32}, num_warps=4),
76
+ triton.Config({'BD': 32}, num_warps=8),
77
+ triton.Config({'BD': 64}, num_warps=1),
78
+ triton.Config({'BD': 64}, num_warps=2),
79
+ triton.Config({'BD': 64}, num_warps=4),
80
+ triton.Config({'BD': 64}, num_warps=8),
81
+ triton.Config({'BD': 128}, num_warps=1),
82
+ triton.Config({'BD': 128}, num_warps=2),
83
+ triton.Config({'BD': 128}, num_warps=4),
84
+ triton.Config({'BD': 128}, num_warps=8),
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ dx,
93
+ dg,
94
+ do,
95
+ h0,
96
+ T: tl.constexpr,
97
+ D: tl.constexpr,
98
+ BD: tl.constexpr,
99
+ USE_INITIAL_STATE: tl.constexpr
100
+ ):
101
+ i_d, i_bh = tl.program_id(0), tl.program_id(1)
102
+ o_d = i_d * BD + tl.arange(0, BD)
103
+ mask = o_d < D
104
+
105
+ p_g = g + (i_bh * T + T - 1) * D + o_d
106
+ p_o = o + (i_bh * T + T - 2) * D + o_d
107
+ p_dx = dx + (i_bh * T + T - 1) * D + o_d
108
+ p_dg = dg + (i_bh * T + T - 1) * D + o_d
109
+ p_do = do + (i_bh * T + T - 1) * D + o_d
110
+
111
+ b_dh = tl.zeros([BD], dtype=tl.float32)
112
+ for i in range(T - 1, -1, -1):
113
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
114
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
115
+ if i > 0:
116
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
117
+ elif USE_INITIAL_STATE:
118
+ b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
119
+ else:
120
+ b_o = tl.zeros([BD], dtype=tl.float32)
121
+
122
+ b_dh = b_dh + b_do
123
+ b_dx = b_dh
124
+ b_dh = b_dh * tl.exp(b_g)
125
+ b_dg = b_dh * b_o
126
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
127
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
128
+
129
+ p_g -= D
130
+ p_o -= D
131
+ p_dx -= D
132
+ p_dg -= D
133
+ p_do -= D
134
+
135
+
136
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
137
+
138
+ @staticmethod
139
+ @contiguous
140
+ def forward(ctx, x, g, initial_state=None, output_final_state=False):
141
+ B, H, T, D = x.shape
142
+
143
+ final_state = None
144
+ if output_final_state:
145
+ final_state = x.new_empty(B, H, D)
146
+
147
+ o = torch.empty_like(x)
148
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
149
+ fused_recurrent_hgrn_fwd_kernel[grid](
150
+ x, g, o, initial_state, final_state,
151
+ T, D,
152
+ USE_INITIAL_STATE=initial_state is not None,
153
+ STORE_FINAL_STATE=final_state is not None
154
+ )
155
+ ctx.save_for_backward(g, o, initial_state)
156
+ return o, final_state
157
+
158
+ @staticmethod
159
+ @contiguous
160
+ def backward(ctx, do, dht=None):
161
+ g, o, initial_state = ctx.saved_tensors
162
+ B, H, T, D = do.shape
163
+
164
+ dx = torch.empty_like(o, dtype=torch.float)
165
+ dg = torch.empty_like(g, dtype=torch.float)
166
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
167
+ fused_recurrent_hgrn_bwd_kernel[grid](
168
+ g, o, dx, dg, do, initial_state,
169
+ T, D,
170
+ USE_INITIAL_STATE=initial_state is not None,
171
+ )
172
+
173
+ return dx, dg, None, None
174
+
175
+
176
+ def fused_recurrent_hgrn(
177
+ x: torch.Tensor,
178
+ g: torch.Tensor,
179
+ initial_state: torch.Tensor = None,
180
+ output_final_state: bool = False
181
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
182
+ return FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
opencompass/models/fla2/ops/linear_attn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_linear_attn
4
+ from .chunk_fuse import fused_chunk_linear_attn
5
+ from .recurrent_fuse import fused_recurrent_linear_attn
6
+
7
+ __all__ = [
8
+ 'chunk_linear_attn',
9
+ 'fused_chunk_linear_attn',
10
+ 'fused_recurrent_linear_attn'
11
+ ]
opencompass/models/fla2/ops/linear_attn/chunk.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023, Yu Zhang, Songlin Yang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.linear_attn.utils import normalize_output
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
12
+
13
+
14
+ @triton.jit
15
+ def chunk_linear_attn_fwd_kernel_h(
16
+ k,
17
+ v,
18
+ h,
19
+ h0,
20
+ ht,
21
+ s_qk_h,
22
+ s_qk_t,
23
+ s_qk_d,
24
+ s_vo_h,
25
+ s_vo_t,
26
+ s_vo_d,
27
+ s_h_h,
28
+ s_h_t,
29
+ T: tl.constexpr,
30
+ K: tl.constexpr,
31
+ V: tl.constexpr,
32
+ BT: tl.constexpr,
33
+ BK: tl.constexpr,
34
+ BV: tl.constexpr,
35
+ NT: tl.constexpr,
36
+ USE_INITIAL_STATE: tl.constexpr,
37
+ STORE_FINAL_STATE: tl.constexpr
38
+ ):
39
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
40
+
41
+ # [BK, BV]
42
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
43
+
44
+ if USE_INITIAL_STATE:
45
+ p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
46
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
47
+
48
+ for i_t in range(NT):
49
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
50
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
51
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
52
+
53
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
54
+ # [BK, BT]
55
+ b_k = tl.load(p_k, boundary_check=(0, 1))
56
+ # [BT, BV]
57
+ b_v = tl.load(p_v, boundary_check=(0, 1))
58
+ # [BK, BV]
59
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
60
+
61
+ if STORE_FINAL_STATE:
62
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
63
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
64
+
65
+
66
+ @triton.jit
67
+ def chunk_linear_attn_fwd_kernel_o(
68
+ q,
69
+ k,
70
+ v,
71
+ h,
72
+ o,
73
+ s_qk_h,
74
+ s_qk_t,
75
+ s_qk_d,
76
+ s_vo_h,
77
+ s_vo_t,
78
+ s_vo_d,
79
+ s_h_h,
80
+ s_h_t,
81
+ scale,
82
+ T: tl.constexpr,
83
+ K: tl.constexpr,
84
+ V: tl.constexpr,
85
+ BT: tl.constexpr,
86
+ BK: tl.constexpr,
87
+ BV: tl.constexpr
88
+ ):
89
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
90
+
91
+ o_i = tl.arange(0, BT)
92
+ m_s = o_i[:, None] >= o_i[None, :]
93
+
94
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
95
+ b_s = tl.zeros([BT, BT], dtype=tl.float32)
96
+ for i_k in range(tl.cdiv(K, BK)):
97
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
98
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
99
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
100
+ # [BT, BK]
101
+ b_q = tl.load(p_q, boundary_check=(0, 1))
102
+ # [BK, BT]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ # [BK, BV]
105
+ b_h = tl.load(p_h, boundary_check=(0, 1))
106
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
107
+ b_s += tl.dot(b_q, b_k, allow_tf32=False)
108
+ b_s = tl.where(m_s, b_s, 0)
109
+
110
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
111
+ p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
112
+
113
+ b_v = tl.load(p_v, boundary_check=(0, 1))
114
+ b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
115
+
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.jit
120
+ def chunk_linear_attn_bwd_kernel_dh(
121
+ q,
122
+ do,
123
+ dh,
124
+ s_qk_h,
125
+ s_qk_t,
126
+ s_qk_d,
127
+ s_vo_h,
128
+ s_vo_t,
129
+ s_vo_d,
130
+ s_h_h,
131
+ s_h_t,
132
+ scale,
133
+ T: tl.constexpr,
134
+ K: tl.constexpr,
135
+ V: tl.constexpr,
136
+ BT: tl.constexpr,
137
+ BK: tl.constexpr,
138
+ BV: tl.constexpr,
139
+ NT: tl.constexpr
140
+ ):
141
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
142
+
143
+ # [BK, BV]
144
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
145
+ for i_t in range(NT - 1, -1, -1):
146
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
147
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
148
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
149
+
150
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
151
+ # [BK, BT]
152
+ b_q = tl.load(p_q, boundary_check=(0, 1))
153
+ b_q = (b_q * scale).to(b_q.dtype)
154
+ # [BT, V]
155
+ b_do = tl.load(p_do, boundary_check=(0, 1))
156
+ # [BK, BV]
157
+ b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
158
+
159
+
160
+ @triton.jit
161
+ def chunk_linear_attn_bwd_kernel_dqkv(
162
+ q,
163
+ k,
164
+ v,
165
+ h,
166
+ do,
167
+ dh,
168
+ dq,
169
+ dk,
170
+ dv,
171
+ s_qk_h,
172
+ s_qk_t,
173
+ s_qk_d,
174
+ s_vo_h,
175
+ s_vo_t,
176
+ s_vo_d,
177
+ s_h_h,
178
+ s_h_t,
179
+ scale,
180
+ T: tl.constexpr,
181
+ K: tl.constexpr,
182
+ V: tl.constexpr,
183
+ BT: tl.constexpr,
184
+ BK: tl.constexpr,
185
+ BV: tl.constexpr,
186
+ NT: tl.constexpr
187
+ ):
188
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
189
+ n_bh = tl.num_programs(2)
190
+ o_i = tl.arange(0, BT)
191
+
192
+ p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
193
+ p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+
195
+ b_q = tl.load(p_q, boundary_check=(0, 1))
196
+ b_k = tl.load(p_k, boundary_check=(0, 1))
197
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
198
+ b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
199
+
200
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
201
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
202
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
203
+ for i_v in range(tl.cdiv(V, BV)):
204
+ p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
206
+ p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
207
+ p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
208
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
209
+ # [BT, BV]
210
+ b_v = tl.load(p_v, boundary_check=(0, 1))
211
+ b_do = tl.load(p_do, boundary_check=(0, 1))
212
+ # [BV, BK]
213
+ b_h = tl.load(p_h, boundary_check=(0, 1))
214
+ # [BK, BV]
215
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
216
+ # [BT, BT]
217
+ b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
218
+ # [BT, BK]
219
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
220
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
221
+ # [BT, BV]
222
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
223
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
224
+ # [BT, BT]
225
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
226
+ # [BT, BK]
227
+ b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
228
+ b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
229
+
230
+ p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
231
+ p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
232
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
233
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
234
+
235
+
236
+ class ChunkLinearAttentionFunction(torch.autograd.Function):
237
+
238
+ @staticmethod
239
+ @contiguous
240
+ @autocast_custom_fwd
241
+ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
242
+ B, H, T, K, V = *q.shape, v.shape[-1]
243
+ BT = 64
244
+ BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
245
+ NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
246
+ num_stages = 1
247
+ num_warps = 4 if BK == 64 else 2
248
+ ctx.scale = scale
249
+
250
+ final_state = None
251
+ if output_final_state:
252
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
253
+
254
+ h = q.new_empty(B, H, NT * K, V)
255
+ grid = (NK, NV, B * H)
256
+ chunk_linear_attn_fwd_kernel_h[grid](
257
+ k, v, h, initial_state, final_state,
258
+ q.stride(1), q.stride(2), q.stride(3),
259
+ v.stride(1), v.stride(2), v.stride(3),
260
+ h.stride(1), h.stride(2),
261
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
262
+ USE_INITIAL_STATE=initial_state is not None,
263
+ STORE_FINAL_STATE=output_final_state,
264
+ num_warps=num_warps,
265
+ num_stages=num_stages
266
+ )
267
+ grid = (NV, NT, B * H)
268
+ o = torch.empty_like(v)
269
+ chunk_linear_attn_fwd_kernel_o[grid](
270
+ q, k, v, h, o,
271
+ q.stride(1), q.stride(2), q.stride(3),
272
+ v.stride(1), v.stride(2), v.stride(3),
273
+ h.stride(1), h.stride(2),
274
+ scale,
275
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
276
+ num_warps=num_warps,
277
+ num_stages=num_stages
278
+ )
279
+ ctx.save_for_backward(q, k, v, h)
280
+ return o.to(q.dtype), final_state
281
+
282
+ @staticmethod
283
+ @contiguous
284
+ @autocast_custom_bwd
285
+ def backward(ctx, do, dht=None):
286
+ q, k, v, h = ctx.saved_tensors
287
+
288
+ B, H, T, K, V = *q.shape, v.shape[-1]
289
+ BT = 64
290
+ BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
291
+ NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
292
+ num_stages = 1
293
+ num_warps = 4 if BK == 64 else 2
294
+ scale = ctx.scale
295
+
296
+ dh = q.new_empty(B, H, NT * K, V)
297
+ grid = (NK, NV, B * H)
298
+ chunk_linear_attn_bwd_kernel_dh[grid](
299
+ q, do, dh,
300
+ q.stride(1), q.stride(2), q.stride(3),
301
+ v.stride(1), v.stride(2), v.stride(3),
302
+ dh.stride(1), dh.stride(2),
303
+ scale,
304
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
305
+ num_warps=num_warps,
306
+ num_stages=num_stages
307
+ )
308
+
309
+ grid = (NK, NT, B * H)
310
+ dq = torch.empty_like(q)
311
+ dk = torch.empty_like(k)
312
+ dv = v.new_empty(NK, *v.shape)
313
+ num_stages = 1
314
+ num_warps = 4 if BK == 64 else 2
315
+ chunk_linear_attn_bwd_kernel_dqkv[grid](
316
+ q, k, v, h, do, dh, dq, dk, dv,
317
+ q.stride(1), q.stride(2), q.stride(3),
318
+ v.stride(1), v.stride(2), v.stride(3),
319
+ dh.stride(1), dh.stride(2),
320
+ scale,
321
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
322
+ num_warps=num_warps,
323
+ num_stages=num_stages
324
+ )
325
+ dv = dv.sum(0)
326
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
327
+
328
+
329
+ def chunk_linear_attn(
330
+ q: torch.Tensor,
331
+ k: torch.Tensor,
332
+ v: torch.Tensor,
333
+ scale: Optional[float] = None,
334
+ initial_state: torch.Tensor = None,
335
+ output_final_state: bool = False,
336
+ normalize: bool = True
337
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
338
+ r"""
339
+ Args:
340
+ q (torch.Tensor):
341
+ queries of shape `(B, H, T, K)`
342
+ k (torch.Tensor):
343
+ keys of shape `(B, H, T, K)`
344
+ v (torch.Tensor):
345
+ values of shape `(B, H, T, V)`
346
+ scale (Optional[int]):
347
+ Scale factor for the linear attention scores.
348
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
349
+ initial_state (Optional[torch.Tensor]):
350
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
351
+ output_final_state (Optional[bool]):
352
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
353
+ normalize (bool):
354
+ Whether to normalize the output. Default: `True`.
355
+ """
356
+ if scale is None:
357
+ scale = q.shape[-1] ** -0.5
358
+ o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
359
+ if normalize:
360
+ o = normalize_output(q * scale, k, o)
361
+ return o, final_state