msj19 commited on
Commit
e28eee0
·
verified ·
1 Parent(s): 983f7b3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla3/modules/__pycache__/l2norm.cpython-312.pyc +0 -0
  2. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc +0 -0
  3. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc +0 -0
  4. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-310.pyc +0 -0
  5. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc +0 -0
  6. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc +0 -0
  7. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc +0 -0
  8. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc +0 -0
  9. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
  10. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc +0 -0
  11. fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc +0 -0
  12. fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  13. fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  14. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc +0 -0
  15. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc +0 -0
  16. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc +0 -0
  17. fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  18. fla3/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +196 -0
  19. fla3/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +173 -0
  20. fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py +273 -0
  21. fla3/ops/generalized_delta_rule/dplr/naive.py +96 -0
  22. fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +284 -0
  23. fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  24. fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc +0 -0
  25. fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc +0 -0
  26. fla3/ops/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  27. fla3/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  28. fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc +0 -0
  29. fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  30. fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  31. fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  32. fla3/ops/gla/chunk.py +1300 -0
  33. fla3/ops/gla/fused_chunk.py +625 -0
  34. fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  35. fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc +0 -0
  36. fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc +0 -0
  37. fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  38. fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  39. fla3/ops/gsa/chunk.py +1136 -0
  40. fla3/ops/gsa/fused_recurrent.py +525 -0
  41. fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc +0 -0
  42. fla3/ops/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  43. fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc +0 -0
  44. fla3/ops/hgrn/__pycache__/chunk.cpython-312.pyc +0 -0
  45. fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  46. fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  47. fla3/ops/hgrn/chunk.py +282 -0
  48. fla3/ops/hgrn/fused_recurrent.py +308 -0
  49. fla3/ops/hgrn/naive.py +63 -0
  50. fla3/ops/lightning_attn/__init__.py +9 -0
fla3/modules/__pycache__/l2norm.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc ADDED
Binary file (25 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc ADDED
Binary file (7.6 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla3/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import exp, gather
12
+ from ....utils import is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
21
+ for num_warps in [2, 4, 8, 16, 32]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BK', 'BT'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi,
34
+ ge,
35
+ qg,
36
+ kg,
37
+ ag,
38
+ bg,
39
+ Aqk,
40
+ Aqb,
41
+ Aab,
42
+ Aak,
43
+ cu_seqlens,
44
+ chunk_indices,
45
+ scale: tl.constexpr,
46
+ T,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ BT: tl.constexpr,
50
+ BC: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ IS_VARLEN: tl.constexpr,
53
+ GATHER_SUPPORTED: tl.constexpr
54
+ ):
55
+ i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+
57
+ if IS_VARLEN:
58
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
59
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ if i_t * BT >= T:
65
+ return
66
+
67
+ o_i = tl.arange(0, BC)
68
+ o_k = tl.arange(0, BK)
69
+ m_k = o_k < K
70
+ m_A = (i_t * BT + tl.arange(0, BC)) < T
71
+ last_idx = min((i_t+1) * BT, T) - 1
72
+ o_A = (bos + i_t * BT + tl.arange(0, BC)) * H*BT + i_h * BT
73
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
74
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
75
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
76
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
77
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
78
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
79
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
80
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
81
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
82
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
83
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
84
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0))
85
+
86
+ b_q = tl.load(p_q, boundary_check=(0, 1))
87
+ b_q = b_q * scale
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ b_a = tl.load(p_a, boundary_check=(0, 1))
90
+ b_b = tl.load(p_b, boundary_check=(0, 1))
91
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
92
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
93
+
94
+ # deal with decay term.
95
+ g_exp = exp(b_gi)
96
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
97
+ b_qg = b_q * g_exp
98
+ b_kg = b_k * g_exp_inv
99
+ b_bg = b_b * g_exp_inv
100
+ b_ag = b_a * exp(b_ge)
101
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
102
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
103
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
104
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
105
+ # tl.debug_barrier()
106
+
107
+ b_q = b_q.to(b_k.dtype)
108
+ # inner attn
109
+ for j in range(0, min(BC, T - i_t * BT)):
110
+ # a trick to index the j-th row of b_k, b_g, b_b
111
+ if GATHER_SUPPORTED:
112
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
113
+ # [1, BK]
114
+ b_k_j = gather(b_k, row_idx, axis=0)
115
+ b_gk_j = gather(b_gi, row_idx, axis=0)
116
+ b_b_j = gather(b_b, row_idx, axis=0)
117
+ else:
118
+ mask = tl.arange(0, BC) == j
119
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
120
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
121
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
122
+ tmp = exp(b_gi - b_gk_j)
123
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
124
+ m_i = (o_i >= j).to(tl.float32)
125
+ b_A_qk = b_A_qk * m_i
126
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
127
+ b_A_qb = b_A_qb * m_i
128
+ tmp2 = exp(b_ge - b_gk_j)
129
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
130
+ m_i2 = (o_i > j).to(tl.float32)
131
+ b_A_ak = b_A_ak * m_i2
132
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
133
+ b_A_ab = b_A_ab * m_i2
134
+
135
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
136
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
137
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
138
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
139
+
140
+
141
+ def chunk_dplr_fwd_intra(
142
+ q: torch.Tensor,
143
+ k: torch.Tensor,
144
+ a: torch.Tensor,
145
+ b: torch.Tensor,
146
+ gi: torch.Tensor,
147
+ ge: torch.Tensor,
148
+ scale: float,
149
+ chunk_size: int,
150
+ cu_seqlens: Optional[torch.LongTensor] = None,
151
+ ):
152
+ B, T, H, K = k.shape
153
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
154
+
155
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
156
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
157
+
158
+ Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype)
159
+ Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype)
160
+ # involving matrix inverse and it'd be better to use float here.
161
+ Aab = q.new_empty(B, T, H, BT, dtype=torch.float)
162
+ Aak = q.new_empty(B, T, H, BT, dtype=torch.float)
163
+
164
+ grid = (NT, B, H)
165
+ BK = triton.next_power_of_2(K)
166
+ qg = torch.empty_like(q)
167
+ kg = torch.empty_like(k, dtype=q.dtype)
168
+ ag = torch.empty_like(a, dtype=q.dtype)
169
+ bg = torch.empty_like(b, dtype=q.dtype)
170
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
171
+ q=q,
172
+ k=k,
173
+ a=a,
174
+ b=b,
175
+ gi=gi,
176
+ ge=ge,
177
+ Aqk=Aqk,
178
+ Aqb=Aqb,
179
+ Aab=Aab,
180
+ Aak=Aak,
181
+ qg=qg,
182
+ kg=kg,
183
+ ag=ag,
184
+ bg=bg,
185
+ cu_seqlens=cu_seqlens,
186
+ chunk_indices=chunk_indices,
187
+ scale=scale,
188
+ T=T,
189
+ H=H,
190
+ K=K,
191
+ BT=BT,
192
+ BC=BT,
193
+ BK=BK,
194
+ GATHER_SUPPORTED=is_gather_supported
195
+ )
196
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla3/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets
11
+ from ....ops.utils.op import exp
12
+ from ....utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ cu_seqlens,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ IS_VARLEN: tl.constexpr,
54
+ ):
55
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ NT = tl.cdiv(T, BT)
61
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
62
+ else:
63
+ bos, eos = i_n * T, i_n * T + T
64
+ NT = tl.cdiv(T, BT)
65
+ boh = i_n * NT
66
+ o_k = i_k * BK + tl.arange(0, BK)
67
+
68
+ # [BK, BV]
69
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_INITIAL_STATE:
71
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
73
+
74
+ for i_t in range(NT):
75
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
79
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
80
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
81
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
82
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
83
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
84
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
85
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
87
+ # [BK, BC]
88
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
89
+ b_v = tl.load(p_v, boundary_check=(0, 1))
90
+ b_w = tl.load(p_w, boundary_check=(0, 1))
91
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
92
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
93
+ b_hc += tl.dot(b_kg, b_v)
94
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
95
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
96
+
97
+ last_idx = min((i_t + 1) * BT, T) - 1
98
+ b_g_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k, mask=o_k < K).to(tl.float32)
99
+ b_h *= exp(b_g_last[:, None])
100
+ b_h += b_hc
101
+
102
+ if STORE_FINAL_STATE:
103
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
104
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
105
+
106
+
107
+ def chunk_dplr_fwd_h(
108
+ kg: torch.Tensor,
109
+ v: torch.Tensor,
110
+ w: torch.Tensor,
111
+ u: torch.Tensor,
112
+ bg: torch.Tensor,
113
+ gk: torch.Tensor,
114
+ initial_state: Optional[torch.Tensor] = None,
115
+ output_final_state: bool = False,
116
+ cu_seqlens: Optional[torch.LongTensor] = None,
117
+ chunk_size: int = 64
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ B, T, H, K, V = *kg.shape, u.shape[-1]
120
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
121
+
122
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
123
+ # N: the actual number of sequences in the batch with either equal or variable lengths
124
+ if cu_seqlens is None:
125
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
126
+ else:
127
+ N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
128
+ BK = triton.next_power_of_2(K)
129
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
130
+ # H100 can have larger block size
131
+
132
+ if check_shared_mem('hopper', kg.device.index):
133
+ BV = 64
134
+ BC = 64 if K <= 128 else 32
135
+ elif check_shared_mem('ampere', kg.device.index): # A100
136
+ BV = 32
137
+ BC = 32
138
+ else:
139
+ BV = 16
140
+ BC = 16
141
+
142
+ BC = min(BT, BC)
143
+ NK = triton.cdiv(K, BK)
144
+ NV = triton.cdiv(V, BV)
145
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
146
+
147
+ h = kg.new_empty(B, NT, H, K, V)
148
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
149
+ v_new = torch.empty_like(u)
150
+ grid = (NK, NV, N * H)
151
+ chunk_dplr_fwd_kernel_h[grid](
152
+ kg=kg,
153
+ v=v,
154
+ w=w,
155
+ bg=bg,
156
+ u=u,
157
+ v_new=v_new,
158
+ h=h,
159
+ gk=gk,
160
+ h0=initial_state,
161
+ ht=final_state,
162
+ cu_seqlens=cu_seqlens,
163
+ chunk_offsets=chunk_offsets,
164
+ T=T,
165
+ H=H,
166
+ K=K,
167
+ V=V,
168
+ BT=BT,
169
+ BC=BC,
170
+ BK=BK,
171
+ BV=BV,
172
+ )
173
+ return h, v_new, final_state
fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils.op import exp
11
+ from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ cu_seqlens,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ IS_VARLEN: tl.constexpr,
53
+ ):
54
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
55
+ i_n, i_h = i_nh // H, i_nh % H
56
+
57
+ if IS_VARLEN:
58
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
59
+ T = eos - bos
60
+ else:
61
+ bos, eos = i_n * T, i_n * T + T
62
+
63
+ o_k = tl.arange(0, BK)
64
+ o_v = i_v * BV + tl.arange(0, BV)
65
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
66
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
67
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
68
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
69
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
70
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
71
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
72
+
73
+ mask_k = o_k < K
74
+ mask_v = o_v < V
75
+ mask_h = mask_k[None, :] & mask_v[:, None]
76
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
77
+
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
80
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
81
+
82
+ for _ in range(0, T):
83
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
84
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
86
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
87
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
88
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
89
+
90
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
91
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
92
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
93
+
94
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
95
+ p_q += (-1 if REVERSE else 1) * H*K
96
+ p_k += (-1 if REVERSE else 1) * H*K
97
+ p_a += (-1 if REVERSE else 1) * H*K
98
+ p_b += (-1 if REVERSE else 1) * H*K
99
+ p_gk += (-1 if REVERSE else 1) * H*K
100
+ p_v += (-1 if REVERSE else 1) * H*V
101
+ p_o += (-1 if REVERSE else 1) * H*V
102
+
103
+ if STORE_FINAL_STATE:
104
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
105
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
106
+
107
+
108
+ def fused_recurrent_dplr_delta_rule_fwd(
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ a: torch.Tensor,
113
+ b: torch.Tensor,
114
+ gk: torch.Tensor,
115
+ scale: Optional[float] = 1.0,
116
+ initial_state: Optional[torch.Tensor] = None,
117
+ output_final_state: bool = False,
118
+ reverse: bool = False,
119
+ cu_seqlens: Optional[torch.LongTensor] = None,
120
+ ):
121
+ B, T, H, K, V = *k.shape, v.shape[-1]
122
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
123
+ BK = triton.next_power_of_2(K)
124
+
125
+ h0 = initial_state
126
+ if output_final_state:
127
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
128
+ else:
129
+ ht = None
130
+ o = torch.empty_like(v)
131
+
132
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
133
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
134
+ q,
135
+ k,
136
+ v,
137
+ a,
138
+ b,
139
+ gk,
140
+ o,
141
+ h0,
142
+ ht,
143
+ cu_seqlens,
144
+ scale,
145
+ T=T,
146
+ B=B,
147
+ H=H,
148
+ K=K,
149
+ V=V,
150
+ BK=BK,
151
+ REVERSE=reverse,
152
+ )
153
+ return o, ht
154
+
155
+
156
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ @input_guard
160
+ @autocast_custom_fwd
161
+ def forward(
162
+ ctx,
163
+ q: torch.Tensor,
164
+ k: torch.Tensor,
165
+ v: torch.Tensor,
166
+ a: torch.Tensor,
167
+ b: torch.Tensor,
168
+ gk: torch.Tensor,
169
+ scale: Optional[float] = 1.0,
170
+ initial_state: Optional[torch.Tensor] = None,
171
+ output_final_state: bool = False,
172
+ reverse: bool = False,
173
+ cu_seqlens: Optional[torch.LongTensor] = None,
174
+ ):
175
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
176
+ q=q,
177
+ k=k,
178
+ v=v,
179
+ a=a,
180
+ b=b,
181
+ gk=gk,
182
+ scale=scale,
183
+ initial_state=initial_state,
184
+ output_final_state=output_final_state,
185
+ reverse=reverse,
186
+ cu_seqlens=cu_seqlens,
187
+ )
188
+ return o, ht
189
+
190
+ @staticmethod
191
+ @input_guard
192
+ @autocast_custom_bwd
193
+ def backward(ctx, do, dht):
194
+ raise NotImplementedError(
195
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
196
+ "This kernel is only for inference. "
197
+ "For training, please use `chunk_dplr_delta_rule`."
198
+ )
199
+
200
+
201
+ def fused_recurrent_dplr_delta_rule(
202
+ q: torch.Tensor,
203
+ k: torch.Tensor,
204
+ v: torch.Tensor,
205
+ a: torch.Tensor,
206
+ b: torch.Tensor,
207
+ gk: torch.Tensor,
208
+ scale: Optional[float] = 1.0,
209
+ initial_state: Optional[torch.Tensor] = None,
210
+ output_final_state: bool = False,
211
+ reverse: bool = False,
212
+ cu_seqlens: Optional[torch.Tensor] = None,
213
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
214
+ r"""
215
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
216
+
217
+ Args:
218
+ q (torch.Tensor):
219
+ queries of shape `[B, T, H, K]`.
220
+ k (torch.Tensor):
221
+ keys of shape `[B, T, H, K]`.
222
+ v (torch.Tensor):
223
+ values of shape `[B, T, H, V]`.
224
+ a (torch.Tensor):
225
+ a of shape `[B, T, H, K]`.
226
+ b (torch.Tensor):
227
+ b of shape `[B, T, H, K]`.
228
+ gk (torch.Tensor):
229
+ gk of shape `[B, T, H, K]`. decay term in log space!
230
+ scale (Optional[int]):
231
+ Scale factor for the RetNet attention scores.
232
+ If not provided, it will default to `1 / sqrt(K)`. Default: 1.
233
+ initial_state (Optional[torch.Tensor]):
234
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
235
+ For equal-length input sequences, `N` equals the batch size `B`.
236
+ Default: `None`.
237
+ output_final_state (Optional[bool]):
238
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
239
+ reverse (Optional[bool]):
240
+ If `True`, process the state passing in reverse order. Default: `False`.
241
+ cu_seqlens (Optional[torch.Tensor]):
242
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
243
+ consistent with the FlashAttention API.
244
+ """
245
+ if cu_seqlens is not None:
246
+ if q.shape[0] != 1:
247
+ raise ValueError(
248
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
249
+ f"Please flatten variable-length inputs before processing."
250
+ )
251
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
252
+ raise ValueError(
253
+ f"The number of initial states is expected to be equal to the number of input sequences, "
254
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
255
+ )
256
+ if scale is None:
257
+ scale = q.shape[-1] ** -0.5
258
+ else:
259
+ assert scale > 0, "scale must be positive"
260
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
261
+ q,
262
+ k,
263
+ v,
264
+ a,
265
+ b,
266
+ gk,
267
+ scale,
268
+ initial_state,
269
+ output_final_state,
270
+ reverse,
271
+ cu_seqlens,
272
+ )
273
+ return o, final_state
fla3/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ....ops.utils import prepare_chunk_indices
11
+ from ....ops.utils.op import gather
12
+ from ....utils import is_gather_supported, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({}, num_warps=num_warps)
21
+ for num_warps in [1, 2, 4, 8, 16]
22
+ ],
23
+ key=['BT'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def prepare_wy_repr_fwd_kernel_chunk32(
28
+ A_ab,
29
+ A_ab_inv,
30
+ cu_seqlens,
31
+ chunk_indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BC: tl.constexpr, # placeholder, do not delete
36
+ IS_VARLEN: tl.constexpr,
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
47
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
49
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
50
+ for i in range(1, BT):
51
+ mask = tl.arange(0, BT) == i
52
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
53
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
54
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
55
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
56
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
57
+
58
+
59
+ @triton.heuristics({
60
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
61
+ })
62
+ @triton.autotune(
63
+ configs=[
64
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
65
+ for num_warps in [2, 4, 8]
66
+ for num_stages in [2, 3, 4]
67
+ ],
68
+ key=['BC'],
69
+ use_cuda_graph=use_cuda_graph,
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def prepare_wy_repr_fwd_kernel_chunk64(
73
+ A_ab,
74
+ A_ab_inv,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ BT: tl.constexpr,
80
+ BC: tl.constexpr,
81
+ IS_VARLEN: tl.constexpr,
82
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
83
+ ):
84
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
88
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
89
+ T = eos - bos
90
+ else:
91
+ bos, eos = i_b * T, i_b * T + T
92
+
93
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
94
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
95
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
96
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
97
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
98
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
99
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
100
+
101
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
102
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
103
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
104
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
105
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
106
+
107
+ for i in range(1, BC):
108
+ if GATHER_SUPPORTED:
109
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
110
+ # [1, BK] -> [BK]
111
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
112
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
113
+ else:
114
+ mask = tl.arange(0, BC) == i
115
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
116
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
117
+ mask = tl.arange(0, BC) == i
118
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
119
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
120
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
121
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
122
+ b_A = tl.where(mask[:, None], b_a, b_A)
123
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
124
+
125
+ # blockwise computation of lower triangular matrix's inverse
126
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
127
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
128
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
129
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
130
+ # tl.debug_barrier()
131
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
132
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
133
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
134
+ # causal mask
135
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
136
+
137
+
138
+ @triton.heuristics({
139
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
140
+ })
141
+ @triton.autotune(
142
+ configs=[
143
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
144
+ for num_warps in [2, 4, 8, 16]
145
+ for num_stages in [2, 3, 4]
146
+ ],
147
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
148
+ use_cuda_graph=use_cuda_graph,
149
+ )
150
+ @triton.jit(do_not_specialize=['T'])
151
+ def wu_fwd_kernel(
152
+ w,
153
+ u,
154
+ ag,
155
+ v,
156
+ A_ab_inv,
157
+ A_ak,
158
+ cu_seqlens,
159
+ chunk_indices,
160
+ T,
161
+ H: tl.constexpr,
162
+ K: tl.constexpr,
163
+ V: tl.constexpr,
164
+ BT: tl.constexpr,
165
+ BK: tl.constexpr,
166
+ BV: tl.constexpr,
167
+ IS_VARLEN: tl.constexpr,
168
+ ):
169
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
170
+ i_b, i_h = i_bh // H, i_bh % H
171
+ if IS_VARLEN:
172
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
173
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
174
+ T = eos - bos
175
+ else:
176
+ bos, eos = i_b * T, i_b * T + T
177
+ o_s = tl.arange(0, BT)
178
+
179
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
180
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+
182
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
183
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
184
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
185
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
186
+ # let's use tf32 here
187
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
188
+ # (SY 01/04) should be bf16 or tf32? To verify.
189
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
190
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
191
+
192
+ for i_k in range(tl.cdiv(K, BK)):
193
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
195
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
196
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
197
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
198
+
199
+ for i_v in range(tl.cdiv(V, BV)):
200
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
201
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
202
+ b_v = tl.load(p_v, boundary_check=(0, 1))
203
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
204
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
205
+
206
+
207
+ def wu_fwd(
208
+ ag: torch.Tensor,
209
+ v: torch.Tensor,
210
+ A_ak: torch.Tensor,
211
+ A_ab_inv: torch.Tensor,
212
+ cu_seqlens: Optional[torch.LongTensor],
213
+ chunk_size: int
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+ B, T, H, K, V = *ag.shape, v.shape[-1]
216
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
217
+
218
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
219
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
220
+ BK = min(triton.next_power_of_2(K), 64)
221
+ BV = min(triton.next_power_of_2(V), 64)
222
+
223
+ w = torch.empty_like(ag)
224
+ u = torch.empty_like(v)
225
+ wu_fwd_kernel[(NT, B * H)](
226
+ ag=ag,
227
+ v=v,
228
+ A_ak=A_ak,
229
+ A_ab_inv=A_ab_inv,
230
+ w=w,
231
+ u=u,
232
+ cu_seqlens=cu_seqlens,
233
+ chunk_indices=chunk_indices,
234
+ T=T,
235
+ H=H,
236
+ K=K,
237
+ V=V,
238
+ BT=BT,
239
+ BK=BK,
240
+ BV=BV,
241
+ )
242
+ return w, u
243
+
244
+
245
+ def prepare_wy_repr_fwd(
246
+ ag: torch.Tensor,
247
+ v: torch.Tensor,
248
+ A_ak: torch.Tensor,
249
+ A_ab: torch.Tensor,
250
+ cu_seqlens: Optional[torch.LongTensor],
251
+ chunk_size: int = 64
252
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
253
+ B, T, H, _ = ag.shape
254
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
255
+
256
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
257
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
258
+ BC = min(BT, 32)
259
+ fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32
260
+ A_ab_inv = torch.empty_like(A_ab)
261
+ fwd_fn[(NT, B * H)](
262
+ A_ab=A_ab,
263
+ A_ab_inv=A_ab_inv,
264
+ cu_seqlens=cu_seqlens,
265
+ chunk_indices=chunk_indices,
266
+ T=T,
267
+ H=H,
268
+ BT=BT,
269
+ BC=BC,
270
+ )
271
+ w, u = wu_fwd(
272
+ ag=ag,
273
+ v=v,
274
+ A_ak=A_ak,
275
+ A_ab_inv=A_ab_inv,
276
+ cu_seqlens=cu_seqlens,
277
+ chunk_size=BT
278
+ )
279
+ return w, u, A_ab_inv
280
+
281
+
282
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
283
+
284
+ fwd_wu = wu_fwd
fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (25.6 kB). View file
 
fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc ADDED
Binary file (7.86 kB). View file
 
fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (19.7 kB). View file
 
fla3/ops/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (371 Bytes). View file
 
fla3/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (66.3 kB). View file
 
fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (35 kB). View file
 
fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (4.83 kB). View file
 
fla3/ops/gla/chunk.py ADDED
@@ -0,0 +1,1300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
11
+ from fla.ops.utils import prepare_chunk_indices
12
+ from fla.ops.utils.cumsum import chunk_local_cumsum
13
+ from fla.ops.utils.op import exp, safe_exp
14
+ from fla.utils import check_shared_mem, input_guard
15
+
16
+ BK_LIST = [32, 64] if check_shared_mem() else [16, 32]
17
+ BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32]
18
+
19
+
20
+ @triton.heuristics({
21
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
26
+ for BK in [32, 64]
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=["BC"]
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gla_fwd_A_kernel_intra_sub_inter(
34
+ q,
35
+ k,
36
+ g,
37
+ A,
38
+ cu_seqlens,
39
+ chunk_indices,
40
+ scale,
41
+ T,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BC: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ NC: tl.constexpr,
48
+ IS_VARLEN: tl.constexpr,
49
+ ):
50
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+ i_i, i_j = i_c // NC, i_c % NC
53
+ if IS_VARLEN:
54
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
55
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
56
+ T = eos - bos
57
+ else:
58
+ bos, eos = i_b * T, i_b * T + T
59
+
60
+ if i_t * BT + i_i * BC >= T:
61
+ return
62
+ if i_i <= i_j:
63
+ return
64
+
65
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
66
+ for i_k in range(tl.cdiv(K, BK)):
67
+ o_k = i_k * BK + tl.arange(0, BK)
68
+ m_k = o_k < K
69
+
70
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
71
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
72
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
73
+ p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
74
+ p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
75
+
76
+ # [BK,]
77
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
78
+ # [BC, BK]
79
+ b_q = tl.load(p_q, boundary_check=(0, 1))
80
+ b_g = tl.load(p_g, boundary_check=(0, 1))
81
+ b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
82
+ # [BK, BC]
83
+ b_k = tl.load(p_k, boundary_check=(0, 1))
84
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
85
+ b_kg = b_k * exp(b_gn[:, None] - b_gk)
86
+ # [BC, BC] using tf32 to improve precision here.
87
+ b_A += tl.dot(b_qg, b_kg)
88
+
89
+ p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
90
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
91
+
92
+
93
+ @triton.heuristics({
94
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
95
+ })
96
+ @triton.autotune(
97
+ configs=[
98
+ triton.Config({}, num_warps=1),
99
+ triton.Config({}, num_warps=2),
100
+ triton.Config({}, num_warps=4),
101
+ triton.Config({}, num_warps=8),
102
+ ],
103
+ key=["BK", "BT"]
104
+ )
105
+ @triton.jit(do_not_specialize=['T'])
106
+ def chunk_gla_fwd_A_kernel_intra_sub_intra(
107
+ q,
108
+ k,
109
+ g,
110
+ A,
111
+ cu_seqlens,
112
+ chunk_indices,
113
+ scale,
114
+ T,
115
+ H: tl.constexpr,
116
+ K: tl.constexpr,
117
+ BT: tl.constexpr,
118
+ BC: tl.constexpr,
119
+ BK: tl.constexpr,
120
+ IS_VARLEN: tl.constexpr,
121
+ ):
122
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
123
+ i_b, i_h = i_bh // H, i_bh % H
124
+ i_j = i_i
125
+ if IS_VARLEN:
126
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
127
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
128
+ T = eos - bos
129
+ else:
130
+ bos, eos = i_b * T, i_b * T + T
131
+
132
+ if i_t * BT + i_i * BC >= T:
133
+ return
134
+
135
+ o_i = tl.arange(0, BC)
136
+ o_k = tl.arange(0, BK)
137
+ m_k = o_k < K
138
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
139
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
140
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
141
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
142
+ p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
143
+ p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
144
+
145
+ b_q = tl.load(p_q, boundary_check=(0, 1))
146
+ b_g = tl.load(p_g, boundary_check=(0, 1))
147
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
148
+ b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)
149
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
150
+ b_A = tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1)
151
+ b_A = tl.where(o_i >= j, b_A * scale, 0.)
152
+
153
+ tl.store(A + o_A + j, b_A, mask=m_A)
154
+ p_k += H*K
155
+ p_gk += H*K
156
+
157
+
158
+ @triton.heuristics({
159
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
160
+ })
161
+ @triton.autotune(
162
+ configs=[
163
+ triton.Config({}, num_warps=1),
164
+ triton.Config({}, num_warps=2),
165
+ triton.Config({}, num_warps=4),
166
+ triton.Config({}, num_warps=8),
167
+ ],
168
+ key=['BC', 'BK']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_gla_fwd_A_kernel_intra_sub_intra_split(
172
+ q,
173
+ k,
174
+ g,
175
+ A,
176
+ cu_seqlens,
177
+ chunk_indices,
178
+ scale,
179
+ T,
180
+ B: tl.constexpr,
181
+ H: tl.constexpr,
182
+ K: tl.constexpr,
183
+ BT: tl.constexpr,
184
+ BC: tl.constexpr,
185
+ BK: tl.constexpr,
186
+ NC: tl.constexpr,
187
+ IS_VARLEN: tl.constexpr,
188
+ ):
189
+ i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
190
+ i_b, i_h = i_bh // H, i_bh % H
191
+ i_t, i_i = i_tc // NC, i_tc % NC
192
+ i_j = i_i
193
+ if IS_VARLEN:
194
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
195
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
196
+ all = T
197
+ T = eos - bos
198
+ else:
199
+ bos, eos = i_b * T, i_b * T + T
200
+ all = B * T
201
+
202
+ if i_t * BT + i_i * BC >= T:
203
+ return
204
+
205
+ o_i = tl.arange(0, BC)
206
+ o_k = i_k * BK + tl.arange(0, BK)
207
+ m_k = o_k < K
208
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
209
+
210
+ o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC
211
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
212
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
213
+ p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
214
+ p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
215
+
216
+ b_q = tl.load(p_q, boundary_check=(0, 1))
217
+ b_g = tl.load(p_g, boundary_check=(0, 1))
218
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
219
+ b_A = tl.zeros([BC], dtype=tl.float32)
220
+ b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)
221
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
222
+ b_A += tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1)
223
+ b_A = tl.where(o_i >= j, b_A * scale, 0.)
224
+ tl.store(A + o_A + j, b_A, mask=m_A)
225
+ p_k += H*K
226
+ p_gk += H*K
227
+
228
+
229
+ @triton.heuristics({
230
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
231
+ })
232
+ @triton.autotune(
233
+ configs=[
234
+ triton.Config({}, num_warps=1),
235
+ triton.Config({}, num_warps=2),
236
+ triton.Config({}, num_warps=4),
237
+ triton.Config({}, num_warps=8),
238
+ ],
239
+ key=['BC']
240
+ )
241
+ @triton.jit(do_not_specialize=['T'])
242
+ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge(
243
+ A,
244
+ A2,
245
+ cu_seqlens,
246
+ chunk_indices,
247
+ T,
248
+ B: tl.constexpr,
249
+ H: tl.constexpr,
250
+ BT: tl.constexpr,
251
+ BC: tl.constexpr,
252
+ NK: tl.constexpr,
253
+ IS_VARLEN: tl.constexpr,
254
+ ):
255
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
256
+ i_b, i_h = i_bh // H, i_bh % H
257
+ if IS_VARLEN:
258
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
259
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
260
+ all = T
261
+ T = eos - bos
262
+ else:
263
+ bos, eos = i_b * T, i_b * T + T
264
+ all = B * T
265
+
266
+ if i_t * BT + i_c * BC >= T:
267
+ return
268
+
269
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
270
+ for i_k in range(0, NK):
271
+ p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0))
272
+ b_A += tl.load(p_A, boundary_check=(0, 1))
273
+ p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0))
274
+ tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1))
275
+
276
+
277
+ @triton.heuristics({
278
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
279
+ })
280
+ @triton.autotune(
281
+ configs=[
282
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
283
+ for BK in [32, 64]
284
+ for BV in [64, 128]
285
+ for num_warps in [2, 4, 8]
286
+ ],
287
+ key=['BT'],
288
+ )
289
+ @triton.jit(do_not_specialize=['T'])
290
+ def chunk_gla_fwd_kernel_o(
291
+ q,
292
+ v,
293
+ g,
294
+ h,
295
+ o,
296
+ A,
297
+ cu_seqlens,
298
+ chunk_indices,
299
+ scale,
300
+ T,
301
+ H: tl.constexpr,
302
+ K: tl.constexpr,
303
+ V: tl.constexpr,
304
+ BT: tl.constexpr,
305
+ BK: tl.constexpr,
306
+ BV: tl.constexpr,
307
+ IS_VARLEN: tl.constexpr,
308
+ ):
309
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
310
+ i_b, i_h = i_bh // H, i_bh % H
311
+ if IS_VARLEN:
312
+ i_tg = i_t
313
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
314
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
315
+ T = eos - bos
316
+ NT = tl.cdiv(T, BT)
317
+ else:
318
+ NT = tl.cdiv(T, BT)
319
+ i_tg = i_b * NT + i_t
320
+ bos, eos = i_b * T, i_b * T + T
321
+
322
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
323
+
324
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
325
+ for i_k in range(tl.cdiv(K, BK)):
326
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
327
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
328
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
329
+
330
+ # [BT, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BT, BK]
334
+ b_g = tl.load(p_g, boundary_check=(0, 1))
335
+ # [BT, BK]
336
+ b_qg = (b_q * exp(b_g)).to(b_q.dtype)
337
+ # [BK, BV]
338
+ b_h = tl.load(p_h, boundary_check=(0, 1))
339
+ # works but dkw, owing to divine benevolence
340
+ # [BT, BV]
341
+ if i_k >= 0:
342
+ b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
343
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
344
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
345
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
346
+ # [BT, BV]
347
+ b_v = tl.load(p_v, boundary_check=(0, 1))
348
+ # [BT, BT]
349
+ b_A = tl.load(p_A, boundary_check=(0, 1))
350
+ b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype)
351
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
352
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
353
+
354
+
355
+ @triton.heuristics({
356
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
357
+ })
358
+ @triton.autotune(
359
+ configs=[
360
+ triton.Config({}, num_warps=num_warps)
361
+ for num_warps in [1, 2, 4, 8]
362
+ ],
363
+ key=['BK', 'NC', 'BT'],
364
+ )
365
+ @triton.jit(do_not_specialize=['T'])
366
+ def chunk_gla_bwd_kernel_intra(
367
+ q,
368
+ k,
369
+ g,
370
+ dA,
371
+ dq,
372
+ dk,
373
+ cu_seqlens,
374
+ chunk_indices,
375
+ T,
376
+ H: tl.constexpr,
377
+ K: tl.constexpr,
378
+ BT: tl.constexpr,
379
+ BC: tl.constexpr,
380
+ BK: tl.constexpr,
381
+ NC: tl.constexpr,
382
+ IS_VARLEN: tl.constexpr,
383
+ ):
384
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
385
+ i_b, i_h = i_bh // H, i_bh % H
386
+ i_t, i_i = i_c // NC, i_c % NC
387
+ if IS_VARLEN:
388
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
389
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
390
+ else:
391
+ bos, eos = i_b * T, i_b * T + T
392
+ T = eos - bos
393
+ if i_t * BT + i_i * BC >= T:
394
+ return
395
+
396
+ o_k = i_k * BK + tl.arange(0, BK)
397
+ m_k = o_k < K
398
+
399
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
400
+ # [BC, BK]
401
+ b_g = tl.load(p_g, boundary_check=(0, 1))
402
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
403
+ if i_i > 0:
404
+ p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k
405
+
406
+ # [BK,]
407
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
408
+ for i_j in range(0, i_i):
409
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0))
410
+ p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0))
411
+ p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
412
+ # [BC, BK]
413
+ b_k = tl.load(p_k, boundary_check=(0, 1))
414
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
415
+ b_kg = (b_k * exp(b_gn[None, :] - b_gk))
416
+ # [BC, BC]
417
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
418
+ # [BC, BK]
419
+ b_dq += tl.dot(b_dA, b_kg)
420
+ b_dq *= exp(b_g - b_gn[None, :])
421
+
422
+ o_i = tl.arange(0, BC)
423
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
424
+ o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC
425
+ p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
426
+ p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
427
+ p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
428
+
429
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
430
+ # [BC,]
431
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
432
+ # [BK,]
433
+ b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32)
434
+ b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32)
435
+ # [BC, BK]
436
+ m_i = o_i[:, None] >= j
437
+ # [BC, BK]
438
+ # (SY 09/17) important to not use bf16 here to have a good precision.
439
+ b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.)
440
+ p_kj += H*K
441
+ p_gkj += H*K
442
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
443
+
444
+ tl.debug_barrier()
445
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
446
+ p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
447
+
448
+ # [BC, BK]
449
+ b_k = tl.load(p_k, boundary_check=(0, 1))
450
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
451
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
452
+
453
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
454
+ if i_i < NC - 1:
455
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k
456
+
457
+ # [BK,]
458
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
459
+ for i_j in range(i_i + 1, NC):
460
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0))
461
+ p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0))
462
+ p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
463
+ # [BC, BK]
464
+ b_q = tl.load(p_q, boundary_check=(0, 1))
465
+ b_gq = tl.load(p_gq, boundary_check=(0, 1))
466
+ b_qg = b_q * safe_exp(b_gq - b_gn[None, :])
467
+ # [BC, BC]
468
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
469
+ # [BC, BK]
470
+ # (SY 09/17) important to not use bf16 here to have a good precision.
471
+ b_dk += tl.dot(b_dA, b_qg)
472
+ b_dk *= exp(b_gn[None, :] - b_gk)
473
+ o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC)
474
+ p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
475
+ p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
476
+ p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
477
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
478
+ # [BC,]
479
+ b_dA = tl.load(dA + o_dA + j * H*BT)
480
+ # [BK,]
481
+ b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32)
482
+ b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32)
483
+ # [BC, BK]
484
+ m_i = o_i[:, None] <= j
485
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.)
486
+ p_qj += H*K
487
+ p_gqj += H*K
488
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
489
+
490
+
491
+ @triton.heuristics({
492
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
493
+ })
494
+ @triton.autotune(
495
+ configs=[
496
+ triton.Config({}, num_warps=1),
497
+ triton.Config({}, num_warps=2),
498
+ triton.Config({}, num_warps=4),
499
+ triton.Config({}, num_warps=8),
500
+ ],
501
+ key=['BV', 'BT'],
502
+ )
503
+ @triton.jit(do_not_specialize=['T'])
504
+ def chunk_gla_bwd_kernel_dA(
505
+ v,
506
+ do,
507
+ dA,
508
+ cu_seqlens,
509
+ chunk_indices,
510
+ scale,
511
+ T,
512
+ H: tl.constexpr,
513
+ V: tl.constexpr,
514
+ BT: tl.constexpr,
515
+ BV: tl.constexpr,
516
+ IS_VARLEN: tl.constexpr,
517
+ ):
518
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
519
+ i_b, i_h = i_bh // H, i_bh % H
520
+ if IS_VARLEN:
521
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
523
+ else:
524
+ bos, eos = i_b * T, i_b * T + T
525
+ T = eos - bos
526
+
527
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
528
+ for i_v in range(tl.cdiv(V, BV)):
529
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
530
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
531
+ b_v = tl.load(p_v, boundary_check=(0, 1))
532
+ b_do = tl.load(p_do, boundary_check=(0, 1))
533
+ b_dA += tl.dot(b_do, b_v)
534
+ p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
535
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
536
+ b_dA = tl.where(m_s, b_dA * scale, 0.)
537
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
538
+
539
+
540
+ @triton.heuristics({
541
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
542
+ })
543
+ @triton.autotune(
544
+ configs=[
545
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
546
+ for BK in BK_LIST
547
+ for BV in BV_LIST
548
+ for num_warps in [2, 4, 8]
549
+ ],
550
+ key=['BT'],
551
+ )
552
+ @triton.jit(do_not_specialize=['T'])
553
+ def chunk_gla_bwd_kernel_dv(
554
+ k,
555
+ g,
556
+ A,
557
+ do,
558
+ dh,
559
+ dv,
560
+ cu_seqlens,
561
+ chunk_indices,
562
+ T,
563
+ H: tl.constexpr,
564
+ K: tl.constexpr,
565
+ V: tl.constexpr,
566
+ BT: tl.constexpr,
567
+ BK: tl.constexpr,
568
+ BV: tl.constexpr,
569
+ IS_VARLEN: tl.constexpr,
570
+ ):
571
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
572
+ i_b, i_h = i_bh // H, i_bh % H
573
+ if IS_VARLEN:
574
+ i_tg = i_t
575
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
576
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
577
+ T = eos - bos
578
+ NT = tl.cdiv(T, BT)
579
+ else:
580
+ NT = tl.cdiv(T, BT)
581
+ i_tg = i_b * NT + i_t
582
+ bos, eos = i_b * T, i_b * T + T
583
+
584
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
585
+ p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
586
+ p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
587
+
588
+ b_A = tl.load(p_A, boundary_check=(0, 1))
589
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.)
590
+ b_do = tl.load(p_do, boundary_check=(0, 1))
591
+ # (SY 09/17) important to disallow tf32 here to maintain a good precision.
592
+ b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False)
593
+
594
+ for i_k in range(tl.cdiv(K, BK)):
595
+ o_k = i_k * BK + tl.arange(0, BK)
596
+ m_k = o_k < K
597
+
598
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
599
+ p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
600
+ p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k
601
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
602
+
603
+ b_k = tl.load(p_k, boundary_check=(0, 1))
604
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
605
+ b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk)
606
+ b_k = (b_k * b_gn).to(b_k.dtype)
607
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
608
+ # [BT, BV]
609
+ # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here
610
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
611
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
612
+
613
+
614
+ @triton.heuristics({
615
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
616
+ })
617
+ @triton.autotune(
618
+ configs=[
619
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
620
+ for BK in BK_LIST
621
+ for BV in BV_LIST
622
+ for num_warps in [2, 4, 8]
623
+ ],
624
+ key=['BT'],
625
+ )
626
+ @triton.jit(do_not_specialize=['T'])
627
+ def chunk_gla_bwd_kernel_inter(
628
+ q,
629
+ k,
630
+ v,
631
+ h,
632
+ g,
633
+ do,
634
+ dh,
635
+ dq,
636
+ dk,
637
+ dq2,
638
+ dk2,
639
+ dg,
640
+ cu_seqlens,
641
+ chunk_indices,
642
+ scale,
643
+ T,
644
+ H: tl.constexpr,
645
+ K: tl.constexpr,
646
+ V: tl.constexpr,
647
+ BT: tl.constexpr,
648
+ BK: tl.constexpr,
649
+ BV: tl.constexpr,
650
+ IS_VARLEN: tl.constexpr,
651
+ ):
652
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
653
+ i_b, i_h = i_bh // H, i_bh % H
654
+ if IS_VARLEN:
655
+ i_tg = i_t
656
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
657
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
658
+ T = eos - bos
659
+ NT = tl.cdiv(T, BT)
660
+ else:
661
+ NT = tl.cdiv(T, BT)
662
+ i_tg = i_b * NT + i_t
663
+ bos, eos = i_b * T, i_b * T + T
664
+ o_k = i_k * BK + tl.arange(0, BK)
665
+ m_k = o_k < K
666
+
667
+ p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
668
+ p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k
669
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
670
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
671
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
672
+ b_dgk = tl.zeros([BK,], dtype=tl.float32)
673
+
674
+ for i_v in range(tl.cdiv(V, BV)):
675
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
676
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
677
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
678
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
679
+ # [BT, BV]
680
+ b_v = tl.load(p_v, boundary_check=(0, 1))
681
+ b_do = tl.load(p_do, boundary_check=(0, 1))
682
+ # [BV, BK]
683
+ b_h = tl.load(p_h, boundary_check=(0, 1))
684
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
685
+ # [BK]
686
+ b_dgk += tl.sum(b_h * b_dh, axis=0)
687
+ # [BT, BK]
688
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
689
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
690
+ b_dgk *= exp(b_gn)
691
+ b_dq *= scale
692
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
693
+ b_dq = b_dq * exp(b_gk)
694
+ b_dk = b_dk * exp(b_gn[None, :] - b_gk)
695
+
696
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
697
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
698
+ p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
699
+ p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
700
+ b_q = tl.load(p_q, boundary_check=(0, 1))
701
+ b_k = tl.load(p_k, boundary_check=(0, 1))
702
+ b_dgk += tl.sum(b_dk * b_k, axis=0)
703
+ b_dq += tl.load(p_dq, boundary_check=(0, 1))
704
+ b_dk += tl.load(p_dk, boundary_check=(0, 1))
705
+ b_dg = b_q * b_dq - b_k * b_dk
706
+ # tl.debug_barrier()
707
+ b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :]
708
+ # Buggy due to strange triton compiler issue.
709
+ # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.)
710
+ # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :]
711
+ p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
712
+ p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
713
+ p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
714
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
715
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
716
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
717
+
718
+
719
+ def chunk_gla_fwd_intra_gk(
720
+ q: torch.Tensor,
721
+ k: torch.Tensor,
722
+ g: torch.Tensor,
723
+ scale: float,
724
+ cu_seqlens: Optional[torch.LongTensor] = None,
725
+ chunk_size: int = 64
726
+ ):
727
+ B, T, H, K = k.shape
728
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
729
+
730
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
731
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
732
+ BC = min(16, BT)
733
+ NC = triton.cdiv(BT, BC)
734
+
735
+ A = q.new_empty(B, T, H, BT, dtype=torch.float)
736
+ grid = (NT, NC * NC, B * H)
737
+ chunk_gla_fwd_A_kernel_intra_sub_inter[grid](
738
+ q,
739
+ k,
740
+ g,
741
+ A,
742
+ cu_seqlens,
743
+ chunk_indices,
744
+ scale,
745
+ T=T,
746
+ H=H,
747
+ K=K,
748
+ BT=BT,
749
+ BC=BC,
750
+ NC=NC,
751
+ )
752
+
753
+ grid = (NT, NC, B * H)
754
+ # load the entire [BC, K] blocks into SRAM at once
755
+ if K <= 256:
756
+ BK = triton.next_power_of_2(K)
757
+ chunk_gla_fwd_A_kernel_intra_sub_intra[grid](
758
+ q,
759
+ k,
760
+ g,
761
+ A,
762
+ cu_seqlens,
763
+ chunk_indices,
764
+ scale,
765
+ T=T,
766
+ H=H,
767
+ K=K,
768
+ BT=BT,
769
+ BC=BC,
770
+ BK=BK,
771
+ )
772
+ # split then merge
773
+ else:
774
+ BK = min(128, triton.next_power_of_2(K))
775
+ NK = triton.cdiv(K, BK)
776
+ A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float)
777
+
778
+ grid = (NK, NT * NC, B * H)
779
+ chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid](
780
+ q,
781
+ k,
782
+ g,
783
+ A_intra,
784
+ cu_seqlens,
785
+ chunk_indices,
786
+ scale,
787
+ T=T,
788
+ B=B,
789
+ H=H,
790
+ K=K,
791
+ BT=BT,
792
+ BC=BC,
793
+ BK=BK,
794
+ NC=NC,
795
+ )
796
+
797
+ grid = (NT, NC, B * H)
798
+ chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid](
799
+ A_intra,
800
+ A,
801
+ cu_seqlens,
802
+ chunk_indices,
803
+ T=T,
804
+ B=B,
805
+ H=H,
806
+ BT=BT,
807
+ BC=BC,
808
+ NK=NK,
809
+ )
810
+ return A
811
+
812
+
813
+ def chunk_gla_fwd_o_gk(
814
+ q: torch.Tensor,
815
+ v: torch.Tensor,
816
+ g: torch.Tensor,
817
+ A: torch.Tensor,
818
+ h: torch.Tensor,
819
+ scale: float,
820
+ cu_seqlens: Optional[torch.LongTensor] = None,
821
+ chunk_size: int = 64
822
+ ):
823
+ B, T, H, K, V = *q.shape, v.shape[-1]
824
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
825
+
826
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
827
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
828
+
829
+ o = torch.empty_like(v)
830
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
831
+ chunk_gla_fwd_kernel_o[grid](
832
+ q,
833
+ v,
834
+ g,
835
+ h,
836
+ o,
837
+ A,
838
+ cu_seqlens,
839
+ chunk_indices,
840
+ scale,
841
+ T=T,
842
+ H=H,
843
+ K=K,
844
+ V=V,
845
+ BT=BT,
846
+ )
847
+ return o
848
+
849
+
850
+ def chunk_gla_bwd_dA(
851
+ v: torch.Tensor,
852
+ do: torch.Tensor,
853
+ scale: float,
854
+ cu_seqlens: Optional[torch.LongTensor] = None,
855
+ chunk_size: int = 64
856
+ ):
857
+ B, T, H, V = v.shape
858
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
859
+
860
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
861
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
862
+ BV = min(64, triton.next_power_of_2(V))
863
+
864
+ dA = v.new_empty(B, T, H, BT, dtype=torch.float)
865
+ grid = (NT, B * H)
866
+ chunk_gla_bwd_kernel_dA[grid](
867
+ v,
868
+ do,
869
+ dA,
870
+ cu_seqlens,
871
+ chunk_indices,
872
+ scale,
873
+ T=T,
874
+ H=H,
875
+ V=V,
876
+ BT=BT,
877
+ BV=BV,
878
+ )
879
+ return dA
880
+
881
+
882
+ def chunk_gla_bwd_dv(
883
+ k: torch.Tensor,
884
+ g: torch.Tensor,
885
+ A: torch.Tensor,
886
+ do: torch.Tensor,
887
+ dh: torch.Tensor,
888
+ cu_seqlens: Optional[torch.LongTensor] = None,
889
+ chunk_size: int = 64
890
+ ):
891
+ B, T, H, K, V = *k.shape, do.shape[-1]
892
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
893
+
894
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
895
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
896
+
897
+ dv = torch.empty_like(do)
898
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
899
+ chunk_gla_bwd_kernel_dv[grid](
900
+ k,
901
+ g,
902
+ A,
903
+ do,
904
+ dh,
905
+ dv,
906
+ cu_seqlens,
907
+ chunk_indices,
908
+ T=T,
909
+ H=H,
910
+ K=K,
911
+ V=V,
912
+ BT=BT,
913
+ )
914
+ return dv
915
+
916
+
917
+ def chunk_gla_bwd_dqk_intra(
918
+ q: torch.Tensor,
919
+ k: torch.Tensor,
920
+ g: torch.Tensor,
921
+ dA: torch.Tensor,
922
+ cu_seqlens: Optional[torch.LongTensor] = None,
923
+ chunk_size: int = 64
924
+ ):
925
+ B, T, H, K = q.shape
926
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
927
+ BC = min(16, BT)
928
+ BK = min(64, triton.next_power_of_2(K))
929
+
930
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
931
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
932
+ NC = triton.cdiv(BT, BC)
933
+ NK = triton.cdiv(K, BK)
934
+
935
+ dq = torch.empty_like(q, dtype=torch.float)
936
+ dk = torch.empty_like(k, dtype=torch.float)
937
+ grid = (NK, NT * NC, B * H)
938
+ chunk_gla_bwd_kernel_intra[grid](
939
+ q,
940
+ k,
941
+ g,
942
+ dA,
943
+ dq,
944
+ dk,
945
+ cu_seqlens,
946
+ chunk_indices,
947
+ T=T,
948
+ H=H,
949
+ K=K,
950
+ BT=BT,
951
+ BC=BC,
952
+ BK=BK,
953
+ NC=NC,
954
+ )
955
+ return dq, dk
956
+
957
+
958
+ def chunk_gla_bwd_dqkg(
959
+ q: torch.Tensor,
960
+ k: torch.Tensor,
961
+ v: torch.Tensor,
962
+ h: torch.Tensor,
963
+ g: torch.Tensor,
964
+ do: torch.Tensor,
965
+ dh: torch.Tensor,
966
+ dq: torch.Tensor,
967
+ dk: torch.Tensor,
968
+ scale: float,
969
+ cu_seqlens: Optional[torch.LongTensor] = None,
970
+ chunk_size: int = 64
971
+ ):
972
+ B, T, H, K, V = *k.shape, v.shape[-1]
973
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
974
+
975
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
976
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
977
+
978
+ dg = torch.empty_like(g)
979
+ dq2 = torch.empty_like(dq)
980
+ dk2 = torch.empty_like(dk)
981
+ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H)
982
+ chunk_gla_bwd_kernel_inter[grid](
983
+ q,
984
+ k,
985
+ v,
986
+ h,
987
+ g,
988
+ do,
989
+ dh,
990
+ dq,
991
+ dk,
992
+ dq2,
993
+ dk2,
994
+ dg,
995
+ cu_seqlens,
996
+ chunk_indices,
997
+ scale,
998
+ T=T,
999
+ H=H,
1000
+ K=K,
1001
+ V=V,
1002
+ BT=BT,
1003
+ )
1004
+ return dq2, dk2, dg
1005
+
1006
+
1007
+ def chunk_gla_fwd(
1008
+ q: torch.Tensor,
1009
+ k: torch.Tensor,
1010
+ v: torch.Tensor,
1011
+ g: torch.Tensor,
1012
+ g_cumsum: Optional[torch.Tensor],
1013
+ scale: float,
1014
+ initial_state: torch.Tensor,
1015
+ output_final_state: bool,
1016
+ cu_seqlens: Optional[torch.LongTensor] = None,
1017
+ chunk_size: int = 64
1018
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1019
+ T = q.shape[1]
1020
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1021
+ if g_cumsum is None:
1022
+ g_cumsum = chunk_local_cumsum(g, BT, cu_seqlens=cu_seqlens)
1023
+
1024
+ h, ht = chunk_fwd_h(
1025
+ k=k,
1026
+ v=v,
1027
+ g=None,
1028
+ gk=g_cumsum,
1029
+ gv=None,
1030
+ h0=initial_state,
1031
+ output_final_state=output_final_state,
1032
+ states_in_fp32=False,
1033
+ cu_seqlens=cu_seqlens,
1034
+ chunk_size=BT
1035
+ )
1036
+
1037
+ # the intra A is kept in fp32
1038
+ # the computation has very marginal effect on the entire throughput
1039
+ A = chunk_gla_fwd_intra_gk(
1040
+ q=q,
1041
+ k=k,
1042
+ g=g_cumsum,
1043
+ scale=scale,
1044
+ cu_seqlens=cu_seqlens,
1045
+ chunk_size=BT
1046
+ )
1047
+ o = chunk_gla_fwd_o_gk(
1048
+ q=q,
1049
+ v=v,
1050
+ g=g_cumsum,
1051
+ A=A,
1052
+ h=h,
1053
+ scale=scale,
1054
+ cu_seqlens=cu_seqlens,
1055
+ chunk_size=BT
1056
+ )
1057
+ return g_cumsum, A, h, ht, o
1058
+
1059
+
1060
+ def chunk_gla_bwd(
1061
+ q: torch.Tensor,
1062
+ k: torch.Tensor,
1063
+ v: torch.Tensor,
1064
+ g: torch.Tensor,
1065
+ g_cumsum: Optional[torch.Tensor],
1066
+ scale: float,
1067
+ initial_state: torch.Tensor,
1068
+ h: torch.Tensor,
1069
+ A: torch.Tensor,
1070
+ do: torch.Tensor,
1071
+ dht: torch.Tensor,
1072
+ cu_seqlens: Optional[torch.LongTensor] = None,
1073
+ chunk_size: int = 64
1074
+ ):
1075
+ T = q.shape[1]
1076
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1077
+ if g_cumsum is None:
1078
+ g_cumsum = chunk_local_cumsum(g, BT, cu_seqlens=cu_seqlens)
1079
+
1080
+ if h is None:
1081
+ h, _ = chunk_fwd_h(
1082
+ k=k,
1083
+ v=v,
1084
+ g=None,
1085
+ gk=g_cumsum,
1086
+ gv=None,
1087
+ h0=initial_state,
1088
+ output_final_state=False,
1089
+ cu_seqlens=cu_seqlens,
1090
+ chunk_size=BT,
1091
+ states_in_fp32=True
1092
+ )
1093
+ dh, dh0 = chunk_bwd_dh(
1094
+ q=q,
1095
+ k=k,
1096
+ v=v,
1097
+ g=None,
1098
+ gk=g_cumsum,
1099
+ gv=None,
1100
+ do=do,
1101
+ h0=initial_state,
1102
+ dht=dht,
1103
+ scale=scale,
1104
+ cu_seqlens=cu_seqlens,
1105
+ chunk_size=BT,
1106
+ states_in_fp32=True
1107
+ )
1108
+
1109
+ dv = chunk_gla_bwd_dv(
1110
+ k=k,
1111
+ g=g_cumsum,
1112
+ A=A,
1113
+ do=do,
1114
+ dh=dh,
1115
+ cu_seqlens=cu_seqlens,
1116
+ chunk_size=BT
1117
+ )
1118
+
1119
+ # dq dk in fp32
1120
+ dA = chunk_gla_bwd_dA(
1121
+ v=v,
1122
+ do=do,
1123
+ scale=scale,
1124
+ cu_seqlens=cu_seqlens,
1125
+ chunk_size=BT
1126
+ )
1127
+ dq, dk = chunk_gla_bwd_dqk_intra(
1128
+ q=q,
1129
+ k=k,
1130
+ g=g_cumsum,
1131
+ dA=dA,
1132
+ cu_seqlens=cu_seqlens,
1133
+ chunk_size=BT
1134
+ )
1135
+ dq, dk, dg = chunk_gla_bwd_dqkg(
1136
+ q=q,
1137
+ k=k,
1138
+ v=v,
1139
+ h=h,
1140
+ g=g_cumsum,
1141
+ do=do,
1142
+ dh=dh,
1143
+ dq=dq,
1144
+ dk=dk,
1145
+ scale=scale,
1146
+ cu_seqlens=cu_seqlens,
1147
+ chunk_size=BT
1148
+ )
1149
+ return dq, dk, dv, dg, dh0
1150
+
1151
+
1152
+ class ChunkGLAFunction(torch.autograd.Function):
1153
+
1154
+ @staticmethod
1155
+ @input_guard
1156
+ def forward(
1157
+ ctx,
1158
+ q,
1159
+ k,
1160
+ v,
1161
+ g,
1162
+ scale,
1163
+ initial_state,
1164
+ output_final_state,
1165
+ cu_seqlens,
1166
+ ):
1167
+ T = q.shape[1]
1168
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1169
+
1170
+ g_cumsum, A, h, ht, o = chunk_gla_fwd(
1171
+ q=q,
1172
+ k=k,
1173
+ v=v,
1174
+ g=g,
1175
+ g_cumsum=None,
1176
+ scale=scale,
1177
+ initial_state=initial_state,
1178
+ output_final_state=output_final_state,
1179
+ cu_seqlens=cu_seqlens,
1180
+ chunk_size=chunk_size
1181
+ )
1182
+ # recompute g_cumsum in bwd pass
1183
+ if g.dtype != torch.float:
1184
+ g_cumsum = None
1185
+ else:
1186
+ g = None
1187
+ ctx.save_for_backward(q, k, v, g, g_cumsum, initial_state, A)
1188
+ ctx.chunk_size = chunk_size
1189
+ ctx.scale = scale
1190
+ ctx.cu_seqlens = cu_seqlens
1191
+ return o, ht
1192
+
1193
+ @staticmethod
1194
+ @input_guard
1195
+ def backward(ctx, do, dht):
1196
+ q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors
1197
+ chunk_size, scale, cu_seqlens = ctx.chunk_size, ctx.scale, ctx.cu_seqlens
1198
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
1199
+ q=q,
1200
+ k=k,
1201
+ v=v,
1202
+ g=g,
1203
+ g_cumsum=g_cumsum,
1204
+ scale=scale,
1205
+ h=None,
1206
+ A=A,
1207
+ initial_state=initial_state,
1208
+ do=do,
1209
+ dht=dht,
1210
+ cu_seqlens=cu_seqlens,
1211
+ chunk_size=chunk_size
1212
+ )
1213
+ return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None
1214
+
1215
+
1216
+ @torch.compiler.disable
1217
+ def chunk_gla(
1218
+ q: torch.Tensor,
1219
+ k: torch.Tensor,
1220
+ v: torch.Tensor,
1221
+ g: torch.Tensor,
1222
+ scale: Optional[int] = None,
1223
+ initial_state: torch.Tensor = None,
1224
+ output_final_state: bool = False,
1225
+ cu_seqlens: Optional[torch.LongTensor] = None,
1226
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1227
+ r"""
1228
+ Args:
1229
+ q (torch.Tensor):
1230
+ queries of shape `[B, T, H, K]`.
1231
+ k (torch.Tensor):
1232
+ keys of shape `[B, T, H, K]`.
1233
+ v (torch.Tensor):
1234
+ values of shape `[B, T, H, V]`.
1235
+ g (torch.Tensor):
1236
+ Forget gates of shape `[B, T, H, K]`.
1237
+ scale (Optional[int]):
1238
+ Scale factor for the attention scores.
1239
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1240
+ initial_state (Optional[torch.Tensor]):
1241
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
1242
+ For equal-length input sequences, `N` equals the batch size `B`.
1243
+ Default: `None`.
1244
+ output_final_state (Optional[bool]):
1245
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
1246
+ cu_seqlens (torch.LongTensor):
1247
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1248
+ consistent with the FlashAttention API.
1249
+
1250
+ Returns:
1251
+ o (torch.Tensor):
1252
+ Outputs of shape `[B, T, H, V]`.
1253
+ final_state (torch.Tensor):
1254
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
1255
+
1256
+ Examples::
1257
+ >>> import torch
1258
+ >>> import torch.nn.functional as F
1259
+ >>> from einops import rearrange
1260
+ >>> from fla.ops.gla import chunk_gla
1261
+ # inputs with equal lengths
1262
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
1263
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1264
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1265
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1266
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
1267
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
1268
+ >>> o, ht = chunk_gla(
1269
+ q, k, v, g,
1270
+ initial_state=h0,
1271
+ output_final_state=True
1272
+ )
1273
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1274
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
1275
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1276
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1277
+ >>> o_var, ht_var = chunk_gla(
1278
+ q, k, v, g,
1279
+ initial_state=h0,
1280
+ output_final_state=True,
1281
+ cu_seqlens=cu_seqlens
1282
+ )
1283
+ >>> assert o.allclose(o_var.view(o.shape))
1284
+ >>> assert ht.allclose(ht_var)
1285
+ """
1286
+ if cu_seqlens is not None:
1287
+ if q.shape[0] != 1:
1288
+ raise ValueError(
1289
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1290
+ f"Please flatten variable-length inputs before processing."
1291
+ )
1292
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
1293
+ raise ValueError(
1294
+ f"The number of initial states is expected to be equal to the number of input sequences, "
1295
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
1296
+ )
1297
+ if scale is None:
1298
+ scale = q.shape[-1] ** -0.5
1299
+ o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens)
1300
+ return o, final_state
fla3/ops/gla/fused_chunk.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+ from packaging import version
12
+
13
+ from fla.ops.utils import chunk_local_cumsum
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def prepare_qg_kg(
20
+ q,
21
+ k,
22
+ g,
23
+ qg,
24
+ kg,
25
+ scale,
26
+ T,
27
+ K: tl.constexpr,
28
+ BT: tl.constexpr,
29
+ BK: tl.constexpr
30
+ ):
31
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+ p_q = q + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
33
+ p_g = g + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
34
+ p_k = k + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
35
+ p_qg = qg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
36
+ p_kg = kg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK)
37
+
38
+ mask = (i_k * BK + tl.arange(0, BK)) < K
39
+
40
+ last_decay = tl.load(g + i_bh * T*K + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK))
41
+
42
+ for _ in range(BT):
43
+ b_q = tl.load(p_q, mask=mask, other=0)
44
+ b_k = tl.load(p_k, mask=mask, other=0)
45
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
46
+ b_q *= exp(b_g) * scale
47
+ b_k *= exp(last_decay - b_g)
48
+ tl.store(p_kg, b_k.to(p_kg.dtype.element_ty), mask=mask)
49
+ tl.store(p_qg, b_q.to(p_qg.dtype.element_ty), mask=mask)
50
+ p_q += K
51
+ p_g += K
52
+ p_k += K
53
+ p_kg += K
54
+ p_qg += K
55
+
56
+
57
+ @triton.jit(do_not_specialize=['T'])
58
+ def bwd_decay_global_cumsum(
59
+ dq_inner,
60
+ dq_inter,
61
+ dk_inner,
62
+ dk_inter,
63
+ q,
64
+ k,
65
+ g,
66
+ dg,
67
+ T,
68
+ K: tl.constexpr,
69
+ BT: tl.constexpr,
70
+ BK: tl.constexpr
71
+ ):
72
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
73
+ p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
74
+ p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
75
+ p_g = g + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
76
+ p_dg = dg + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
77
+ p_dq_inner = dq_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
78
+ p_dk_inner = dk_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
79
+ p_dq_inter = dq_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
80
+ p_dk_inter = dk_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K
81
+ cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
82
+ mask = (i_k * BK + tl.arange(0, BK)) < K
83
+ last_g = tl.zeros([BK], dtype=tl.float32)
84
+ for j in range(BT-1, -1, -1):
85
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
86
+ if j == (BT-1):
87
+ last_g = b_g
88
+ b_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
89
+ b_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
90
+ b_dq2 *= exp(b_g)
91
+ b_dq = b_dq1 + b_dq2
92
+ tl.store(p_dq_inter, b_dq, mask=mask)
93
+ b_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
94
+ b_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
95
+ b_dk2 *= safe_exp(last_g - b_g)
96
+ b_dk = b_dk1 + b_dk2
97
+ tl.store(p_dk_inter, b_dk, mask=mask)
98
+ b_q = tl.load(p_q, mask=mask, other=0)
99
+ b_k = tl.load(p_k, mask=mask, other=0)
100
+ b_dg = b_dq * b_q - b_dk * b_k
101
+ cum_grad_dg += b_dg
102
+ tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
103
+ p_g -= K
104
+ p_k -= K
105
+ p_q -= K
106
+ p_dq_inner -= K
107
+ p_dk_inner -= K
108
+ p_dq_inter -= K
109
+ p_dk_inter -= K
110
+ p_dg -= K
111
+
112
+
113
+ @triton.jit(do_not_specialize=['T'])
114
+ def fused_chunk_gla_fwd_kernel(
115
+ q,
116
+ k,
117
+ v,
118
+ g,
119
+ o,
120
+ h0,
121
+ ht,
122
+ T,
123
+ B: tl.constexpr,
124
+ H: tl.constexpr,
125
+ K: tl.constexpr,
126
+ V: tl.constexpr,
127
+ BT: tl.constexpr,
128
+ BK: tl.constexpr,
129
+ BV: tl.constexpr,
130
+ USE_INITIAL_STATE: tl.constexpr,
131
+ STORE_FINAL_STATE: tl.constexpr,
132
+ CHECK: tl.constexpr
133
+ ):
134
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
135
+
136
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
137
+
138
+ # make block pointers
139
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
140
+ p_gn = g + i_bh * T*K + (BT - 1) * K + i_k * BK + tl.arange(0, BK)
141
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
142
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
143
+ p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
144
+
145
+ if USE_INITIAL_STATE:
146
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
148
+
149
+ mask = (i_k * BK + tl.arange(0, BK)) < K
150
+
151
+ for i in range(0, tl.cdiv(T, BT)):
152
+ # [BK, BT]
153
+ b_k = tl.load(p_k, boundary_check=(0, 1))
154
+ # [BT, BV]
155
+ b_v = tl.load(p_v, boundary_check=(0, 1))
156
+ # [BT, BK]
157
+ b_q = tl.load(p_q, boundary_check=(0, 1))
158
+ b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
159
+ if CHECK and i == 0:
160
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
161
+ b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
162
+ else:
163
+ b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
164
+ b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
165
+
166
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
167
+ p_q = tl.advance(p_q, (BT, 0))
168
+ p_k = tl.advance(p_k, (0, BT))
169
+ p_v = tl.advance(p_v, (BT, 0))
170
+ p_o = tl.advance(p_o, (BT, 0))
171
+ p_gn += BT * K
172
+
173
+ if STORE_FINAL_STATE:
174
+ p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
175
+ tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
176
+
177
+
178
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
179
+ @triton.jit(do_not_specialize=['T'])
180
+ def fused_chunk_gla_bwd_kernel(
181
+ q, k, v, g,
182
+ do,
183
+ dq,
184
+ dk,
185
+ dv,
186
+ h0,
187
+ scale,
188
+ T,
189
+ B: tl.constexpr,
190
+ H: tl.constexpr,
191
+ K: tl.constexpr,
192
+ V: tl.constexpr,
193
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
194
+ BT: tl.constexpr,
195
+ BK: tl.constexpr,
196
+ BV: tl.constexpr,
197
+ USE_INITIAL_STATE: tl.constexpr,
198
+ CHECK: tl.constexpr
199
+ ):
200
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
201
+ # [BV, BK]
202
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
203
+
204
+ if USE_INITIAL_STATE:
205
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
206
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
207
+
208
+ mask = (i_k * BK + tl.arange(0, BK)) < K
209
+ for i in range(0, tl.cdiv(T, BT)):
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
211
+ p_gn = g + i_bh * T*K + ((i+1) * BT - 1) * K + i_k * BK + tl.arange(0, BK)
212
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
213
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
214
+ p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
215
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
216
+ # [BT, K]
217
+ b_k = tl.load(p_k, boundary_check=(0, 1))
218
+ b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
219
+
220
+ # [V, BT]
221
+ b_v = tl.load(p_v, boundary_check=(0, 1))
222
+ # [BT, V]
223
+ b_do = tl.load(p_do, boundary_check=(0, 1))
224
+ # [V, K]
225
+ if CHECK and i == 0:
226
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
227
+ b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
228
+ else:
229
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
230
+ b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
231
+ b_dq *= scale
232
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
233
+
234
+ # sync threads
235
+ b_h = None
236
+ tl.debug_barrier()
237
+ # [BK, BV]
238
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
239
+
240
+ # cum = tl.zeros([BK], dtype=tl.float32)
241
+ for i in range(1, tl.cdiv(T, BT) + 1):
242
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
243
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
244
+ p_gn = g + i_bh * T*K + (T - (i-1) * BT - 1) * K + i_k * BK + tl.arange(0, BK)
245
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
246
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
247
+ p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * T*K, (T, K),
248
+ (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
249
+ p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * T*V, (T, V),
250
+ (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
251
+ # [K, BT]
252
+ b_q = tl.load(p_q, boundary_check=(0, 1))
253
+ # [BT, K]
254
+ b_k = tl.load(p_k, boundary_check=(0, 1))
255
+ # [BT, V]
256
+ b_v = tl.load(p_v, boundary_check=(0, 1))
257
+ b_do = tl.load(p_do, boundary_check=(0, 1))
258
+ b_db = tl.load(p_gn, mask=mask, other=0).to(tl.float32)
259
+
260
+ # inter-chunk
261
+ # [K, V]
262
+ if CHECK and i == 1:
263
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
264
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
265
+ b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
266
+ else:
267
+ b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
268
+ b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
269
+ b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
270
+
271
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
272
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
273
+
274
+
275
+ @triton.jit
276
+ def fwd_inner_chunk(
277
+ q, k, g, A,
278
+ scale, # K ** -0.5
279
+ B: tl.constexpr, # B
280
+ H: tl.constexpr, # H
281
+ T, # T
282
+ K: tl.constexpr, # K
283
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
284
+ BK: tl.constexpr # BLOCK SIZE along the K dimension
285
+ ):
286
+
287
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
288
+
289
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
290
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
291
+
292
+ b_k = tl.load(p_k, boundary_check=(0, 1))
293
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
294
+
295
+ mask = (i_k * BK + tl.arange(0, BK)) < K
296
+ o_i = tl.arange(0, BT)
297
+
298
+ p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
299
+ p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
300
+ p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
301
+
302
+ for i in range(BT):
303
+ b_q = tl.load(p_q, mask=mask, other=0) * scale
304
+ b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
305
+ s = b_q[None, :] * b_k * safe_exp(b_gq[None, :] - b_g)
306
+ score = tl.sum(s, axis=1)
307
+ score = tl.where(o_i <= i, score, 0)
308
+ tl.store(p_A, score.to(p_A.dtype.element_ty))
309
+ p_q += K
310
+ p_gq += K
311
+ p_A += BT
312
+
313
+
314
+ @triton.jit
315
+ def bwd_inner_chunk(
316
+ q,
317
+ k,
318
+ g,
319
+ dA,
320
+ dq,
321
+ dk,
322
+ T, # T
323
+ K: tl.constexpr, # K
324
+ # clamp_min, # minimum log value of the gate for numerical stability. default: -5
325
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
326
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
327
+ ):
328
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
329
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
330
+ b_k = tl.load(p_k, boundary_check=(0, 1))
331
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
332
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
333
+
334
+ mask = (i_k * BK + tl.arange(0, BK)) < K
335
+ o_i = tl.arange(0, BT)
336
+
337
+ p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
338
+ p_dq = dq + (i_bh) * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
339
+ p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK)
340
+ p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
341
+
342
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
343
+
344
+ for i in range(BT):
345
+ b_q = tl.load(p_q, mask=mask, other=0)
346
+ b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
347
+ score = safe_exp(b_gq[None, :] - b_g)
348
+ score = tl.where(o_i[:, None] <= i, score, 0)
349
+ b_dA = tl.load(p_dA)
350
+ b_dA = tl.where(o_i <= i, b_dA, 0)
351
+ b_dk += (b_dA[:, None] * score * b_q[None, :])
352
+ b_dq = tl.sum(b_dA[:, None] * score * b_k, axis=0)
353
+ tl.store(p_dq, b_dq, mask=mask)
354
+ p_q += K
355
+ p_dq += K
356
+ p_gq += K
357
+ p_dA += BT
358
+
359
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
360
+ tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
361
+
362
+
363
+ class FusedChunkGLAFunction(torch.autograd.Function):
364
+
365
+ @staticmethod
366
+ @input_guard
367
+ @autocast_custom_fwd
368
+ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
369
+ ctx.g_dtype = g.dtype
370
+ ctx.scale = scale
371
+ B, H, T, K, V = *k.shape, v.shape[-1]
372
+ BT = 16 # chunk_size
373
+ BK, BV = min(K, 64), min(V, 64)
374
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
375
+ num_stages = 1
376
+ num_warps = 2
377
+
378
+ g_org = g
379
+ # cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
380
+ g = chunk_local_cumsum(g_org, chunk_size=BT)
381
+ o = q.new_empty(NK, B, H, T, V)
382
+ q_g = torch.empty_like(q)
383
+ k_g = torch.empty_like(k)
384
+
385
+ grid = (NK, triton.cdiv(T, BT), B * H)
386
+ prepare_qg_kg[grid](
387
+ q,
388
+ k,
389
+ g,
390
+ q_g,
391
+ k_g,
392
+ scale,
393
+ T=T,
394
+ K=K,
395
+ BT=BT,
396
+ BK=BK,
397
+ num_warps=1
398
+ )
399
+
400
+ if output_final_state:
401
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False)
402
+ else:
403
+ final_state = None
404
+ # the bug still exists even for Triton 2.2 on H100 GPUs
405
+ # so we always enable initial checks
406
+ CHECK = True
407
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
408
+ import warnings
409
+ warnings.warn(
410
+ "Triton<2.2.0 detected for running this kernel, "
411
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
412
+ "that lead to significant precision loss. "
413
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
414
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
415
+ )
416
+ CHECK = True
417
+
418
+ grid = (NV, NK, B * H)
419
+ fused_chunk_gla_fwd_kernel[grid](
420
+ q_g, k_g, v, g, o, initial_state, final_state,
421
+ T=T,
422
+ B=B,
423
+ H=H,
424
+ K=K,
425
+ V=V,
426
+ BT=BT,
427
+ BK=BK,
428
+ BV=BV,
429
+ USE_INITIAL_STATE=initial_state is not None,
430
+ STORE_FINAL_STATE=output_final_state,
431
+ CHECK=CHECK,
432
+ num_warps=num_warps,
433
+ num_stages=num_stages
434
+ )
435
+
436
+ o = o.sum(0)
437
+
438
+ # intra-chunk
439
+ chunk_size = 16
440
+ num_chunk = T // chunk_size
441
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
442
+ BK = min(K, 64)
443
+ NK = triton.cdiv(K, BK)
444
+ A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT)
445
+ grid = (NK, triton.cdiv(T, BT), B * H)
446
+ fwd_inner_chunk[grid](
447
+ q, k, g, A,
448
+ scale,
449
+ B=B,
450
+ H=H,
451
+ T=T,
452
+ K=K,
453
+ BT=BT,
454
+ BK=BK,
455
+ num_stages=3,
456
+ num_warps=4
457
+ )
458
+ A = A.sum(0)
459
+ o2 = A @ v2
460
+ o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
461
+ # combine inner and inter
462
+ o.add_(o2)
463
+ ctx.save_for_backward(q, k, v, g_org, A, initial_state)
464
+ ctx.CHECK = CHECK
465
+ return o.to(v), final_state
466
+
467
+ @staticmethod
468
+ @input_guard
469
+ @autocast_custom_bwd
470
+ def backward(ctx, do, dht=None):
471
+ q, k, v, g_org, A, initial_state = ctx.saved_tensors
472
+ B, H, T, K, V = *k.shape, v.shape[-1]
473
+ scale = ctx.scale
474
+
475
+ # recomputation
476
+ # inter-chunk
477
+ BT = 16 # chunk_size
478
+ g = chunk_local_cumsum(g_org, chunk_size=BT)
479
+ BK, BV = min(K, 64), min(V, 64)
480
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
481
+ q_g = torch.empty_like(q)
482
+ k_g = torch.empty_like(k)
483
+ grid = (NK, triton.cdiv(T, BT), B * H)
484
+ prepare_qg_kg[grid](
485
+ q,
486
+ k,
487
+ g,
488
+ q_g,
489
+ k_g,
490
+ scale,
491
+ T=T,
492
+ K=K,
493
+ BT=BT,
494
+ BK=BK,
495
+ num_warps=1
496
+ )
497
+
498
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
499
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
500
+ num_stages = 1
501
+ num_warps = 2
502
+ dq = q.new_empty(NV, B, H, T, K)
503
+ dk = q.new_empty(NV, B, H, T, K)
504
+ dv = q.new_empty(NK, B, H, T, V)
505
+
506
+ grid = (NV, NK, B * H)
507
+
508
+ fused_chunk_gla_bwd_kernel[grid](
509
+ q_g,
510
+ k_g,
511
+ v,
512
+ g,
513
+ do,
514
+ dq,
515
+ dk,
516
+ dv,
517
+ initial_state,
518
+ scale,
519
+ T=T,
520
+ B=B,
521
+ H=H,
522
+ K=K,
523
+ V=V,
524
+ BT=BT,
525
+ BK=BK,
526
+ BV=BV,
527
+ USE_INITIAL_STATE=initial_state is not None,
528
+ CHECK=ctx.CHECK,
529
+ num_warps=num_warps,
530
+ num_stages=num_stages,
531
+ )
532
+ dq = dq.sum(0)
533
+ dk = dk.sum(0)
534
+ dv = dv.sum(0)
535
+
536
+ # intra chunk
537
+ NT = T // BT
538
+ v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=NT)
539
+ do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=NT)
540
+ dA2 = (do2 @ v2.transpose(-2, -1)) * scale
541
+ dv2 = A.transpose(-1, -2) @ do2
542
+ dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=NT)
543
+
544
+ BK = min(triton.next_power_of_2(K), 16)
545
+ NK = triton.cdiv(K, BK)
546
+ dk2 = torch.empty_like(k)
547
+ dq2 = torch.empty_like(q)
548
+
549
+ grid = (NK, NT, B * H)
550
+ bwd_inner_chunk[grid](
551
+ q, k, g,
552
+ dA2,
553
+ dq2,
554
+ dk2,
555
+ T=T,
556
+ K=K,
557
+ BT=BT,
558
+ BK=BK,
559
+ num_warps=1,
560
+ num_stages=3
561
+ )
562
+
563
+ BK = min(triton.next_power_of_2(K), 32)
564
+ NK = triton.cdiv(K, BK)
565
+ dg = torch.empty_like(g, dtype=torch.float32)
566
+ grid = (NK, triton.cdiv(T, BT), B * H)
567
+ bwd_decay_global_cumsum[grid](
568
+ dq2,
569
+ dq,
570
+ dk2,
571
+ dk,
572
+ q,
573
+ k,
574
+ g,
575
+ dg,
576
+ T=T,
577
+ K=K,
578
+ BT=BT,
579
+ BK=BK,
580
+ num_warps=1,
581
+ num_stages=1
582
+ )
583
+ dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
584
+
585
+ def rev_cumsum_exclusive(x):
586
+ cumsum_x = x.cumsum(-2)
587
+ rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
588
+ return rev_cumsum_x
589
+
590
+ rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
591
+ dg.add_(rev_cumsum_dg.unsqueeze(-2))
592
+ dv.add_(dv2)
593
+ dg = rearrange(dg, 'b h n c d -> b h (n c) d')
594
+
595
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
596
+
597
+
598
+ def ceildiv(a, b):
599
+ return -(a // -b)
600
+
601
+
602
+ def pad(x, chunk_size=16):
603
+ T = x.shape[-2]
604
+ padded_seq_len = ceildiv(T, chunk_size) * chunk_size
605
+ if x.shape[-2] % chunk_size != 0:
606
+ x = F.pad(x, (0, 0, 0, padded_seq_len - T))
607
+ return x
608
+
609
+
610
+ def fused_chunk_gla(
611
+ q: torch.Tensor,
612
+ k: torch.Tensor,
613
+ v: torch.Tensor,
614
+ g: torch.Tensor,
615
+ scale: int = -1,
616
+ initial_state: torch.Tensor = None,
617
+ output_final_state: bool = False,
618
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
619
+ if scale == -1:
620
+ scale = q.shape[-1] ** -0.5
621
+ seq_len = q.shape[-2]
622
+ q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
623
+ o, final_state = FusedChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
624
+ o = o[..., :seq_len, :].contiguous()
625
+ return o, final_state
fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (320 Bytes). View file
 
fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (57 kB). View file
 
fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
fla3/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange, reduce
11
+
12
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
13
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
14
+ from fla.ops.utils import prepare_chunk_indices
15
+ from fla.ops.utils.cumsum import chunk_local_cumsum
16
+ from fla.ops.utils.op import exp, safe_exp
17
+ from fla.ops.utils.softmax import softmax_bwd, softmax_fwd
18
+ from fla.utils import input_guard
19
+
20
+
21
+ @triton.heuristics({
22
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
23
+ })
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
27
+ for BK in [32, 64]
28
+ for BV in [32, 64]
29
+ for num_warps in [2, 4, 8]
30
+ for num_stages in [2, 3, 4]
31
+ ],
32
+ key=['BT']
33
+ )
34
+ @triton.jit(do_not_specialize=['T'])
35
+ def chunk_gsa_fwd_k_kernel_inter(
36
+ q,
37
+ k,
38
+ h,
39
+ g,
40
+ o,
41
+ A,
42
+ cu_seqlens,
43
+ chunk_indices,
44
+ scale,
45
+ T,
46
+ HQ: tl.constexpr,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ V: tl.constexpr,
50
+ BT: tl.constexpr,
51
+ BK: tl.constexpr,
52
+ BV: tl.constexpr,
53
+ NG: tl.constexpr,
54
+ IS_VARLEN: tl.constexpr,
55
+ ):
56
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
58
+ i_h = i_hq // NG
59
+ if IS_VARLEN:
60
+ i_tg = i_t
61
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
62
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ else:
66
+ NT = tl.cdiv(T, BT)
67
+ i_tg = i_b * NT + i_t
68
+ bos, eos = i_b * T, i_b * T + T
69
+
70
+ o_i = tl.arange(0, BT)
71
+ m_s = o_i[:, None] >= o_i[None, :]
72
+
73
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
74
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
75
+ for i_k in range(tl.cdiv(K, BK)):
76
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+
80
+ # [BT, BK]
81
+ b_q = tl.load(p_q, boundary_check=(0, 1))
82
+ b_q = (b_q * scale).to(b_q.dtype)
83
+ # [BK, BT]
84
+ b_k = tl.load(p_k, boundary_check=(0, 1))
85
+ # [BK, BV]
86
+ b_h = tl.load(p_h, boundary_check=(0, 1))
87
+ # [BT, BV]
88
+ b_o += tl.dot(b_q, b_h)
89
+ # [BT, BT]
90
+ b_A += tl.dot(b_q, b_k)
91
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
92
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
93
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
94
+ # [BT, BV]
95
+ b_g = tl.load(p_g, boundary_check=(0, 1))
96
+ b_o = b_o * exp(b_g)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+ # [BT, BT]
100
+ b_A = tl.where(m_s, b_A, 0.)
101
+ if i_v == 0:
102
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
103
+
104
+
105
+ @triton.heuristics({
106
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
107
+ })
108
+ @triton.jit(do_not_specialize=['T'])
109
+ def chunk_gsa_fwd_k_kernel_intra(
110
+ v,
111
+ g,
112
+ o,
113
+ A,
114
+ cu_seqlens,
115
+ chunk_indices,
116
+ T,
117
+ HQ: tl.constexpr,
118
+ H: tl.constexpr,
119
+ V: tl.constexpr,
120
+ BT: tl.constexpr,
121
+ BC: tl.constexpr,
122
+ BV: tl.constexpr,
123
+ NC: tl.constexpr,
124
+ NG: tl.constexpr,
125
+ IS_VARLEN: tl.constexpr,
126
+ ):
127
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
128
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
129
+ i_h = i_hq // NG
130
+ i_t, i_i = i_c // NC, i_c % NC
131
+ if IS_VARLEN:
132
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
133
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
134
+ T = eos - bos
135
+ else:
136
+ bos, eos = i_b * T, i_b * T + T
137
+
138
+ o_v = i_v * BV + tl.arange(0, BV)
139
+ m_v = o_v < V
140
+
141
+ if i_t * BT + i_i * BC > T:
142
+ return
143
+
144
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
145
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
146
+ # [BV,]
147
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
148
+ # [BC, BV]
149
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
150
+ for i_j in range(0, i_i):
151
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
152
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
153
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
154
+ # [BC, BV]
155
+ b_v = tl.load(p_v, boundary_check=(0, 1))
156
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
157
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
158
+ # [BC, BC]
159
+ b_A = tl.load(p_A, boundary_check=(0, 1))
160
+ b_o += tl.dot(b_A, b_vg)
161
+ # [BC, BV]
162
+ b_g = tl.load(p_g, boundary_check=(0, 1))
163
+ b_o *= exp(b_g - b_gn[None, :])
164
+
165
+ o_i = tl.arange(0, BC)
166
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
167
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
168
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
169
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
170
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
171
+ # [BC,]
172
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
173
+ # [BV,]
174
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
175
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
176
+ # [BC, BV]
177
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
178
+ # avoid 0 * inf = inf
179
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
180
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
181
+ b_o += tl.load(p_o, boundary_check=(0, 1))
182
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
183
+
184
+
185
+ @triton.heuristics({
186
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
187
+ })
188
+ @triton.autotune(
189
+ configs=[
190
+ triton.Config({}, num_warps=num_warps)
191
+ for num_warps in [2, 4, 8]
192
+ ],
193
+ key=["BT"]
194
+ )
195
+ @triton.jit(do_not_specialize=['T'])
196
+ def chunk_gsa_bwd_k_kernel_dA(
197
+ v,
198
+ g,
199
+ do,
200
+ dA,
201
+ chunk_indices,
202
+ cu_seqlens,
203
+ scale,
204
+ T,
205
+ B: tl.constexpr,
206
+ HQ: tl.constexpr,
207
+ H: tl.constexpr,
208
+ V: tl.constexpr,
209
+ BT: tl.constexpr,
210
+ BC: tl.constexpr,
211
+ BV: tl.constexpr,
212
+ NC: tl.constexpr,
213
+ NG: tl.constexpr,
214
+ IS_VARLEN: tl.constexpr,
215
+ ):
216
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
217
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
218
+ i_h = i_hq // NG
219
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
220
+ if IS_VARLEN:
221
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
222
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
223
+ all = T
224
+ T = eos - bos
225
+ else:
226
+ bos, eos = i_b * T, i_b * T + T
227
+ all = B * T
228
+
229
+ o_v = i_v * BV + tl.arange(0, BV)
230
+ m_v = o_v < V
231
+
232
+ if i_t * BT + i_i * BC > T:
233
+ return
234
+
235
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
236
+
237
+ # [BC, BC]
238
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
239
+ if i_i > i_j:
240
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
241
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
242
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
243
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
244
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
245
+ # [BV,]
246
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
247
+ # [BC, BV]
248
+ b_g = tl.load(p_g, boundary_check=(0, 1))
249
+ b_do = tl.load(p_do, boundary_check=(0, 1))
250
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
251
+ # [BV, BC]
252
+ b_v = tl.load(p_v, boundary_check=(0, 1))
253
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
254
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
255
+ # [BC, BC]
256
+ b_dA = tl.dot(b_do, b_vg)
257
+ elif i_i == i_j:
258
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
259
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
260
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
261
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
262
+ # [BC, BV]
263
+ b_g = tl.load(p_g, boundary_check=(0, 1))
264
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
265
+ m_v = o_v < V
266
+
267
+ o_i = tl.arange(0, BC)
268
+ # [BC, BC]
269
+ m_dA = o_i[:, None] >= o_i[None, :]
270
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
271
+ # [BV,]
272
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
273
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
274
+ # [BC,]
275
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
276
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
277
+
278
+ p_v += H*V
279
+ p_gv += H*V
280
+ b_dA = tl.where(m_dA, b_dA, 0.)
281
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
282
+
283
+
284
+ @triton.heuristics({
285
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
286
+ })
287
+ @triton.autotune(
288
+ configs=[
289
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
290
+ for num_warps in [2, 4]
291
+ for num_stages in [2, 3, 4]
292
+ ],
293
+ key=['BT']
294
+ )
295
+ @triton.jit(do_not_specialize=['T'])
296
+ def chunk_gsa_bwd_k_kernel_dqkvg(
297
+ q,
298
+ k,
299
+ v,
300
+ h,
301
+ g,
302
+ A,
303
+ do,
304
+ dh,
305
+ dq,
306
+ dk,
307
+ dv,
308
+ dg,
309
+ dgv,
310
+ dA,
311
+ cu_seqlens,
312
+ chunk_indices,
313
+ scale,
314
+ T,
315
+ B: tl.constexpr,
316
+ HQ: tl.constexpr,
317
+ H: tl.constexpr,
318
+ K: tl.constexpr,
319
+ V: tl.constexpr,
320
+ BT: tl.constexpr,
321
+ BK: tl.constexpr,
322
+ BV: tl.constexpr,
323
+ NG: tl.constexpr,
324
+ IS_VARLEN: tl.constexpr,
325
+ ):
326
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
327
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
328
+ i_h = i_hq // NG
329
+ if IS_VARLEN:
330
+ i_tg = i_t
331
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
332
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
333
+ all = T
334
+ T = eos - bos
335
+ NT = tl.cdiv(T, BT)
336
+ else:
337
+ NT = tl.cdiv(T, BT)
338
+ i_tg = i_b * NT + i_t
339
+ bos, eos = i_b * T, i_b * T + T
340
+ all = B * T
341
+
342
+ o_i = tl.arange(0, BT)
343
+ o_t = min(i_t * BT + BT, T)
344
+ m_s = o_i[:, None] >= o_i[None, :]
345
+
346
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
347
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
348
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
349
+
350
+ # [BT, BK]
351
+ b_q = tl.load(p_q, boundary_check=(0, 1))
352
+ b_k = tl.load(p_k, boundary_check=(0, 1))
353
+ # [BT, BT]
354
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
355
+ b_A = tl.where(m_s, b_A, 0.)
356
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
357
+
358
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
359
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
360
+ for i_v in range(tl.cdiv(V, BV)):
361
+ o_v = i_v * BV + tl.arange(0, BV)
362
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
364
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
365
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
366
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
367
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
368
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
370
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
371
+ m_v = o_v < V
372
+
373
+ # [BV,]
374
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
375
+ # [BT, BV]
376
+ b_v = tl.load(p_v, boundary_check=(0, 1))
377
+ b_g = tl.load(p_g, boundary_check=(0, 1))
378
+ b_gv = exp(b_gn[None, :] - b_g)
379
+ # [BV, BK]
380
+ b_h = tl.load(p_h, boundary_check=(0, 1))
381
+ # [BT, BV]
382
+ b_do = tl.load(p_do, boundary_check=(0, 1))
383
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
384
+ # [BK, BV]
385
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
386
+ # [BV]
387
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
388
+
389
+ b_dh = b_dh.to(b_k.dtype)
390
+ # [BT, BK]
391
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
392
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
393
+ # [BT, BV]
394
+ b_dv = tl.dot(b_k, b_dh) * b_gv
395
+ # [BV]
396
+ b_dg += tl.sum(b_dv * b_v, 0)
397
+
398
+ if i_k == 0:
399
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
400
+ else:
401
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
402
+
403
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
404
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
405
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
406
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
407
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
408
+ # [BT, BT]
409
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
410
+ # [BT, BK]
411
+ b_dq += tl.dot(b_dA, b_k)
412
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
413
+
414
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
415
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
416
+
417
+
418
+ @triton.heuristics({
419
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
420
+ })
421
+ @triton.jit(do_not_specialize=['T'])
422
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
423
+ v,
424
+ g,
425
+ o,
426
+ A,
427
+ do,
428
+ dv,
429
+ dg,
430
+ cu_seqlens,
431
+ chunk_indices,
432
+ T,
433
+ HQ: tl.constexpr,
434
+ H: tl.constexpr,
435
+ V: tl.constexpr,
436
+ BT: tl.constexpr,
437
+ BC: tl.constexpr,
438
+ BV: tl.constexpr,
439
+ NC: tl.constexpr,
440
+ NG: tl.constexpr,
441
+ IS_VARLEN: tl.constexpr,
442
+ ):
443
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
444
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
445
+ i_h = i_hq // NG
446
+ i_t, i_i = i_c // NC, i_c % NC
447
+ if IS_VARLEN:
448
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
449
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
450
+ T = eos - bos
451
+ else:
452
+ bos, eos = i_b * T, i_b * T + T
453
+
454
+ o_v = i_v * BV + tl.arange(0, BV)
455
+ m_v = o_v < V
456
+
457
+ if i_t * BT + i_i * BC > T:
458
+ return
459
+
460
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
461
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
462
+ # [BV,]
463
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
464
+ # [BC, BV]
465
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
466
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
467
+ for i_j in range(i_i + 1, NC):
468
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
469
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
470
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
471
+ # [BC, BV]
472
+ b_g = tl.load(p_g, boundary_check=(0, 1))
473
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
474
+ # [BC, BC]
475
+ b_A = tl.load(p_A, boundary_check=(0, 1))
476
+ # [BC, BV]
477
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
478
+ b_dv *= exp(b_gn[None, :] - b_gv)
479
+
480
+ o_i = tl.arange(0, BC)
481
+ o_c = i_i * BC + tl.arange(0, BC)
482
+
483
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
484
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
485
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
486
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
487
+ # [BC,]
488
+ b_A = tl.load(p_A)
489
+ # [BV,]
490
+ b_g = tl.load(p_g, mask=m_v, other=0)
491
+ b_do = tl.load(p_do, mask=m_v, other=0)
492
+ # [BC, BV]
493
+ m_i = o_i[:, None] <= j
494
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
495
+
496
+ p_g += H * V
497
+ p_A += HQ * BT
498
+ p_do += HQ * V
499
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
500
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
501
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
502
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
503
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
504
+
505
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
506
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
507
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
508
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
509
+ b_dg = b_o * b_do - b_v * b_dv
510
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
511
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
512
+
513
+
514
+ def chunk_gsa_fwd_v(
515
+ q: torch.Tensor,
516
+ k: torch.Tensor,
517
+ v: torch.Tensor,
518
+ g: torch.Tensor,
519
+ scale: float = 1.,
520
+ initial_state: Optional[torch.Tensor] = None,
521
+ output_final_state: bool = False,
522
+ cu_seqlens: Optional[torch.LongTensor] = None,
523
+ chunk_size: int = 64
524
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
525
+ _, A, h, ht, o = chunk_gla_fwd(
526
+ q=q,
527
+ k=k,
528
+ v=v,
529
+ g=None,
530
+ g_cumsum=g,
531
+ scale=scale,
532
+ initial_state=initial_state,
533
+ output_final_state=output_final_state,
534
+ cu_seqlens=cu_seqlens,
535
+ chunk_size=chunk_size
536
+ )
537
+ return A, h, ht, o
538
+
539
+
540
+ def chunk_gsa_fwd_k(
541
+ q: torch.Tensor,
542
+ k: torch.Tensor,
543
+ v: torch.Tensor,
544
+ g: torch.Tensor,
545
+ h0: Optional[torch.Tensor] = None,
546
+ output_final_state: bool = False,
547
+ scale: float = 1.,
548
+ cu_seqlens: Optional[torch.LongTensor] = None,
549
+ chunk_size: int = 64
550
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
551
+ B, T, H, K, V = *k.shape, v.shape[-1]
552
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
553
+ BC = min(16, BT)
554
+ BV = min(64, triton.next_power_of_2(V))
555
+ HQ = q.shape[2]
556
+
557
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
558
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
559
+ NC = triton.cdiv(BT, BC)
560
+ NG = HQ // H
561
+
562
+ h, ht = chunk_fwd_h(
563
+ k=k,
564
+ v=v,
565
+ g=None,
566
+ gk=None,
567
+ gv=g,
568
+ h0=h0,
569
+ output_final_state=output_final_state,
570
+ cu_seqlens=cu_seqlens,
571
+ chunk_size=BT,
572
+ states_in_fp32=False
573
+ )
574
+ o = v.new_empty(B, T, HQ, V)
575
+ A = q.new_empty(B, T, HQ, BT)
576
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
577
+ chunk_gsa_fwd_k_kernel_inter[grid](
578
+ q,
579
+ k,
580
+ h,
581
+ g,
582
+ o,
583
+ A,
584
+ cu_seqlens=cu_seqlens,
585
+ chunk_indices=chunk_indices,
586
+ scale=scale,
587
+ T=T,
588
+ HQ=HQ,
589
+ H=H,
590
+ K=K,
591
+ V=V,
592
+ BT=BT,
593
+ NG=NG,
594
+ )
595
+
596
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
597
+ chunk_gsa_fwd_k_kernel_intra[grid](
598
+ v,
599
+ g,
600
+ o,
601
+ A,
602
+ cu_seqlens=cu_seqlens,
603
+ chunk_indices=chunk_indices,
604
+ T=T,
605
+ HQ=HQ,
606
+ H=H,
607
+ V=V,
608
+ BT=BT,
609
+ BC=BC,
610
+ BV=BV,
611
+ NC=NC,
612
+ NG=NG,
613
+ num_warps=4,
614
+ num_stages=2
615
+ )
616
+ return A, h, ht, o
617
+
618
+
619
+ def chunk_gsa_bwd_v(
620
+ q: torch.Tensor,
621
+ k: torch.Tensor,
622
+ v: torch.Tensor,
623
+ g: torch.Tensor,
624
+ h0: torch.Tensor,
625
+ h: torch.Tensor,
626
+ A: torch.Tensor,
627
+ do: torch.Tensor,
628
+ dht: torch.Tensor,
629
+ dg: torch.Tensor,
630
+ scale: float = 1.,
631
+ cu_seqlens: Optional[torch.LongTensor] = None,
632
+ chunk_size: int = 64
633
+ ):
634
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
635
+ q=q,
636
+ k=k,
637
+ v=v,
638
+ g=None,
639
+ g_cumsum=g,
640
+ scale=scale,
641
+ initial_state=h0,
642
+ h=h,
643
+ A=A,
644
+ do=do,
645
+ dht=dht,
646
+ cu_seqlens=cu_seqlens,
647
+ chunk_size=chunk_size
648
+ )
649
+ return dq, dk, dv, dg, dh0
650
+
651
+
652
+ def chunk_gsa_bwd_k(
653
+ q: torch.Tensor,
654
+ k: torch.Tensor,
655
+ v: torch.Tensor,
656
+ g: torch.Tensor,
657
+ h: torch.Tensor,
658
+ h0: torch.Tensor,
659
+ o: torch.Tensor,
660
+ do: torch.Tensor,
661
+ dht: torch.Tensor,
662
+ dg: torch.Tensor,
663
+ scale: float = 1.,
664
+ cu_seqlens: Optional[torch.LongTensor] = None,
665
+ chunk_size: int = 64
666
+ ):
667
+ B, T, H, K, V = *k.shape, v.shape[-1]
668
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
669
+ BC = min(16, BT)
670
+ BK = min(64, triton.next_power_of_2(K))
671
+ BV = min(64, triton.next_power_of_2(V))
672
+ HQ = q.shape[2]
673
+
674
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
675
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
676
+ NC = triton.cdiv(BT, BC)
677
+ NK = triton.cdiv(K, BK)
678
+ NV = triton.cdiv(V, BV)
679
+ NG = HQ // H
680
+
681
+ if h is None:
682
+ h, _ = chunk_fwd_h(
683
+ k=k,
684
+ v=v,
685
+ g=None,
686
+ gk=None,
687
+ gv=g,
688
+ h0=h0,
689
+ output_final_state=False,
690
+ cu_seqlens=cu_seqlens,
691
+ chunk_size=BT,
692
+ states_in_fp32=False
693
+ )
694
+ dh, dh0 = chunk_bwd_dh(
695
+ q=q,
696
+ k=k,
697
+ v=v,
698
+ g=None,
699
+ gk=None,
700
+ gv=g,
701
+ do=do,
702
+ h0=h0,
703
+ dht=dht,
704
+ scale=scale,
705
+ cu_seqlens=cu_seqlens,
706
+ chunk_size=BT,
707
+ states_in_fp32=True
708
+ )
709
+ dA = q.new_empty(NV, B, T, HQ, BT)
710
+ grid = (NV, NT * NC * NC, B * HQ)
711
+ chunk_gsa_bwd_k_kernel_dA[grid](
712
+ v,
713
+ g,
714
+ do,
715
+ dA,
716
+ cu_seqlens=cu_seqlens,
717
+ chunk_indices=chunk_indices,
718
+ scale=scale,
719
+ T=T,
720
+ B=B,
721
+ HQ=HQ,
722
+ H=H,
723
+ V=V,
724
+ BT=BT,
725
+ BC=BC,
726
+ BV=BV,
727
+ NC=NC,
728
+ NG=NG,
729
+ )
730
+ dA = dA.sum(0, dtype=dA.dtype)
731
+
732
+ A = do.new_empty(NK, B, T, HQ, BT)
733
+ dq = torch.empty_like(q)
734
+ dk = k.new_empty(B, T, HQ, K)
735
+ dv = v.new_empty(NK, B, T, HQ, V)
736
+ dgv = g.new_empty(NK, B, T, HQ, V, dtype=torch.float)
737
+ grid = (NK, NT, B * HQ)
738
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
739
+ q,
740
+ k,
741
+ v,
742
+ h,
743
+ g,
744
+ A,
745
+ do,
746
+ dh,
747
+ dq,
748
+ dk,
749
+ dv,
750
+ dg,
751
+ dgv,
752
+ dA,
753
+ cu_seqlens=cu_seqlens,
754
+ chunk_indices=chunk_indices,
755
+ scale=scale,
756
+ T=T,
757
+ B=B,
758
+ HQ=HQ,
759
+ H=H,
760
+ K=K,
761
+ V=V,
762
+ BT=BT,
763
+ BK=BK,
764
+ BV=BV,
765
+ NG=NG,
766
+ )
767
+ A = A.sum(0, dtype=A.dtype)
768
+ dv = dv.sum(0, dtype=dv.dtype)
769
+ dgv = dgv.sum(0, dtype=dgv.dtype)
770
+
771
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
772
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
773
+ v,
774
+ g,
775
+ o,
776
+ A,
777
+ do,
778
+ dv,
779
+ dg,
780
+ cu_seqlens=cu_seqlens,
781
+ chunk_indices=chunk_indices,
782
+ T=T,
783
+ HQ=HQ,
784
+ H=H,
785
+ V=V,
786
+ BT=BT,
787
+ BC=BC,
788
+ BV=BV,
789
+ NC=NC,
790
+ NG=NG,
791
+ num_warps=4,
792
+ num_stages=2
793
+ )
794
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, cu_seqlens=cu_seqlens))
795
+
796
+ return dq, dk, dv, dg, dh0
797
+
798
+
799
+ def chunk_gsa_fwd(
800
+ q: torch.Tensor,
801
+ k: torch.Tensor,
802
+ v: torch.Tensor,
803
+ s: torch.Tensor,
804
+ g: torch.Tensor,
805
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
806
+ output_final_state: bool = False,
807
+ scale: float = 1.,
808
+ cu_seqlens: Optional[torch.LongTensor] = None,
809
+ chunk_size: int = 64
810
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
811
+ hk0, hv0 = None, None
812
+ if initial_state is not None:
813
+ hk0, hv0 = initial_state
814
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
815
+ q=q,
816
+ k=k,
817
+ v=s,
818
+ g=g,
819
+ h0=hk0,
820
+ output_final_state=output_final_state,
821
+ scale=scale,
822
+ cu_seqlens=cu_seqlens,
823
+ chunk_size=chunk_size
824
+ )
825
+
826
+ # p is kept in fp32 for safe softmax backward
827
+ p = softmax_fwd(ok, dtype=torch.float)
828
+
829
+ qv = p.to(q.dtype)
830
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
831
+ q=qv,
832
+ k=s,
833
+ v=v,
834
+ g=g,
835
+ scale=1.,
836
+ initial_state=hv0,
837
+ output_final_state=output_final_state,
838
+ cu_seqlens=cu_seqlens,
839
+ chunk_size=chunk_size
840
+ )
841
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
842
+
843
+
844
+ def chunk_gsa_bwd(
845
+ q: torch.Tensor,
846
+ k: torch.Tensor,
847
+ v: torch.Tensor,
848
+ s: torch.Tensor,
849
+ g: torch.Tensor,
850
+ ok: torch.Tensor,
851
+ p: torch.Tensor,
852
+ A: Tuple[torch.Tensor, torch.Tensor],
853
+ h: Tuple[torch.Tensor, torch.Tensor],
854
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
855
+ scale: float,
856
+ do: torch.Tensor,
857
+ dht: Tuple[torch.Tensor, torch.Tensor],
858
+ cu_seqlens: Optional[torch.LongTensor] = None,
859
+ chunk_size: int = 64
860
+ ):
861
+ hk0, hv0 = None, None
862
+ if initial_state is not None:
863
+ hk0, hv0 = initial_state
864
+
865
+ _, Av = A
866
+ hk, hv = h
867
+ dhkt, dhvt = dht
868
+
869
+ qv = p.to(q.dtype)
870
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
871
+ q=qv,
872
+ k=s,
873
+ v=v,
874
+ g=g,
875
+ h0=hv0,
876
+ h=hv,
877
+ A=Av,
878
+ do=do,
879
+ dht=dhvt,
880
+ dg=None,
881
+ scale=1.,
882
+ cu_seqlens=cu_seqlens,
883
+ chunk_size=chunk_size
884
+ )
885
+
886
+ # softmax gradient, equivalent to:
887
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
888
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
889
+
890
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
891
+ q=q,
892
+ k=k,
893
+ v=s,
894
+ g=g,
895
+ h0=hk0,
896
+ h=hk,
897
+ o=ok,
898
+ do=dok,
899
+ dht=dhkt,
900
+ dg=dg,
901
+ scale=scale,
902
+ cu_seqlens=cu_seqlens,
903
+ chunk_size=chunk_size
904
+ )
905
+
906
+ ds = dsv.add_(dsk)
907
+ if q.shape[1] != k.shape[1]:
908
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
909
+ dg = dg.to(s.dtype)
910
+ return dq, dk, dv, ds, dg, dhk0, dhv0
911
+
912
+
913
+ class ChunkGSAFunction(torch.autograd.Function):
914
+
915
+ @staticmethod
916
+ @input_guard
917
+ def forward(
918
+ ctx,
919
+ q: torch.Tensor,
920
+ k: torch.Tensor,
921
+ v: torch.Tensor,
922
+ s: torch.Tensor,
923
+ g: torch.Tensor,
924
+ scale: float,
925
+ hk0: Optional[torch.Tensor],
926
+ hv0: Optional[torch.Tensor],
927
+ output_final_state: bool,
928
+ checkpoint_level: int,
929
+ cu_seqlens: Optional[torch.LongTensor],
930
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
931
+ T = q.shape[1]
932
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
933
+
934
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens)
935
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
936
+ q=q,
937
+ k=k,
938
+ v=v,
939
+ s=s,
940
+ g=g,
941
+ initial_state=(hk0, hv0),
942
+ output_final_state=output_final_state,
943
+ scale=scale,
944
+ cu_seqlens=cu_seqlens,
945
+ chunk_size=chunk_size
946
+ )
947
+
948
+ if checkpoint_level >= 1:
949
+ del g
950
+ g = g_org
951
+ if checkpoint_level > 1:
952
+ del hk
953
+ del hv
954
+ hk, hv = None, None
955
+ else:
956
+ hk0, hv0 = None, None
957
+
958
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
959
+ ctx.checkpoint_level = checkpoint_level
960
+ ctx.scale = scale
961
+ ctx.cu_seqlens = cu_seqlens
962
+ ctx.chunk_size = chunk_size
963
+ return ov, hkt, hvt
964
+
965
+ @staticmethod
966
+ @input_guard
967
+ def backward(ctx, dov, dhkt=None, dhvt=None):
968
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
969
+ scale = ctx.scale
970
+ cu_seqlens = ctx.cu_seqlens
971
+ chunk_size = ctx.chunk_size
972
+
973
+ if ctx.checkpoint_level >= 1:
974
+ g = chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens)
975
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
976
+ q=q,
977
+ k=k,
978
+ v=v,
979
+ s=s,
980
+ g=g,
981
+ ok=ok,
982
+ p=p,
983
+ A=(None, Av),
984
+ h=(hk, hv),
985
+ initial_state=(hk0, hv0),
986
+ scale=scale,
987
+ do=dov,
988
+ dht=(dhkt, dhvt),
989
+ cu_seqlens=cu_seqlens,
990
+ chunk_size=chunk_size
991
+ )
992
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
993
+
994
+
995
+ @torch.compiler.disable
996
+ def chunk_gsa(
997
+ q: torch.Tensor,
998
+ k: torch.Tensor,
999
+ v: torch.Tensor,
1000
+ s: torch.Tensor,
1001
+ g: Optional[torch.Tensor] = None,
1002
+ scale: Optional[int] = None,
1003
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1004
+ output_final_state: Optional[bool] = False,
1005
+ checkpoint_level: Optional[int] = 2,
1006
+ cu_seqlens: Optional[torch.LongTensor] = None,
1007
+ head_first: Optional[bool] = False
1008
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1009
+ r"""
1010
+ Args:
1011
+ q (torch.Tensor):
1012
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, H, T, K]`.
1013
+ k (torch.Tensor):
1014
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1015
+ GQA is performed if `H` is not equal to `HQ`.
1016
+ v (torch.Tensor):
1017
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1018
+ s (torch.Tensor):
1019
+ slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]`.
1020
+ g (torch.Tensor):
1021
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1022
+ If not provided, this function is equivalent to vanilla ABC.
1023
+ scale (Optional[int]):
1024
+ Scale factor for attention scores.
1025
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1026
+ initial_state (Optional[Tuple[torch.Tensor]]):
1027
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1028
+ For equal-length input sequences, `N` equals the batch size `B`.
1029
+ Default: `None`.
1030
+ output_final_state (Optional[bool]):
1031
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1032
+ Default: `False`.
1033
+ checkpoint_level (Optional[int]):
1034
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1035
+ Default: `2`:
1036
+ - Level `0`: no memory saved, no recomputation.
1037
+ - Level `1`: recompute the fp32 cumulative values during backward.
1038
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1039
+ cu_seqlens (torch.LongTensor):
1040
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1041
+ consistent with the FlashAttention API.
1042
+ head_first (Optional[bool]):
1043
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1044
+ Default: `False`.
1045
+
1046
+ Returns:
1047
+ o (torch.Tensor):
1048
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1049
+ final_state (Tuple[torch.Tensor]):
1050
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1051
+ `None` otherwise.
1052
+
1053
+ Examples::
1054
+ >>> import torch
1055
+ >>> import torch.nn.functional as F
1056
+ >>> from einops import rearrange
1057
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1058
+ # inputs with equal lengths
1059
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1060
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1061
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1062
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1063
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1064
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1065
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1066
+ >>> o, (hk, hv) = chunk_gsa(
1067
+ q, k, v, s, g,
1068
+ initial_state=h0,
1069
+ output_final_state=True
1070
+ )
1071
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1072
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1073
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1074
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1075
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(
1076
+ q, k, v, s, g,
1077
+ initial_state=h0,
1078
+ output_final_state=True,
1079
+ cu_seqlens=cu_seqlens
1080
+ )
1081
+ >>> assert o.allclose(o_var.view(o.shape))
1082
+ >>> assert hk.allclose(hk_var)
1083
+ >>> assert hv.allclose(hv_var)
1084
+ """
1085
+ if head_first:
1086
+ raise DeprecationWarning(
1087
+ "head_first is deprecated and will be removed in a future version. "
1088
+ "Please use head_first=False for now instead."
1089
+ )
1090
+ q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g))
1091
+ if not head_first and q.shape[1] < q.shape[2]:
1092
+ warnings.warn(
1093
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
1094
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
1095
+ "when head_first=False was specified. "
1096
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
1097
+ )
1098
+ if cu_seqlens is not None:
1099
+ if q.shape[0] != 1:
1100
+ raise ValueError(
1101
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1102
+ f"Please flatten variable-length inputs before processing."
1103
+ )
1104
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1105
+ raise ValueError(
1106
+ f"The number of initial states is expected to be equal to the number of input sequences, "
1107
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}."
1108
+ )
1109
+ assert checkpoint_level in [0, 1, 2]
1110
+ if g is None:
1111
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1112
+ z = s.float().logcumsumexp(2)
1113
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 1) - z
1114
+ s = torch.exp(s - z).to(k.dtype)
1115
+ if scale is None:
1116
+ scale = q.shape[-1] ** -0.5
1117
+
1118
+ hk0, hv0 = None, None
1119
+ if initial_state is not None:
1120
+ hk0, hv0 = initial_state
1121
+ o, *final_state = ChunkGSAFunction.apply(
1122
+ q,
1123
+ k,
1124
+ v,
1125
+ s,
1126
+ g,
1127
+ scale,
1128
+ hk0,
1129
+ hv0,
1130
+ output_final_state,
1131
+ checkpoint_level,
1132
+ cu_seqlens
1133
+ )
1134
+ if head_first:
1135
+ o = rearrange(o, 'b h t ... -> b t h ...')
1136
+ return o, final_state
fla3/ops/gsa/fused_recurrent.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.fused_recurrent import fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel
11
+ from fla.ops.utils import chunk_global_cumsum
12
+ from fla.ops.utils.op import exp
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
14
+
15
+
16
+ @triton.jit
17
+ def fused_recurrent_gsa_inference_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ s,
22
+ g,
23
+ o,
24
+ hk0,
25
+ hv0,
26
+ hkt,
27
+ hvt,
28
+ scale,
29
+ K: tl.constexpr,
30
+ V: tl.constexpr,
31
+ M: tl.constexpr,
32
+ BK: tl.constexpr,
33
+ BV: tl.constexpr,
34
+ NG: tl.constexpr
35
+ ):
36
+ i_bh = tl.program_id(0)
37
+ i_bg = i_bh // NG
38
+
39
+ b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32)
40
+ b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32)
41
+ b_g = exp(b_g)
42
+
43
+ b_ok = tl.zeros([M], dtype=tl.float32)
44
+ for i_k in range(tl.cdiv(K, BK)):
45
+ o_k = i_k * BK + tl.arange(0, BK)
46
+
47
+ p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None]
48
+ # [BK,]
49
+ mask_k = o_k < K
50
+ # [M, BK]
51
+ mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :]
52
+ # [M, BK]
53
+ b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32)
54
+ # [BK,]
55
+ b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale
56
+ b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32)
57
+ b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None]
58
+ b_ok += tl.sum(b_hk * b_q[None, :], axis=1)
59
+
60
+ if i_bh % NG == 0:
61
+ p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None]
62
+ tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk)
63
+
64
+ b_qv = tl.softmax(b_ok)
65
+ for i_v in range(tl.cdiv(V, BV)):
66
+ o_v = i_v * BV + tl.arange(0, BV)
67
+
68
+ p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
69
+ # [BV,]
70
+ mask_v = o_v < V
71
+ # [BV, M]
72
+ mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :]
73
+ # [BV, M]
74
+ b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32)
75
+ # [BV,]
76
+ b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32)
77
+ b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None]
78
+ b_ov = tl.sum(b_hv * b_qv[None, :], axis=1)
79
+
80
+ tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v)
81
+
82
+ if i_bh % NG == 0:
83
+ p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
84
+ tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv)
85
+
86
+
87
+ def fused_recurrent_gsa_inference(
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ v: torch.Tensor,
91
+ s: torch.Tensor,
92
+ g: torch.Tensor,
93
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
94
+ output_final_state: bool = False,
95
+ scale: float = 1.,
96
+ ) -> torch.Tensor:
97
+ B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
98
+ HQ = q.shape[2]
99
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
100
+ NG = HQ // H
101
+
102
+ if initial_state != (None, None) and initial_state is not None:
103
+ hk0, hv0 = initial_state
104
+ else:
105
+ hk0, hv0 = q.new_zeros(B, H, K, M, dtype=torch.float), q.new_zeros(B, H, M, V, dtype=torch.float)
106
+
107
+ hkt, hvt = None, None
108
+ if output_final_state:
109
+ if NG == 1:
110
+ hkt, hvt = hk0, hv0
111
+ else:
112
+ hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float)
113
+
114
+ o = v.new_empty(B, T, HQ, V)
115
+ grid = (B * HQ,)
116
+ fused_recurrent_gsa_inference_kernel[grid](
117
+ q,
118
+ k,
119
+ v,
120
+ s,
121
+ g,
122
+ o,
123
+ hk0,
124
+ hv0,
125
+ hkt,
126
+ hvt,
127
+ scale=scale,
128
+ K=K,
129
+ V=V,
130
+ M=M,
131
+ BK=BK,
132
+ BV=BV,
133
+ NG=NG
134
+ )
135
+ return o, (hkt, hvt)
136
+
137
+
138
+ def fused_recurrent_gsa_fwd(
139
+ q: torch.Tensor,
140
+ k: torch.Tensor,
141
+ v: torch.Tensor,
142
+ s: torch.Tensor,
143
+ g: torch.Tensor,
144
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
145
+ output_final_state: bool = False,
146
+ scale: float = 1.,
147
+ reverse: bool = False,
148
+ cu_seqlens: Optional[torch.LongTensor] = None,
149
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
150
+ B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
151
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
152
+ HQ = q.shape[2]
153
+ if HQ != H:
154
+ raise ValueError("GQA not supported yet.")
155
+
156
+ BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64)
157
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
158
+
159
+ hk0, hv0 = None, None
160
+ if initial_state != (None, None) and initial_state is not None:
161
+ hk0, hv0 = initial_state
162
+ hkt, hvt = None, None
163
+ if output_final_state:
164
+ hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float)
165
+
166
+ ok = q.new_empty(NK, *s.shape, dtype=torch.float)
167
+ gk, gv = None, g
168
+ grid = (NM, NK, N * H)
169
+ fused_recurrent_fwd_kernel[grid](
170
+ q=q,
171
+ k=k,
172
+ v=s,
173
+ g=None,
174
+ gk=gk,
175
+ gv=gv,
176
+ o=ok,
177
+ h0=hk0,
178
+ ht=hkt,
179
+ cu_seqlens=cu_seqlens,
180
+ scale=scale,
181
+ B=B,
182
+ T=T,
183
+ H=H,
184
+ K=K,
185
+ V=M,
186
+ BK=BK,
187
+ BV=BM,
188
+ USE_G=False,
189
+ USE_GK=False,
190
+ USE_GV=True,
191
+ REVERSE=reverse
192
+ )
193
+ ok = ok.sum(0)
194
+
195
+ qv = ok.softmax(-1, dtype=torch.float)
196
+ ov = q.new_empty(NM, *v.shape, dtype=torch.float)
197
+ gk, gv = g, None
198
+ grid = (NV, NM, N * H)
199
+ fused_recurrent_fwd_kernel[grid](
200
+ q=qv,
201
+ k=s,
202
+ v=v,
203
+ g=None,
204
+ gk=gk,
205
+ gv=gv,
206
+ o=ov,
207
+ h0=hv0,
208
+ ht=hvt,
209
+ cu_seqlens=cu_seqlens,
210
+ scale=1.,
211
+ B=B,
212
+ T=T,
213
+ H=H,
214
+ K=M,
215
+ V=V,
216
+ BK=BM,
217
+ BV=BV,
218
+ USE_G=False,
219
+ USE_GK=True,
220
+ USE_GV=False,
221
+ REVERSE=reverse,
222
+ )
223
+ ov = ov.sum(0)
224
+ return ok, hkt, qv, ov, hvt
225
+
226
+
227
+ def fused_recurrent_gsa_bwd(
228
+ q: torch.Tensor,
229
+ k: torch.Tensor,
230
+ v: torch.Tensor,
231
+ s: torch.Tensor,
232
+ g: torch.Tensor,
233
+ qv: torch.Tensor,
234
+ hk0: Optional[torch.Tensor] = None,
235
+ hv0: Optional[torch.Tensor] = None,
236
+ ok: Optional[torch.Tensor] = None,
237
+ do: Optional[torch.Tensor] = None,
238
+ dhkt: Optional[torch.Tensor] = None,
239
+ dhvt: Optional[torch.Tensor] = None,
240
+ scale: float = 1.,
241
+ reverse: bool = False,
242
+ cu_seqlens: Optional[torch.LongTensor] = None,
243
+ ) -> Tuple[torch.Tensor]:
244
+ B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
245
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
246
+
247
+ BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)
248
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
249
+
250
+ dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float)
251
+ dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float)
252
+ dv = q.new_empty(NM, B, T, H, V, dtype=torch.float)
253
+ dhk0 = torch.empty_like(hk0)if hk0 is not None else None
254
+ dhv0 = torch.empty_like(hv0)if hv0 is not None else None
255
+
256
+ gk, gv = g, None
257
+ grid = (NV, NM, N * H)
258
+ fused_recurrent_bwd_kernel[grid](
259
+ q=qv,
260
+ k=s,
261
+ v=v,
262
+ g=None,
263
+ gk=gk,
264
+ gv=gv,
265
+ h0=hv0,
266
+ do=do,
267
+ dq=dqv,
268
+ dk=dsv,
269
+ dv=dv,
270
+ dht=dhvt,
271
+ dh0=dhv0,
272
+ cu_seqlens=cu_seqlens,
273
+ scale=1.,
274
+ B=B,
275
+ T=T,
276
+ H=H,
277
+ K=M,
278
+ V=V,
279
+ BK=BM,
280
+ BV=BV,
281
+ USE_G=False,
282
+ USE_GK=True,
283
+ USE_GV=False,
284
+ REVERSE=reverse,
285
+ )
286
+ dqv = dqv.sum(0)
287
+ dsv = dsv.sum(0)
288
+ dv = dv.sum(0)
289
+ dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens)
290
+
291
+ dok = qv * (dqv - (qv * dqv).sum(-1, True))
292
+ dq = q.new_empty(NM, B, T, H, K, dtype=torch.float)
293
+ dk = q.new_empty(NM, B, T, H, K, dtype=torch.float)
294
+ dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float)
295
+ gk, gv = None, g
296
+ grid = (NM, NK, N * H)
297
+ fused_recurrent_bwd_kernel[grid](
298
+ q=q,
299
+ k=k,
300
+ v=s,
301
+ g=None,
302
+ gk=gk,
303
+ gv=gv,
304
+ h0=hk0,
305
+ do=dok,
306
+ dq=dq,
307
+ dk=dk,
308
+ dv=dsk,
309
+ dht=dhkt,
310
+ dh0=dhk0,
311
+ cu_seqlens=cu_seqlens,
312
+ scale=scale,
313
+ B=B,
314
+ T=T,
315
+ H=H,
316
+ K=K,
317
+ V=M,
318
+ BK=BK,
319
+ BV=BM,
320
+ USE_G=False,
321
+ USE_GK=False,
322
+ USE_GV=True,
323
+ REVERSE=reverse,
324
+ )
325
+ dq = dq.sum(0)
326
+ dk = dk.sum(0)
327
+ dsk = dsk.sum(0)
328
+
329
+ dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens)
330
+
331
+ ds = dsk.add_(dsv)
332
+ dg = dgk.add_(dgv)
333
+
334
+ return dq, dk, dv, ds, dg, dhk0, dhv0
335
+
336
+
337
+ class FusedRecurrentGSAFunction(torch.autograd.Function):
338
+
339
+ @staticmethod
340
+ @input_guard
341
+ @autocast_custom_fwd
342
+ def forward(
343
+ ctx,
344
+ q: torch.Tensor,
345
+ k: torch.Tensor,
346
+ v: torch.Tensor,
347
+ s: torch.Tensor,
348
+ g: torch.Tensor,
349
+ scale: Optional[float] = None,
350
+ hk0: Optional[torch.Tensor] = None,
351
+ hv0: Optional[torch.Tensor] = None,
352
+ output_final_state: bool = False,
353
+ reverse: bool = False,
354
+ cu_seqlens: Optional[torch.LongTensor] = None,
355
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
356
+ T = q.shape[1]
357
+ if T == 1 and not q.requires_grad:
358
+ o, (hkt, hvt) = fused_recurrent_gsa_inference(
359
+ q=q,
360
+ k=k,
361
+ v=v,
362
+ s=s,
363
+ g=g,
364
+ initial_state=(hk0, hv0),
365
+ output_final_state=output_final_state,
366
+ scale=scale,
367
+ )
368
+ return o, hkt, hvt
369
+ ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd(
370
+ q=q,
371
+ k=k,
372
+ v=v,
373
+ s=s,
374
+ g=g,
375
+ initial_state=(hk0, hv0),
376
+ output_final_state=output_final_state,
377
+ scale=scale,
378
+ reverse=reverse,
379
+ cu_seqlens=cu_seqlens,
380
+ )
381
+ ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok)
382
+ ctx.scale = scale
383
+ ctx.reverse = reverse
384
+ ctx.cu_seqlens = cu_seqlens
385
+ return ov.to(q.dtype), hkt, hvt
386
+
387
+ @staticmethod
388
+ @input_guard
389
+ @autocast_custom_bwd
390
+ def backward(ctx, do, dhkt=None, dhvt=None):
391
+ q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors
392
+ scale = ctx.scale
393
+ reverse = ctx.reverse
394
+ cu_seqlens = ctx.cu_seqlens
395
+
396
+ # not supported yet.
397
+ if dhkt is not None or dhvt is not None:
398
+ if g is not None:
399
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
400
+ dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd(
401
+ q=q,
402
+ k=k,
403
+ v=v,
404
+ s=s,
405
+ g=g,
406
+ qv=qv,
407
+ hk0=hk0,
408
+ hv0=hv0,
409
+ ok=ok,
410
+ do=do,
411
+ dhkt=dhkt,
412
+ dhvt=dhvt,
413
+ scale=scale,
414
+ reverse=reverse,
415
+ cu_seqlens=cu_seqlens,
416
+ )
417
+ return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None
418
+
419
+
420
+ def fused_recurrent_gsa(
421
+ q: torch.Tensor,
422
+ k: torch.Tensor,
423
+ v: torch.Tensor,
424
+ s: torch.Tensor,
425
+ g: Optional[torch.Tensor] = None,
426
+ scale: Optional[int] = None,
427
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
428
+ output_final_state: Optional[bool] = False,
429
+ reverse: Optional[bool] = False,
430
+ cu_seqlens: Optional[torch.LongTensor] = None,
431
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
432
+ r"""
433
+ Args:
434
+ q (torch.Tensor):
435
+ queries of shape `[B, T, H, K]`.
436
+ k (torch.Tensor):
437
+ keys of shape `[B, T, H, K]`.
438
+ v (torch.Tensor):
439
+ values of shape `[B, T, H, V]`.
440
+ s (torch.Tensor):
441
+ slot representations of shape `[B, T, H, M]`.
442
+ g (torch.Tensor):
443
+ Forget gates of shape `[B, H, T, M]` applied to keys.
444
+ scale (Optional[int]):
445
+ Scale factor for the attention scores.
446
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
447
+ initial_state (Optional[Tuple[torch.Tensor]]):
448
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
449
+ For equal-length input sequences, `N` equals the batch size `B`.
450
+ Default: `None`.
451
+ output_final_state (Optional[bool]):
452
+ Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`.
453
+ Default: `False`.
454
+ reverse (Optional[bool]):
455
+ If `True`, process the state passing in reverse order. Default: `False`.
456
+ cu_seqlens (torch.LongTensor):
457
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
458
+ consistent with the FlashAttention API.
459
+
460
+ Returns:
461
+ o (torch.Tensor):
462
+ Outputs of shape `[B, T, H, V]`.
463
+ final_state (Tuple[torch.Tensor]):
464
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
465
+
466
+ Examples::
467
+ >>> import torch
468
+ >>> import torch.nn.functional as F
469
+ >>> from einops import rearrange
470
+ >>> from fla.ops.gsa import fused_recurrent_gsa
471
+ # inputs with equal lengths
472
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
473
+ >>> q = torch.randn(B, T, H, K, device='cuda')
474
+ >>> k = torch.randn(B, T, H, K, device='cuda')
475
+ >>> v = torch.randn(B, T, H, V, device='cuda')
476
+ >>> s = torch.randn(B, T, H, M, device='cuda')
477
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
478
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
479
+ >>> o, (hk, hv) = fused_recurrent_gsa(
480
+ q, k, v, s, g,
481
+ initial_state=h0,
482
+ output_final_state=True
483
+ )
484
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
485
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
486
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
487
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
488
+ >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa(
489
+ q, k, v, s, g,
490
+ initial_state=h0,
491
+ output_final_state=True,
492
+ cu_seqlens=cu_seqlens
493
+ )
494
+ >>> assert o.allclose(o_var.view(o.shape))
495
+ >>> assert hk.allclose(hk_var)
496
+ >>> assert hv.allclose(hv_var)
497
+ """
498
+ if cu_seqlens is not None:
499
+ if q.shape[0] != 1:
500
+ raise ValueError(
501
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
502
+ f"Please flatten variable-length inputs before processing."
503
+ )
504
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
505
+ raise ValueError(
506
+ f"The number of initial states is expected to be equal to the number of input sequences, "
507
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}."
508
+ )
509
+ if scale is None:
510
+ scale = k.shape[-1] ** -0.5
511
+ if initial_state is None:
512
+ initial_state = (None, None)
513
+ o, *final_state = FusedRecurrentGSAFunction.apply(
514
+ q,
515
+ k,
516
+ v,
517
+ s,
518
+ g,
519
+ scale,
520
+ *initial_state,
521
+ output_final_state,
522
+ reverse,
523
+ cu_seqlens,
524
+ )
525
+ return o, final_state
fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (308 Bytes). View file
 
fla3/ops/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (323 Bytes). View file
 
fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
fla3/ops/hgrn/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (16.3 kB). View file
 
fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (8.26 kB). View file
 
fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla3/ops/hgrn/chunk.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # this function implements the chunkwise form of HGRN, inspired by
5
+ # [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
6
+ # also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
7
+
8
+ # from tests on H800, with B, D = 16, 128, we see that the chunk can be greatly faster than the recurrent:
9
+ #
10
+ # Performance:
11
+ # seq_len chunk recurrent chunk_bwd recurrent_bwd
12
+ # 0 128.0 0.039360 0.061056 0.312160 0.205008
13
+ # 1 256.0 0.045824 0.123712 0.308784 0.297696
14
+ # 2 512.0 0.058688 0.241952 0.310720 0.626528
15
+ # 3 1024.0 0.088288 0.476992 0.313184 1.333152
16
+ # 4 2048.0 0.169472 0.943264 0.452464 2.724864
17
+ # 5 4096.0 0.329920 1.886144 0.881600 5.551520
18
+ # 6 8192.0 0.647872 3.755040 1.740496 11.117184
19
+ # 7 16384.0 1.272064 7.520576 3.446608 22.362528
20
+
21
+ from typing import Tuple
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+ from fla.ops.utils.op import exp
28
+ from fla.utils import input_guard
29
+
30
+
31
+ @triton.autotune(
32
+ configs=[
33
+ triton.Config({'BD': 32}, num_warps=1),
34
+ triton.Config({'BD': 32}, num_warps=2),
35
+ triton.Config({'BD': 32}, num_warps=4),
36
+ triton.Config({'BD': 32}, num_warps=8),
37
+ triton.Config({'BD': 64}, num_warps=1),
38
+ triton.Config({'BD': 64}, num_warps=2),
39
+ triton.Config({'BD': 64}, num_warps=4),
40
+ triton.Config({'BD': 64}, num_warps=8),
41
+ triton.Config({'BD': 128}, num_warps=1),
42
+ triton.Config({'BD': 128}, num_warps=2),
43
+ triton.Config({'BD': 128}, num_warps=4),
44
+ triton.Config({'BD': 128}, num_warps=8),
45
+ ],
46
+ key=['D']
47
+ )
48
+ @triton.jit(do_not_specialize=['T'])
49
+ def chunk_hgrn_fwd_kernel_h(
50
+ x,
51
+ g,
52
+ gc,
53
+ o,
54
+ h0,
55
+ T,
56
+ D: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BD: tl.constexpr,
59
+ USE_INITIAL_STATE: tl.constexpr
60
+ ):
61
+ i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ o_d = i_d * BD + tl.arange(0, BD)
63
+ mask = o_d < D
64
+
65
+ p_x = x + i_b * T * D + i_t * BT * D + o_d
66
+ p_g = g + i_b * T * D + i_t * BT * D + o_d
67
+ p_gc = gc + i_b * T * D + i_t * BT * D + o_d
68
+ p_o = o + i_b * T * D + i_t * BT * D + o_d
69
+
70
+ b_h = tl.zeros([BD], dtype=tl.float32)
71
+ b_gc = tl.zeros([BD], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ if i_t == 0:
74
+ b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32)
75
+ for i in range(0, BT):
76
+ mask_t = mask & ((i_t * BT + i) < T)
77
+ b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
78
+ b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
79
+ b_h = exp(b_g) * b_h + b_x
80
+ b_gc = b_gc + b_g
81
+ tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
82
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
83
+
84
+ p_x += D
85
+ p_g += D
86
+ p_gc += D
87
+ p_o += D
88
+
89
+
90
+ @triton.jit(do_not_specialize=['T'])
91
+ def chunk_hgrn_fwd_kernel_o(
92
+ gc,
93
+ o,
94
+ s_b,
95
+ s_t,
96
+ s_d,
97
+ T,
98
+ D: tl.constexpr,
99
+ BT: tl.constexpr,
100
+ BD: tl.constexpr
101
+ ):
102
+ i_d, i_b = tl.program_id(0), tl.program_id(1)
103
+ o_d = i_d * BD + tl.arange(0, BD)
104
+ mask = o_d < D
105
+
106
+ for i_t in range(1, tl.cdiv(T, BT)):
107
+ p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
108
+ p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
109
+
110
+ # [BD,]
111
+ b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
112
+ # [BT, BD]
113
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
114
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
115
+ b_o = b_o + exp(b_gc) * b_h0[None, :]
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.autotune(
120
+ configs=[
121
+ triton.Config({'BD': BD}, num_warps=num_warps)
122
+ for BD in [32, 64, 128]
123
+ for num_warps in [1, 2, 4, 8]
124
+ ],
125
+ key=['D']
126
+ )
127
+ @triton.jit(do_not_specialize=['T'])
128
+ def chunk_hgrn_bwd_kernel_h(
129
+ g,
130
+ gc,
131
+ dx,
132
+ do,
133
+ T,
134
+ D: tl.constexpr,
135
+ BT: tl.constexpr,
136
+ BD: tl.constexpr
137
+ ):
138
+ i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
139
+ o_d = i_d * BD + tl.arange(0, BD)
140
+ mask = o_d < D
141
+ BC = min(BT, T - i_t * BT)
142
+ NT = tl.num_programs(1)
143
+
144
+ p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d
145
+ p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d
146
+ p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d
147
+ p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d
148
+
149
+ if i_t == NT - 1:
150
+ b_gc = tl.zeros([BD], dtype=tl.float32)
151
+ else:
152
+ b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
153
+ b_dh = tl.zeros([BD], dtype=tl.float32)
154
+ for _ in range(BC - 1, -1, -1):
155
+ tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
156
+
157
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
158
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
159
+
160
+ b_gc = b_gc + b_g
161
+ b_dh = b_dh + b_do
162
+ b_dx = b_dh
163
+ b_dh = b_dh * exp(b_g)
164
+
165
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
166
+
167
+ p_g -= D
168
+ p_gc -= D
169
+ p_dx -= D
170
+ p_do -= D
171
+
172
+
173
+ @triton.jit(do_not_specialize=['T'])
174
+ def chunk_hgrn_bwd_kernel_o(
175
+ g,
176
+ gc,
177
+ o,
178
+ dx,
179
+ dg,
180
+ s_b,
181
+ s_t,
182
+ s_d,
183
+ T,
184
+ D: tl.constexpr,
185
+ BT: tl.constexpr,
186
+ BD: tl.constexpr
187
+ ):
188
+ i_d, i_b = tl.program_id(0), tl.program_id(1)
189
+ o_d = i_d * BD + tl.arange(0, BD)
190
+ mask = o_d < D
191
+
192
+ for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
193
+ p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
194
+ p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
195
+ p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
196
+ p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
197
+ p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
198
+
199
+ # [BD,]
200
+ mask_t = mask & ((i_t + 1) * BT < T)
201
+ b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
202
+ # [BT, BD]
203
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
204
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
205
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
206
+ b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
207
+
208
+ b_dx = b_dx + exp(b_gc) * b_ht[None, :]
209
+ b_dg = b_o * b_dx * exp(b_g)
210
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
212
+
213
+
214
+ class ChunkHGRNFunction(torch.autograd.Function):
215
+
216
+ @staticmethod
217
+ @input_guard
218
+ def forward(ctx, x, g, initial_state=None, output_final_state=False):
219
+ B, T, D = x.shape
220
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
221
+ num_warps = 8 if BD == 64 else 4
222
+
223
+ gc = torch.empty_like(g, dtype=torch.float)
224
+ o = torch.empty_like(x, dtype=torch.float)
225
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B)
226
+ chunk_hgrn_fwd_kernel_h[grid](
227
+ x, g, gc, o, initial_state,
228
+ T=T, D=D, BT=BT,
229
+ USE_INITIAL_STATE=initial_state is not None
230
+ )
231
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B)
232
+ chunk_hgrn_fwd_kernel_o[grid](
233
+ gc, o,
234
+ o.stride(-3), o.stride(-2), o.stride(-1),
235
+ T=T, D=D, BT=BT, BD=BD,
236
+ num_warps=num_warps
237
+ )
238
+ final_state = None
239
+ if output_final_state:
240
+ final_state = o[:, -1].clone()
241
+ o = o.to(x.dtype)
242
+ ctx.save_for_backward(g, o, initial_state)
243
+ return o, final_state
244
+
245
+ @staticmethod
246
+ @input_guard
247
+ def backward(ctx, do, dht=None):
248
+ g, o, initial_state = ctx.saved_tensors
249
+ B, T, D = do.shape
250
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
251
+ num_warps = 8 if BD == 64 else 4
252
+
253
+ gc = torch.empty_like(g, dtype=torch.float)
254
+ dx = torch.empty_like(o, dtype=torch.float)
255
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B)
256
+ chunk_hgrn_bwd_kernel_h[grid](
257
+ g, gc, dx, do,
258
+ T=T, D=D, BT=BT
259
+ )
260
+
261
+ dg = torch.empty_like(g, dtype=torch.float)
262
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B)
263
+ chunk_hgrn_bwd_kernel_o[grid](
264
+ g, gc, o, dx, dg,
265
+ o.stride(-3), o.stride(-2), o.stride(-1),
266
+ T=T, D=D, BT=BT, BD=BD,
267
+ num_warps=num_warps
268
+ )
269
+ if initial_state is not None:
270
+ dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype)
271
+
272
+ return dx.to(o.dtype), dg, None, None
273
+
274
+
275
+ @torch.compiler.disable
276
+ def chunk_hgrn(
277
+ x: torch.Tensor,
278
+ g: torch.Tensor,
279
+ initial_state: torch.Tensor = None,
280
+ output_final_state: bool = False
281
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
282
+ return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
fla3/ops/hgrn/fused_recurrent.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BD': BD}, num_warps=num_warps)
22
+ for BD in [32, 64, 128]
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['D']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_hgrn_fwd_kernel(
29
+ x,
30
+ g,
31
+ o,
32
+ h0,
33
+ ht,
34
+ cu_seqlens,
35
+ T,
36
+ D: tl.constexpr,
37
+ BD: tl.constexpr,
38
+ USE_INITIAL_STATE: tl.constexpr,
39
+ STORE_FINAL_STATE: tl.constexpr,
40
+ IS_VARLEN: tl.constexpr
41
+ ):
42
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
43
+ if IS_VARLEN:
44
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_n * T, i_n * T + T
48
+
49
+ o_d = i_d * BD + tl.arange(0, BD)
50
+ mask = o_d < D
51
+
52
+ p_x = x + bos * D + o_d
53
+ p_g = g + bos * D + o_d
54
+ p_o = o + bos * D + o_d
55
+
56
+ b_h = tl.zeros([BD], dtype=tl.float32)
57
+ if USE_INITIAL_STATE:
58
+ p_h0 = h0 + i_n * D + o_d
59
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
62
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
63
+ b_h = exp(b_g) * b_h + b_x
64
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
65
+
66
+ p_x += D
67
+ p_g += D
68
+ p_o += D
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_n * D + o_d
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
73
+
74
+
75
+ @triton.heuristics({
76
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
77
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
78
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
79
+ })
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config({'BD': BD}, num_warps=num_warps)
83
+ for BD in [32, 64, 128]
84
+ for num_warps in [1, 2, 4, 8]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ h0,
93
+ dx,
94
+ dg,
95
+ do,
96
+ dht,
97
+ dh0,
98
+ cu_seqlens,
99
+ T,
100
+ D: tl.constexpr,
101
+ BD: tl.constexpr,
102
+ USE_INITIAL_STATE: tl.constexpr,
103
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
104
+ IS_VARLEN: tl.constexpr
105
+ ):
106
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
107
+ if IS_VARLEN:
108
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_n * T, i_n * T + T
112
+
113
+ o_d = i_d * BD + tl.arange(0, BD)
114
+ mask = o_d < D
115
+
116
+ p_g = g + (bos + T - 1) * D + o_d
117
+ p_o = o + (bos + T - 2) * D + o_d
118
+ p_dx = dx + (bos + T - 1) * D + o_d
119
+ p_dg = dg + (bos + T - 1) * D + o_d
120
+ p_do = do + (bos + T - 1) * D + o_d
121
+
122
+ b_dh = tl.zeros([BD], dtype=tl.float32)
123
+ if USE_FINAL_STATE_GRADIENT:
124
+ p_dht = dht + i_n * D + o_d
125
+ b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
126
+
127
+ for i in range(T - 1, -1, -1):
128
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
129
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
130
+ if i > 0:
131
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
132
+ elif USE_INITIAL_STATE:
133
+ b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
134
+ else:
135
+ b_o = tl.zeros([BD], dtype=tl.float32)
136
+
137
+ b_dh = b_dh + b_do
138
+ b_dx = b_dh
139
+ b_dh = b_dh * exp(b_g)
140
+ b_dg = b_dh * b_o
141
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
142
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
143
+
144
+ p_g -= D
145
+ p_o -= D
146
+ p_dx -= D
147
+ p_dg -= D
148
+ p_do -= D
149
+
150
+ if USE_INITIAL_STATE:
151
+ p_dh0 = dh0 + i_n * D + o_d
152
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
153
+
154
+
155
+ def fused_recurrent_hgrn_fwd(
156
+ x: torch.Tensor,
157
+ g: torch.Tensor,
158
+ initial_state: torch.Tensor = None,
159
+ output_final_state: bool = False,
160
+ cu_seqlens: Optional[torch.LongTensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ B, T, D = x.shape
163
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
164
+
165
+ o = torch.empty_like(x)
166
+ final_state = x.new_empty(N, D) if output_final_state else None
167
+
168
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
169
+ fused_recurrent_hgrn_fwd_kernel[grid](
170
+ x=x,
171
+ g=g,
172
+ o=o,
173
+ h0=initial_state,
174
+ ht=final_state,
175
+ cu_seqlens=cu_seqlens,
176
+ T=T,
177
+ D=D
178
+ )
179
+ return o, final_state
180
+
181
+
182
+ def fused_recurrent_hgrn_bwd(
183
+ g: torch.Tensor,
184
+ o: torch.Tensor,
185
+ do: torch.Tensor,
186
+ dht: torch.Tensor = None,
187
+ initial_state: torch.Tensor = None,
188
+ cu_seqlens: Optional[torch.LongTensor] = None
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ B, T, D = do.shape
191
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
192
+
193
+ dx = torch.empty_like(o, dtype=torch.float)
194
+ dg = torch.empty_like(g, dtype=torch.float)
195
+ dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
196
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
197
+ fused_recurrent_hgrn_bwd_kernel[grid](
198
+ g=g,
199
+ o=o,
200
+ h0=initial_state,
201
+ dx=dx,
202
+ dg=dg,
203
+ do=do,
204
+ dht=dht,
205
+ dh0=dh0,
206
+ cu_seqlens=cu_seqlens,
207
+ T=T,
208
+ D=D
209
+ )
210
+ return dx, dg, dh0
211
+
212
+
213
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
214
+
215
+ @staticmethod
216
+ @input_guard
217
+ def forward(
218
+ ctx,
219
+ x: torch.Tensor,
220
+ g: torch.Tensor,
221
+ initial_state: torch.Tensor = None,
222
+ output_final_state: bool = False,
223
+ cu_seqlens: Optional[torch.LongTensor] = None
224
+ ):
225
+ o, ht = fused_recurrent_hgrn_fwd(
226
+ x=x,
227
+ g=g,
228
+ initial_state=initial_state,
229
+ output_final_state=output_final_state,
230
+ cu_seqlens=cu_seqlens
231
+ )
232
+ ctx.save_for_backward(g, o, initial_state)
233
+ ctx.cu_seqlens = cu_seqlens
234
+ return o, ht
235
+
236
+ @staticmethod
237
+ @input_guard
238
+ def backward(ctx, do, dht=None):
239
+ g, o, initial_state = ctx.saved_tensors
240
+ cu_seqlens = ctx.cu_seqlens
241
+
242
+ dx, dg, dh0 = fused_recurrent_hgrn_bwd(
243
+ g=g,
244
+ o=o,
245
+ do=do,
246
+ dht=dht,
247
+ initial_state=initial_state,
248
+ cu_seqlens=cu_seqlens
249
+ )
250
+ return dx, dg, dh0, None, None
251
+
252
+
253
+ @torch.compiler.disable
254
+ def fused_recurrent_hgrn(
255
+ x: torch.Tensor,
256
+ g: torch.Tensor,
257
+ initial_state: torch.Tensor = None,
258
+ output_final_state: bool = False,
259
+ cu_seqlens: Optional[torch.LongTensor] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ r"""
262
+ Args:
263
+ x (torch.Tensor):
264
+ inputs of shape `[B, T, D].
265
+ g (torch.Tensor):
266
+ Forget gates of shape `[B, T, D]`.
267
+ initial_state (Optional[torch.Tensor]):
268
+ Initial state of shape `[N, D]` for `N` input sequences.
269
+ For equal-length input sequences, `N` equals the batch size `B`.
270
+ Default: `None`.
271
+ output_final_state (Optional[bool]):
272
+ Whether to output the final state of shape `[N, D]`. Default: `False`.
273
+ cu_seqlens (torch.LongTensor):
274
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
275
+ consistent with the FlashAttention API.
276
+
277
+ Returns:
278
+ o (torch.Tensor):
279
+ Outputs of shape `[B, T, D]`.
280
+ final_state (torch.Tensor):
281
+ Final state of shape `[N, D]` if `output_final_state=True` else `None`.
282
+
283
+ Examples::
284
+ >>> import torch
285
+ >>> import torch.nn.functional as F
286
+ >>> from einops import rearrange
287
+ >>> from fla.ops.hgrn import fused_recurrent_hgrn
288
+ # inputs with equal lengths
289
+ >>> B, T, D = 4, 2048, 512
290
+ >>> x = torch.randn(B, T, D, device='cuda')
291
+ >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
292
+ >>> h0 = torch.randn(B, D, device='cuda')
293
+ >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
294
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
295
+ >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
296
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
297
+ >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
298
+ >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens)
299
+ >>> assert o.allclose(o_var.view(o.shape))
300
+ >>> assert ht.allclose(ht_var)
301
+ """
302
+ return FusedRecurrentHGRNFunction.apply(
303
+ x,
304
+ g,
305
+ initial_state,
306
+ output_final_state,
307
+ cu_seqlens
308
+ )
fla3/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, T, D = x.shape
17
+
18
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, i].exp() * h + x[:, i]
27
+ o[:, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, T, D = x.shape
44
+
45
+ gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, j].exp() * h + x[:, j]
58
+ o[:, j] = hp * gc[:, j].exp() + h
59
+ h = o[:, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
fla3/ops/lightning_attn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_lightning_attn
4
+ from .fused_recurrent import fused_recurrent_lightning_attn
5
+
6
+ __all__ = [
7
+ 'chunk_lightning_attn',
8
+ 'fused_recurrent_lightning_attn'
9
+ ]