zaydzuhri commited on
Commit
3332857
·
verified ·
1 Parent(s): 05685b7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/models/bitnet/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc +0 -0
  4. fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  5. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc +0 -0
  6. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc +0 -0
  7. fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
  8. fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc +0 -0
  9. fla/models/mamba2/__pycache__/__init__.cpython-312.pyc +0 -0
  10. fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc +0 -0
  11. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  12. fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  13. fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc +0 -0
  14. fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc +0 -0
  15. fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  16. fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  17. fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  19. fla/modules/__pycache__/activations.cpython-312.pyc +0 -0
  20. fla/modules/__pycache__/fused_cross_entropy.cpython-312.pyc +0 -0
  21. fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc +0 -0
  22. fla/modules/__pycache__/layernorm.cpython-312.pyc +0 -0
  23. fla/modules/__pycache__/mlp.cpython-312.pyc +0 -0
  24. fla/ops/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  25. fla/ops/attn/__init__.py +7 -0
  26. fla/ops/based/__pycache__/__init__.cpython-312.pyc +0 -0
  27. fla/ops/common/chunk_delta_h.py +399 -0
  28. fla/ops/common/chunk_h_parallel.py +650 -0
  29. fla/ops/delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  30. fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  31. fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  32. fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  33. fla/ops/generalized_delta_rule/dplr/__init__.py +7 -0
  34. fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc +0 -0
  35. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc +0 -0
  36. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc +0 -0
  37. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc +0 -0
  38. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc +0 -0
  39. fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  40. fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +324 -0
  41. fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +196 -0
  42. fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +197 -0
  43. fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +138 -0
  44. fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +184 -0
  45. fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +318 -0
  46. fla/ops/generalized_delta_rule/iplr/__init__.py +7 -0
  47. fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-312.pyc +0 -0
  48. fla/ops/generalized_delta_rule/iplr/chunk.py +528 -0
  49. fla/ops/generalized_delta_rule/iplr/naive.py +69 -0
  50. fla/ops/generalized_delta_rule/iplr/wy_fast.py +338 -0
fla/models/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (653 Bytes). View file
 
fla/models/bitnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (678 Bytes). View file
 
fla/models/bitnet/__pycache__/configuration_bitnet.cpython-312.pyc ADDED
Binary file (2.36 kB). View file
 
fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (813 Bytes). View file
 
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-312.pyc ADDED
Binary file (2.49 kB). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (742 Bytes). View file
 
fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc ADDED
Binary file (3.33 kB). View file
 
fla/models/mamba2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (691 Bytes). View file
 
fla/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc ADDED
Binary file (7.5 kB). View file
 
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (683 Bytes). View file
 
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (683 Bytes). View file
 
fla/models/samba/__pycache__/configuration_samba.cpython-312.pyc ADDED
Binary file (3.39 kB). View file
 
fla/models/samba/__pycache__/modeling_samba.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.51 kB). View file
 
fla/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
fla/models/transformer_mtp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (791 Bytes). View file
 
fla/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.34 kB). View file
 
fla/modules/__pycache__/activations.cpython-312.pyc ADDED
Binary file (23 kB). View file
 
fla/modules/__pycache__/fused_cross_entropy.cpython-312.pyc ADDED
Binary file (16 kB). View file
 
fla/modules/__pycache__/fused_linear_listnet_loss.cpython-312.pyc ADDED
Binary file (17.8 kB). View file
 
fla/modules/__pycache__/layernorm.cpython-312.pyc ADDED
Binary file (43.4 kB). View file
 
fla/modules/__pycache__/mlp.cpython-312.pyc ADDED
Binary file (6.23 kB). View file
 
fla/ops/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (208 Bytes). View file
 
fla/ops/attn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .parallel import parallel_attn
4
+
5
+ __all__ = [
6
+ 'parallel_attn'
7
+ ]
fla/ops/based/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (282 Bytes). View file
 
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (388 Bytes). View file
 
fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (34 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (315 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_dplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_dplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (324 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc ADDED
Binary file (28 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc ADDED
Binary file (8.91 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
20
+ for BK in [32, 64]
21
+ for num_warps in [2, 4, 8, 16]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BC', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_inter(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi, # cumsum
34
+ ge, # before cumsum
35
+ Aqk,
36
+ Aqb,
37
+ Aab,
38
+ Aak,
39
+ offsets,
40
+ indices,
41
+ scale: tl.constexpr,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BC: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ NC: tl.constexpr,
49
+ USE_OFFSETS: tl.constexpr,
50
+ HEAD_FIRST: tl.constexpr,
51
+ ):
52
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_h = i_bh // H, i_bh % H
54
+ i_i, i_j = i_c // NC, i_c % NC
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if i_t * BT + i_i * BC >= T:
63
+ return
64
+ if i_i <= i_j:
65
+ return
66
+
67
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
68
+ b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
69
+ b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
70
+ b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
71
+ for i_k in range(tl.cdiv(K, BK)):
72
+ o_k = i_k * BK + tl.arange(0, BK)
73
+ m_k = o_k < K
74
+
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
77
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
78
+ p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
83
+ p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
84
+ else:
85
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
87
+ p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
90
+ p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
91
+ p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
92
+ p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
93
+ # [BK,]
94
+ b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
95
+ # [BC, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_a = tl.load(p_a, boundary_check=(0, 1))
98
+ b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
99
+ b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
100
+ b_ag = b_a * exp(b_gq_e - b_gn[None, :])
101
+ b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
102
+ # [BK, BC]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ b_b = tl.load(p_b, boundary_check=(0, 1))
105
+ b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
106
+ tmp = exp(b_gn[:, None] - b_gk)
107
+ b_kg = b_k * tmp
108
+ b_bg = b_b * tmp
109
+ # [BC, BC] using tf32 to improve precision here.
110
+ b_Aab += tl.dot(b_ag, b_bg)
111
+ b_Aak += tl.dot(b_ag, b_kg)
112
+ b_Aqk += tl.dot(b_qg, b_kg)
113
+ b_Aqb += tl.dot(b_qg, b_bg)
114
+
115
+ if HEAD_FIRST:
116
+ p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
117
+ p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
118
+ p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
119
+ p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ else:
121
+ p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
122
+ p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
123
+ p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
124
+ p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
125
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
126
+ tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
127
+ tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
128
+ tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
129
+
130
+
131
+ @triton.heuristics({
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in [2, 4, 8, 16, 32]
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BK', 'BT'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
145
+ q,
146
+ k,
147
+ a,
148
+ b,
149
+ gi,
150
+ ge,
151
+ qg,
152
+ kg,
153
+ ag,
154
+ bg,
155
+ Aqk,
156
+ Aqb,
157
+ Aab,
158
+ Aak,
159
+ offsets,
160
+ indices,
161
+ scale: tl.constexpr,
162
+ T,
163
+ H: tl.constexpr,
164
+ K: tl.constexpr,
165
+ BT: tl.constexpr,
166
+ BC: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ NC: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr,
171
+ GATHER_SUPPORTED: tl.constexpr
172
+ ):
173
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
174
+ i_b, i_h = i_bh // H, i_bh % H
175
+ i_j = i_i
176
+ if USE_OFFSETS:
177
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
178
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
179
+ T = eos - bos
180
+ else:
181
+ bos, eos = i_b * T, i_b * T + T
182
+
183
+ if i_t * BT + i_i * BC >= T:
184
+ return
185
+
186
+ o_i = tl.arange(0, BC)
187
+ o_k = tl.arange(0, BK)
188
+ m_k = o_k < K
189
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
190
+ last_idx = min((i_t+1) * BT, T) - 1
191
+ if HEAD_FIRST:
192
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
193
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
195
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
196
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
197
+ p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
198
+ p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
199
+ p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
200
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
201
+
202
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
203
+ p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
204
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
205
+ p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
206
+ else:
207
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
208
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
210
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
211
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
212
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
213
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
214
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
215
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
216
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
217
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
218
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
219
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
220
+
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ b_q = b_q * scale
223
+ b_k = tl.load(p_k, boundary_check=(0, 1))
224
+ b_a = tl.load(p_a, boundary_check=(0, 1))
225
+ b_b = tl.load(p_b, boundary_check=(0, 1))
226
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
227
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
228
+
229
+ # deal with decay term.
230
+ g_exp = exp(b_gi)
231
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
232
+ b_qg = b_q * g_exp
233
+ b_kg = b_k * g_exp_inv
234
+ b_bg = b_b * g_exp_inv
235
+ b_ag = b_a * exp(b_ge)
236
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
237
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
238
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ # tl.debug_barrier()
241
+
242
+ b_q = b_q.to(b_k.dtype)
243
+ # inner attn
244
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
245
+ # a trick to index the j-th row of b_k, b_g, b_b
246
+ if GATHER_SUPPORTED:
247
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
248
+ # [1, BK]
249
+ b_k_j = gather(b_k, row_idx, axis=0)
250
+ b_gk_j = gather(b_gi, row_idx, axis=0)
251
+ b_b_j = gather(b_b, row_idx, axis=0)
252
+ else:
253
+ mask = tl.arange(0, BC) == j
254
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
255
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
256
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
257
+ mask = tl.arange(0, BC) == j
258
+ tmp = exp(b_gi - b_gk_j)
259
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
260
+ b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
261
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
262
+ b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
263
+ tmp2 = exp(b_ge - b_gk_j)
264
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
265
+ b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
266
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
267
+ b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
268
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
269
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
270
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
271
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
272
+
273
+
274
+ def chunk_fwd_intra_dplr_fn(
275
+ q: torch.Tensor,
276
+ k: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gi: torch.Tensor,
280
+ ge: torch.Tensor,
281
+ scale: float,
282
+ chunk_size: int,
283
+ offsets: Optional[torch.LongTensor] = None,
284
+ indices: Optional[torch.LongTensor] = None,
285
+ head_first: bool = True,
286
+ ):
287
+ if head_first:
288
+ B, H, T, K = k.shape
289
+ else:
290
+ B, T, H, K = k.shape
291
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
292
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
293
+ BC = min(16, BT)
294
+ NC = triton.cdiv(BT, BC)
295
+
296
+ Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
297
+ Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
298
+ # involving matrix inverse and it'd be better to use float here.
299
+ Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
300
+ Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
301
+ grid = (NT, NC * NC, B * H)
302
+
303
+ chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
304
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
305
+ offsets=offsets, indices=indices,
306
+ scale=scale,
307
+ T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
308
+ HEAD_FIRST=head_first
309
+ )
310
+ grid = (NT, NC, B * H)
311
+ BK = triton.next_power_of_2(K)
312
+ qg = torch.empty_like(q)
313
+ kg = torch.empty_like(k, dtype=q.dtype)
314
+ ag = torch.empty_like(a, dtype=q.dtype)
315
+ bg = torch.empty_like(b, dtype=q.dtype)
316
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
317
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
318
+ qg=qg, kg=kg, ag=ag, bg=bg,
319
+ offsets=offsets, indices=indices,
320
+ scale=scale,
321
+ T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
322
+ GATHER_SUPPORTED=is_gather_supported
323
+ )
324
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_n, i_h = i_nh // H, i_nh % H
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ NT = tl.cdiv(T, BT)
66
+ boh = i_n * NT
67
+
68
+ # [BK, BV]
69
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_FINAL_STATE_GRADIENT:
71
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
73
+
74
+ mask_k = tl.arange(0, BK) < K
75
+ for i_t in range(NT - 1, -1, -1):
76
+ if HEAD_FIRST:
77
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
81
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
82
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
83
+ if HEAD_FIRST:
84
+ p_qg = tl.make_block_ptr(qg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
85
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
88
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ else:
91
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
92
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
93
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
95
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ # [BK, BT]
98
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
99
+ # [BT, BK]
100
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
101
+ b_w = tl.load(p_w, boundary_check=(0, 1))
102
+ # [BT, V]
103
+ b_do = tl.load(p_do, boundary_check=(0, 1))
104
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
105
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
106
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
107
+ # [BK, BV]
108
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
109
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
110
+ last_idx = min((i_t + 1) * BT, T) - 1
111
+ if HEAD_FIRST:
112
+ bg_last = tl.load(gk + (i_nh * T + last_idx) * K + tl.arange(0, BK), mask=mask_k)
113
+ else:
114
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
115
+ b_dh *= exp(bg_last)[:, None]
116
+ b_dh += b_dh_tmp
117
+
118
+ if USE_INITIAL_STATE:
119
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_bwd_dhu(
124
+ qg: torch.Tensor,
125
+ bg: torch.Tensor,
126
+ w: torch.Tensor,
127
+ gk: torch.Tensor,
128
+ h0: torch.Tensor,
129
+ dht: Optional[torch.Tensor],
130
+ do: torch.Tensor,
131
+ dv: torch.Tensor,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *qg.shape, do.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *qg.shape, do.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ BK = triton.next_power_of_2(K)
143
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
144
+ # H100
145
+ if check_shared_mem('hopper', qg.device.index):
146
+ BV = 64
147
+ BC = 64 if K <= 128 else 32
148
+ elif check_shared_mem('ampere', qg.device.index): # A100
149
+ BV = 32
150
+ BC = 32
151
+ else: # Etc: 4090
152
+ BV = 16
153
+ BC = 16
154
+
155
+ # N: the actual number of sequences in the batch with either equal or variable lengths
156
+ if offsets is None:
157
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
158
+ else:
159
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
160
+
161
+ BC = min(BT, BC)
162
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
163
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
164
+
165
+ if head_first:
166
+ dh = qg.new_empty(B, H, NT, K, V)
167
+ else:
168
+ dh = qg.new_empty(B, NT, H, K, V)
169
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
170
+ dv2 = torch.zeros_like(dv)
171
+
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_bwd_kernel_dhu[grid](
174
+ qg=qg,
175
+ bg=bg,
176
+ w=w,
177
+ gk=gk,
178
+ dht=dht,
179
+ dh0=dh0,
180
+ do=do,
181
+ dh=dh,
182
+ dv=dv,
183
+ dv2=dv2,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ HEAD_FIRST=head_first
195
+ )
196
+ return dh, dh0, dv2
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+
82
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
83
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
84
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
85
+ if HEAD_FIRST:
86
+ p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
88
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
92
+ else:
93
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
95
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
98
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+ b_w = tl.load(p_w, boundary_check=(0, 1))
103
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
104
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
105
+ b_hc += tl.dot(b_kg, b_v)
106
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
107
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
108
+
109
+ last_idx = min((i_t + 1) * BT, T) - 1
110
+ if HEAD_FIRST:
111
+ b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
112
+ else:
113
+ b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K +
114
+ tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
115
+ b_h *= exp(b_g_last[:, None])
116
+ b_h += b_hc
117
+
118
+ if STORE_FINAL_STATE:
119
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_fwd_h(
124
+ kg: torch.Tensor,
125
+ v: torch.Tensor,
126
+ w: torch.Tensor,
127
+ u: torch.Tensor,
128
+ bg: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: Optional[torch.Tensor] = None,
131
+ output_final_state: bool = False,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *kg.shape, u.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *kg.shape, u.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ # N: the actual number of sequences in the batch with either equal or variable lengths
143
+ if offsets is None:
144
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
145
+ else:
146
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
147
+ BK = triton.next_power_of_2(K)
148
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
149
+ # H100 can have larger block size
150
+
151
+ if check_shared_mem('hopper', kg.device.index):
152
+ BV = 64
153
+ BC = 64 if K <= 128 else 32
154
+ elif check_shared_mem('ampere', kg.device.index): # A100
155
+ BV = 32
156
+ BC = 32
157
+ else:
158
+ BV = 16
159
+ BC = 16
160
+
161
+ BC = min(BT, BC)
162
+ NK = triton.cdiv(K, BK)
163
+ NV = triton.cdiv(V, BV)
164
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
165
+
166
+ if head_first:
167
+ h = kg.new_empty(B, H, NT, K, V)
168
+ else:
169
+ h = kg.new_empty(B, NT, H, K, V)
170
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
171
+ v_new = torch.empty_like(u)
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_fwd_kernel_h[grid](
174
+ kg=kg,
175
+ v=v,
176
+ w=w,
177
+ bg=bg,
178
+ u=u,
179
+ v_new=v_new,
180
+ h=h,
181
+ gk=gk,
182
+ h0=initial_state,
183
+ ht=final_state,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ NT=NT,
195
+ HEAD_FIRST=head_first
196
+ )
197
+ return h, v_new, final_state
fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, use_cuda_graph
11
+
12
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in BK_LIST
22
+ for BV in BK_LIST
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_o(
31
+ qg,
32
+ v,
33
+ v_new,
34
+ A_qk,
35
+ A_qb,
36
+ h,
37
+ o,
38
+ offsets,
39
+ indices,
40
+ T,
41
+ H: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ USE_OFFSETS: tl.constexpr,
48
+ HEAD_FIRST: tl.constexpr,
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_h = i_bh // H, i_bh % H
52
+
53
+ if USE_OFFSETS:
54
+ i_tg = i_t
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ NT = tl.cdiv(T, BT)
59
+ else:
60
+ NT = tl.cdiv(T, BT)
61
+ i_tg = i_b * NT + i_t
62
+ bos, eos = i_b * T, i_b * T + T
63
+
64
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
65
+ for i_k in range(tl.cdiv(K, BK)):
66
+ if HEAD_FIRST:
67
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
68
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
69
+ else:
70
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
71
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
73
+ b_h = tl.load(p_h, boundary_check=(0, 1))
74
+ b_o += tl.dot(b_qg, b_h)
75
+
76
+ if HEAD_FIRST:
77
+ p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
78
+ p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
79
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
80
+ p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
81
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
82
+ else:
83
+ p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
84
+ p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+
89
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
90
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
91
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
92
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
93
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
94
+ b_v = tl.load(p_v, boundary_check=(0, 1))
95
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
96
+ b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
97
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
98
+
99
+
100
+ def chunk_dplr_fwd_o(
101
+ qg: torch.Tensor,
102
+ v: torch.Tensor,
103
+ v_new: torch.Tensor,
104
+ A_qk: torch.Tensor,
105
+ A_qb: torch.Tensor,
106
+ h: torch.Tensor,
107
+ offsets: Optional[torch.LongTensor] = None,
108
+ indices: Optional[torch.LongTensor] = None,
109
+ head_first: bool = True,
110
+ chunk_size: int = 64
111
+ ) -> torch.Tensor:
112
+ if head_first:
113
+ B, H, T, K, V = *qg.shape, v.shape[-1]
114
+ else:
115
+ B, T, H, K, V = *qg.shape, v.shape[-1]
116
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
117
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
118
+
119
+ o = torch.empty_like(v)
120
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
121
+ chunk_dplr_fwd_kernel_o[grid](
122
+ qg=qg,
123
+ v=v,
124
+ v_new=v_new,
125
+ A_qk=A_qk,
126
+ A_qb=A_qb,
127
+ h=h,
128
+ o=o,
129
+ offsets=offsets,
130
+ indices=indices,
131
+ T=T,
132
+ H=H,
133
+ K=K,
134
+ V=V,
135
+ BT=BT,
136
+ HEAD_FIRST=head_first
137
+ )
138
+ return o
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph
11
+
12
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
13
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [2, 4, 8, 16, 32]
23
+ for num_stages in [2, 3, 4]
24
+ ],
25
+ key=['BT', 'BK', 'BV'],
26
+ use_cuda_graph=use_cuda_graph,
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def bwd_prepare_wy_repr_kernel(
30
+ A_ab_inv,
31
+ A_ak,
32
+ ag,
33
+ v,
34
+ dw,
35
+ du,
36
+ dv,
37
+ dv0,
38
+ dag,
39
+ dAak,
40
+ dAab,
41
+ offsets,
42
+ indices,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if HEAD_FIRST:
63
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
64
+ p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
65
+ p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
66
+ p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
67
+ else:
68
+ p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
69
+ p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
70
+ p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
72
+
73
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
74
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
75
+ b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0)
76
+ b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0)
77
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
78
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_v in range(tl.cdiv(V, BV)):
81
+ if HEAD_FIRST:
82
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
83
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
86
+ else:
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
89
+ p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+ b_v = tl.load(p_v, boundary_check=(0, 1))
92
+ b_du = tl.load(p_du, boundary_check=(0, 1))
93
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
94
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
95
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
96
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0)
99
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
100
+ b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0)
101
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
102
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
103
+
104
+ for i_k in range(tl.cdiv(K, BK)):
105
+ if HEAD_FIRST:
106
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
107
+ p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
108
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
109
+ else:
110
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
111
+ p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
112
+ p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
113
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
114
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
115
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
116
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
117
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
118
+
119
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
120
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
121
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
122
+ # denote A = I - lower(A_ab), B = A^-1
123
+ # in the backward pass.
124
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
125
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
126
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
127
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
128
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
129
+ b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0)
130
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
131
+
132
+
133
+ def chunk_dplr_bwd_wy(
134
+ A_ab_inv: torch.Tensor,
135
+ A_ak: torch.Tensor,
136
+ v: torch.Tensor,
137
+ ag: torch.Tensor,
138
+ dw: torch.Tensor,
139
+ du: torch.Tensor,
140
+ dv0: torch.Tensor,
141
+ offsets: Optional[torch.LongTensor],
142
+ indices: Optional[torch.LongTensor],
143
+ head_first: bool,
144
+ chunk_size: int,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du])
147
+ if head_first:
148
+ B, H, T, K, V = *dw.shape, du.shape[-1]
149
+ else:
150
+ B, T, H, K, V = *dw.shape, du.shape[-1]
151
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
152
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
153
+ BK = min(triton.next_power_of_2(K), 64)
154
+ BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32)
155
+
156
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
157
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
158
+ dv = torch.empty_like(v)
159
+ dag = torch.empty_like(ag)
160
+
161
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
162
+ A_ab_inv=A_ab_inv,
163
+ A_ak=A_ak,
164
+ ag=ag,
165
+ v=v,
166
+ dw=dw,
167
+ du=du,
168
+ dv=dv,
169
+ dv0=dv0,
170
+ dag=dag,
171
+ dAak=dA_ak,
172
+ dAab=dA_ab,
173
+ offsets=offsets,
174
+ indices=indices,
175
+ T=T,
176
+ H=H,
177
+ K=K,
178
+ V=V,
179
+ BT=BT,
180
+ BK=BK,
181
+ BV=BV,
182
+ HEAD_FIRST=head_first
183
+ )
184
+ return dA_ab, dA_ak, dv, dag
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps)
20
+ for num_warps in [1, 2, 4, 8, 16]
21
+ ],
22
+ key=['BT'],
23
+ use_cuda_graph=use_cuda_graph,
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def fwd_prepare_wy_repr_kernel_chunk32(
27
+ A_ab,
28
+ A_ab_inv,
29
+ offsets,
30
+ indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ BC: tl.constexpr, # placeholder, do not delete
35
+ USE_OFFSETS: tl.constexpr,
36
+ HEAD_FIRST: tl.constexpr
37
+ ):
38
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if USE_OFFSETS:
41
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
42
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
43
+ T = eos - bos
44
+ else:
45
+ bos, eos = i_b * T, i_b * T + T
46
+ if HEAD_FIRST:
47
+ p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
48
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
49
+ else:
50
+ p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
51
+ p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
53
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
54
+ for i in range(1, BT):
55
+ mask = tl.arange(0, BT) == i
56
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
58
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
59
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
60
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
61
+
62
+
63
+ @triton.heuristics({
64
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
65
+ })
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
69
+ for num_warps in [2, 4, 8]
70
+ for num_stages in [2, 3, 4]
71
+ ],
72
+ key=['BC'],
73
+ use_cuda_graph=use_cuda_graph,
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def fwd_prepare_wy_repr_kernel_chunk64(
77
+ A_ab,
78
+ A_ab_inv,
79
+ offsets,
80
+ indices,
81
+ T,
82
+ H: tl.constexpr,
83
+ BT: tl.constexpr,
84
+ BC: tl.constexpr,
85
+ USE_OFFSETS: tl.constexpr,
86
+ HEAD_FIRST: tl.constexpr,
87
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported
88
+ ):
89
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
90
+ i_b, i_h = i_bh // H, i_bh % H
91
+ if USE_OFFSETS:
92
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
93
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
94
+ T = eos - bos
95
+ else:
96
+ bos, eos = i_b * T, i_b * T + T
97
+
98
+ if HEAD_FIRST:
99
+
100
+ p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
101
+ p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
102
+ p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
103
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
104
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
105
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
106
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
107
+ else:
108
+ p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
109
+ p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
110
+ p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
111
+ p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
112
+ p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
113
+ p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
114
+ p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
115
+
116
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
117
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
118
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
119
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
120
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
121
+
122
+ for i in range(1, BC):
123
+ if GATHER_SUPPORTED:
124
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
125
+ # [1, BK] -> [BK]
126
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
127
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
128
+ else:
129
+ mask = tl.arange(0, BC) == i
130
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
131
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
132
+ mask = tl.arange(0, BC) == i
133
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
134
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
135
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
136
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
137
+ b_A = tl.where(mask[:, None], b_a, b_A)
138
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
139
+
140
+ # blockwise computation of lower triangular matrix's inverse
141
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
142
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
143
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
144
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
145
+ # tl.debug_barrier()
146
+ tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
147
+ tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
148
+ tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
149
+ # causal mask
150
+ tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
155
+ })
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
159
+ for num_warps in [2, 4, 8, 16, 32]
160
+ for num_stages in [2, 3, 4]
161
+ ],
162
+ key=['BT', 'BK', 'BV'],
163
+ use_cuda_graph=use_cuda_graph,
164
+ )
165
+ @triton.jit(do_not_specialize=['T'])
166
+ def fwd_wu_kernel(
167
+ u,
168
+ w,
169
+ ag,
170
+ v,
171
+ A_ab_inv,
172
+ A_ak,
173
+ offsets,
174
+ indices,
175
+ T,
176
+ H: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BK: tl.constexpr,
181
+ BV: tl.constexpr,
182
+ USE_OFFSETS: tl.constexpr,
183
+ HEAD_FIRST: tl.constexpr,
184
+ ):
185
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
186
+ i_b, i_h = i_bh // H, i_bh % H
187
+ if USE_OFFSETS:
188
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
189
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
190
+ T = eos - bos
191
+ else:
192
+ bos, eos = i_b * T, i_b * T + T
193
+
194
+ if HEAD_FIRST:
195
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
196
+ p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
197
+ else:
198
+ p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
199
+ p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
200
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
201
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
202
+ o_s = tl.arange(0, BT)
203
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
204
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
205
+ # let's use tf32 here
206
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
207
+ # (SY 01/04) should be bf16 or tf32? To verify.
208
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
209
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
210
+
211
+ for i_k in range(tl.cdiv(K, BK)):
212
+ if HEAD_FIRST:
213
+ p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
214
+ p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
215
+ else:
216
+ p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
217
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
218
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
219
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
220
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
221
+
222
+ for i_v in range(tl.cdiv(V, BV)):
223
+ if HEAD_FIRST:
224
+ p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
225
+ p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
226
+ else:
227
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
228
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
229
+ b_v = tl.load(p_v, boundary_check=(0, 1))
230
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
231
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
232
+
233
+
234
+ def fwd_prepare_wy_repr(
235
+ ag: torch.Tensor,
236
+ v: torch.Tensor,
237
+ A_ak: torch.Tensor,
238
+ A_ab: torch.Tensor,
239
+ offsets: Optional[torch.LongTensor],
240
+ indices: Optional[torch.LongTensor],
241
+ head_first: bool = True,
242
+ chunk_size: int = 64
243
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if head_first:
245
+ B, H, T, K = ag.shape
246
+ else:
247
+ B, T, H, K = ag.shape
248
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
249
+
250
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
251
+ BC = min(BT, 32)
252
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
253
+ A_ab_inv = torch.empty_like(A_ab)
254
+ fwd_fn[(NT, B * H)](
255
+ A_ab=A_ab,
256
+ A_ab_inv=A_ab_inv,
257
+ offsets=offsets,
258
+ indices=indices,
259
+ T=T,
260
+ H=H,
261
+ BT=BT,
262
+ BC=BC,
263
+ HEAD_FIRST=head_first
264
+ )
265
+ w, u = fwd_wu(
266
+ ag=ag,
267
+ v=v,
268
+ A_ak=A_ak,
269
+ A_ab_inv=A_ab_inv,
270
+ offsets=offsets,
271
+ indices=indices,
272
+ head_first=head_first,
273
+ chunk_size=BT
274
+ )
275
+ return w, u, A_ab_inv
276
+
277
+
278
+ def fwd_wu(
279
+ ag: torch.Tensor,
280
+ v: torch.Tensor,
281
+ A_ak: torch.Tensor,
282
+ A_ab_inv: torch.Tensor,
283
+ offsets: Optional[torch.LongTensor],
284
+ indices: Optional[torch.LongTensor],
285
+ head_first: bool,
286
+ chunk_size: int
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ if head_first:
289
+ B, H, T, K, V = *ag.shape, v.shape[-1]
290
+ else:
291
+ B, T, H, K, V = *ag.shape, v.shape[-1]
292
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
293
+
294
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
295
+ BK = min(triton.next_power_of_2(K), 64)
296
+ BV = min(triton.next_power_of_2(V), 64)
297
+
298
+ u = torch.empty_like(v)
299
+ w = torch.empty_like(ag)
300
+ fwd_wu_kernel[(NT, B*H)](
301
+ ag=ag,
302
+ v=v,
303
+ A_ak=A_ak,
304
+ A_ab_inv=A_ab_inv,
305
+ w=w,
306
+ u=u,
307
+ offsets=offsets,
308
+ indices=indices,
309
+ T=T,
310
+ H=H,
311
+ K=K,
312
+ V=V,
313
+ BT=BT,
314
+ BK=BK,
315
+ BV=BV,
316
+ HEAD_FIRST=head_first
317
+ )
318
+ return w, u
fla/ops/generalized_delta_rule/iplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_iplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_iplr_delta_rule',
6
+ 'fused_recurrent_iplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (27 kB). View file
 
fla/ops/generalized_delta_rule/iplr/chunk.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_delta_h import prepare_chunk_offsets
11
+ from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph
13
+
14
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({}, num_warps=num_warps)
25
+ for num_warps in [2, 4, 8, 16]
26
+ ],
27
+ key=['BT', 'BK', 'BV'],
28
+ use_cuda_graph=use_cuda_graph,
29
+ )
30
+ @triton.jit(do_not_specialize=['T'])
31
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
32
+ k,
33
+ v,
34
+ d,
35
+ b,
36
+ u,
37
+ v_new,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
82
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
83
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
86
+ p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ else:
92
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
93
+ p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
95
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
98
+ # [BK, BC]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ b_d = tl.load(p_d, boundary_check=(0, 1))
102
+ b_b = tl.load(p_b, boundary_check=(0, 1))
103
+ b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1))
104
+ b_hc += tl.dot(b_k, b_v)
105
+ b_hc += tl.dot(b_b, b_v2.to(b_k.dtype))
106
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
107
+ b_h += b_hc
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
116
+ })
117
+ @triton.autotune(
118
+ configs=[
119
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
120
+ for BK in BKV_LIST
121
+ for BV in BKV_LIST
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3]
124
+ ],
125
+ key=['BT'],
126
+ use_cuda_graph=use_cuda_graph,
127
+ )
128
+ @triton.jit(do_not_specialize=['T'])
129
+ def chunk_generalized_iplr_delta_rule_fwd_kernel_o(
130
+ q,
131
+ k,
132
+ v,
133
+ u,
134
+ b,
135
+ h,
136
+ o,
137
+ offsets,
138
+ indices,
139
+ scale,
140
+ T,
141
+ H: tl.constexpr,
142
+ K: tl.constexpr,
143
+ V: tl.constexpr,
144
+ BT: tl.constexpr,
145
+ BK: tl.constexpr,
146
+ BV: tl.constexpr,
147
+ USE_OFFSETS: tl.constexpr,
148
+ HEAD_FIRST: tl.constexpr,
149
+ ):
150
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
151
+ i_b, i_h = i_bh // H, i_bh % H
152
+
153
+ if USE_OFFSETS:
154
+ i_tg = i_t
155
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
156
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
157
+ T = eos - bos
158
+ NT = tl.cdiv(T, BT)
159
+ else:
160
+ NT = tl.cdiv(T, BT)
161
+ i_tg = i_b * NT + i_t
162
+ bos, eos = i_b * T, i_b * T + T
163
+
164
+ # offset calculation
165
+ q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
166
+ k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
167
+ b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K)
168
+ v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
169
+ u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
170
+ o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V)
171
+ h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V)
172
+ stride_qk = K if HEAD_FIRST else H*K
173
+ stride_vo = V if HEAD_FIRST else H*V
174
+
175
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
176
+ b_Aqk = tl.zeros([BT, BT], dtype=tl.float32)
177
+ b_Aqb = tl.zeros([BT, BT], dtype=tl.float32)
178
+
179
+ for i_k in range(tl.cdiv(K, BK)):
180
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
181
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
182
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
183
+ p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
184
+ # [BT, BK]
185
+ b_q = tl.load(p_q, boundary_check=(0, 1))
186
+ # [BK, BT]
187
+ b_k = tl.load(p_k, boundary_check=(0, 1))
188
+ b_b = tl.load(p_b, boundary_check=(0, 1))
189
+ # [BK, BV]
190
+ b_h = tl.load(p_h, boundary_check=(0, 1))
191
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
192
+ b_o += tl.dot(b_q, b_h)
193
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
194
+ b_Aqk += tl.dot(b_q, b_k)
195
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
196
+ b_Aqb += tl.dot(b_q, b_b)
197
+
198
+ o_i = tl.arange(0, BT)
199
+ m_A = o_i[:, None] >= o_i[None, :]
200
+ b_Aqk = tl.where(m_A, b_Aqk, 0)
201
+ b_Aqb = tl.where(m_A, b_Aqb, 0)
202
+
203
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
205
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ b_v = tl.load(p_v, boundary_check=(0, 1))
207
+ b_u = tl.load(p_u, boundary_check=(0, 1))
208
+ b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale
209
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
210
+
211
+
212
+ def chunk_generalized_iplr_delta_rule_fwd_o(
213
+ q: torch.Tensor,
214
+ k: torch.Tensor,
215
+ v: torch.Tensor,
216
+ v_new: torch.Tensor,
217
+ b: torch.Tensor,
218
+ h: torch.Tensor,
219
+ scale: Optional[float] = None,
220
+ offsets: Optional[torch.LongTensor] = None,
221
+ indices: Optional[torch.LongTensor] = None,
222
+ head_first: bool = True,
223
+ chunk_size: int = 64
224
+ ) -> torch.Tensor:
225
+ if head_first:
226
+ B, H, T, K, V = *q.shape, v.shape[-1]
227
+ else:
228
+ B, T, H, K, V = *q.shape, v.shape[-1]
229
+ if scale is None:
230
+ scale = k.shape[-1] ** -0.5
231
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
232
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
233
+
234
+ o = torch.empty_like(v)
235
+
236
+ def grid(meta): return (
237
+ triton.cdiv(V, meta['BV']),
238
+ NT,
239
+ B * H
240
+ )
241
+ chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid](
242
+ q=q,
243
+ k=k,
244
+ v=v,
245
+ u=v_new,
246
+ b=b,
247
+ h=h,
248
+ o=o,
249
+ offsets=offsets,
250
+ indices=indices,
251
+ scale=scale,
252
+ T=T,
253
+ H=H,
254
+ K=K,
255
+ V=V,
256
+ BT=BT,
257
+ HEAD_FIRST=head_first
258
+ )
259
+ return o
260
+
261
+
262
+ def chunk_generalized_iplr_delta_rule_fwd_h(
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ w: torch.Tensor,
266
+ u: torch.Tensor,
267
+ b: torch.Tensor,
268
+ initial_state: Optional[torch.Tensor] = None,
269
+ output_final_state: bool = False,
270
+ offsets: Optional[torch.LongTensor] = None,
271
+ indices: Optional[torch.LongTensor] = None,
272
+ head_first: bool = True,
273
+ chunk_size: int = 64
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ if head_first:
276
+ B, H, T, K, V = *k.shape, u.shape[-1]
277
+ else:
278
+ B, T, H, K, V = *k.shape, u.shape[-1]
279
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
280
+ # N: the actual number of sequences in the batch with either equal or variable lengths
281
+ if offsets is None:
282
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
283
+ else:
284
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
285
+
286
+ BK = triton.next_power_of_2(K)
287
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
288
+ # H100 can have larger block size
289
+
290
+ if check_shared_mem('hopper', k.device.index):
291
+ BV = 64
292
+ BC = 64 if K <= 128 else 32
293
+ elif check_shared_mem('ampere', k.device.index): # A100
294
+ BV = 32
295
+ BC = 32
296
+ else:
297
+ BV = 16
298
+ BC = 16
299
+
300
+ BC = min(BT, BC)
301
+ NK = triton.cdiv(K, BK)
302
+ NV = triton.cdiv(V, BV)
303
+
304
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
305
+
306
+ if head_first:
307
+ h = k.new_empty(B, H, NT, K, V)
308
+ else:
309
+ h = k.new_empty(B, NT, H, K, V)
310
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
311
+
312
+ v_new = torch.empty_like(u)
313
+ grid = (NK, NV, N * H)
314
+
315
+ chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid](
316
+ k=k,
317
+ v=v,
318
+ d=w,
319
+ b=b,
320
+ u=u,
321
+ v_new=v_new,
322
+ h=h,
323
+ h0=initial_state,
324
+ ht=final_state,
325
+ offsets=offsets,
326
+ chunk_offsets=chunk_offsets,
327
+ T=T,
328
+ H=H,
329
+ K=K,
330
+ V=V,
331
+ BT=BT,
332
+ BC=BC,
333
+ BK=BK,
334
+ BV=BV,
335
+ NT=NT,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return h, v_new, final_state
339
+
340
+
341
+ def chunk_generalized_iplr_delta_rule_fwd(
342
+ q: torch.Tensor,
343
+ k: torch.Tensor,
344
+ v: torch.Tensor,
345
+ a: torch.Tensor,
346
+ b: torch.Tensor,
347
+ scale: float,
348
+ initial_state: torch.Tensor,
349
+ output_final_state: bool,
350
+ offsets: Optional[torch.LongTensor] = None,
351
+ indices: Optional[torch.LongTensor] = None,
352
+ head_first: bool = True,
353
+ chunk_size: int = 64
354
+ ):
355
+ T = q.shape[2] if head_first else q.shape[1]
356
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
357
+ w, u, _ = fwd_prepare_wy_repr(
358
+ a=a,
359
+ b=b,
360
+ k=k,
361
+ v=v,
362
+ offsets=offsets,
363
+ indices=indices,
364
+ head_first=head_first,
365
+ chunk_size=BT
366
+ )
367
+
368
+ h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h(
369
+ k=k,
370
+ v=v,
371
+ b=b,
372
+ w=w,
373
+ u=u,
374
+ initial_state=initial_state,
375
+ output_final_state=output_final_state,
376
+ offsets=offsets,
377
+ indices=indices,
378
+ head_first=head_first,
379
+ chunk_size=BT
380
+ )
381
+ o = chunk_generalized_iplr_delta_rule_fwd_o(
382
+ q=q,
383
+ k=k,
384
+ v=v,
385
+ v_new=v_new,
386
+ b=b,
387
+ h=h,
388
+ scale=scale,
389
+ offsets=offsets,
390
+ indices=indices,
391
+ head_first=head_first,
392
+ chunk_size=BT
393
+ )
394
+ return o, final_state
395
+
396
+
397
+ class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function):
398
+
399
+ @staticmethod
400
+ @input_guard
401
+ @autocast_custom_fwd
402
+ def forward(
403
+ ctx,
404
+ q: torch.Tensor,
405
+ k: torch.Tensor,
406
+ v: torch.Tensor,
407
+ a: torch.Tensor,
408
+ b: torch.Tensor,
409
+ scale: float,
410
+ initial_state: torch.Tensor,
411
+ output_final_state: bool,
412
+ offsets: Optional[torch.LongTensor] = None,
413
+ head_first: bool = True
414
+ ):
415
+ chunk_size = 64
416
+
417
+ # 2-d indices denoting the offsets of chunks in each sequence
418
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
419
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
420
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
421
+ indices = None
422
+ if offsets is not None:
423
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
424
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
425
+
426
+ o, final_state = chunk_generalized_iplr_delta_rule_fwd(
427
+ q=q,
428
+ k=k,
429
+ v=v,
430
+ a=a,
431
+ b=b,
432
+ scale=scale,
433
+ initial_state=initial_state,
434
+ output_final_state=output_final_state,
435
+ offsets=offsets,
436
+ indices=indices,
437
+ head_first=head_first,
438
+ chunk_size=chunk_size
439
+ )
440
+ return o.to(q.dtype), final_state
441
+
442
+ @staticmethod
443
+ @input_guard
444
+ @autocast_custom_bwd
445
+ def backward(
446
+ ctx,
447
+ do: torch.Tensor,
448
+ dht: torch.Tensor
449
+ ):
450
+ raise NotImplementedError(
451
+ "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. "
452
+ "Stay tuned!"
453
+ )
454
+
455
+
456
+ @torch.compiler.disable
457
+ def chunk_iplr_delta_rule(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ a: torch.Tensor,
462
+ b: torch.Tensor,
463
+ scale: float = None,
464
+ initial_state: torch.Tensor = None,
465
+ output_final_state: bool = False,
466
+ cu_seqlens: Optional[torch.LongTensor] = None,
467
+ head_first: bool = True
468
+ ):
469
+ r"""
470
+ Args:
471
+ q (torch.Tensor):
472
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
473
+ k (torch.Tensor):
474
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
475
+ v (torch.Tensor):
476
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
477
+ a (torch.Tensor):
478
+ activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
479
+ b (torch.Tensor):
480
+ betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
481
+ scale (Optional[int]):
482
+ Scale factor for the RetNet attention scores.
483
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
484
+ initial_state (Optional[torch.Tensor]):
485
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
486
+ For equal-length input sequences, `N` equals the batch size `B`.
487
+ Default: `None`.
488
+ output_final_state (Optional[bool]):
489
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
490
+ cu_seqlens (torch.LongTensor):
491
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
492
+ consistent with the FlashAttention API.
493
+ head_first (Optional[bool]):
494
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
495
+ Default: `True`.
496
+
497
+ Returns:
498
+ o (torch.Tensor):
499
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
500
+ final_state (torch.Tensor):
501
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
502
+ """
503
+ assert q.dtype == k.dtype == v.dtype
504
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
505
+
506
+ if cu_seqlens is not None:
507
+ if q.shape[0] != 1:
508
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
509
+ f"Please flatten variable-length inputs before processing.")
510
+ if head_first:
511
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
512
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
513
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
514
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
515
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
516
+ o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply(
517
+ q,
518
+ k,
519
+ v,
520
+ a,
521
+ b,
522
+ scale,
523
+ initial_state,
524
+ output_final_state,
525
+ cu_seqlens,
526
+ head_first
527
+ )
528
+ return o, final_state
fla/ops/generalized_delta_rule/iplr/naive.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
8
+ # q, k, alpha, beta [B, H, L, D_K]
9
+ # v [B, H, L, D_V]
10
+ def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True):
11
+ orig_dtype = q.dtype
12
+ b, h, l, d_k = q.shape
13
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
14
+ d_v = v.shape[-1]
15
+ o = torch.zeros_like(v)
16
+ S = torch.zeros(b, h, d_k, d_v).to(v)
17
+ q = q * (d_k ** -0.5)
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i]
26
+ _alpha = alpha[:, :, i]
27
+ _beta = beta[:, :, i]
28
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
29
+ S = S + _kv
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v
40
+ assert l % chunk_size == 0
41
+
42
+ S = k.new_zeros(b, h, d_k, d_v)
43
+ if initial_state is not None:
44
+ S += initial_state
45
+
46
+ # note that diagonal is masked.
47
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
48
+ q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta])
49
+
50
+ v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
51
+ attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0)
52
+ for i in range(1, chunk_size):
53
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
54
+
55
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
56
+ u = attn @ v2
57
+ w = attn @ alpha
58
+ o = torch.zeros_like(v)
59
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
60
+ for i in range(0, l // chunk_size):
61
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
62
+ o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i
63
+ v2_i = u_i + w_i @ S
64
+ o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i)
65
+ o_3 = q_i @ S
66
+ o[:, :, i] = o_1 + o_2 + o_3
67
+ S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i
68
+ S = None if output_final_state is False else S
69
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps)
22
+ for num_warps in [1, 2, 4, 8, 16]
23
+ ],
24
+ key=['BK']
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def fwd_prepare_wy_repr_kernel_chunk32(
28
+ a,
29
+ b,
30
+ A,
31
+ offsets,
32
+ indices,
33
+ T,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BC: tl.constexpr, # dummy placeholder
39
+ USE_OFFSETS: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if USE_OFFSETS:
45
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ if HEAD_FIRST:
54
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
55
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
56
+ else:
57
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
58
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
59
+ b_a = tl.load(p_a, boundary_check=(0, 1))
60
+ b_b = tl.load(p_b, boundary_check=(0, 1))
61
+ b_A += tl.dot(b_a, b_b)
62
+
63
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
64
+ for i in range(1, BT):
65
+ mask = tl.arange(0, BT) == i
66
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
67
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
68
+ b_A = tl.where(mask[:, None], b_a, b_A)
69
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
70
+
71
+ if HEAD_FIRST:
72
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ else:
74
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
75
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+
78
+ @triton.heuristics({
79
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
80
+ })
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16]
85
+ ],
86
+ key=['BK']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fwd_prepare_wy_repr_kernel_chunk64(
90
+ a,
91
+ b,
92
+ A,
93
+ offsets,
94
+ indices,
95
+ T,
96
+ H: tl.constexpr,
97
+ K: tl.constexpr,
98
+ BT: tl.constexpr,
99
+ BK: tl.constexpr,
100
+ BC: tl.constexpr,
101
+ USE_OFFSETS: tl.constexpr,
102
+ HEAD_FIRST: tl.constexpr
103
+ ):
104
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
105
+ i_b, i_h = i_bh // H, i_bh % H
106
+ if USE_OFFSETS:
107
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_b * T, i_b * T + T
112
+
113
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
114
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
115
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
116
+
117
+ for i_k in range(tl.cdiv(K, BK)):
118
+ if HEAD_FIRST:
119
+ p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
120
+ p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
121
+ p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
122
+ p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
123
+ else:
124
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
125
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
126
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
127
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
128
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
129
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
130
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
131
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
132
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
133
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
134
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
135
+
136
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
137
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
138
+
139
+ for i in range(1, BC):
140
+ mask = tl.arange(0, BC) == i
141
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
142
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
143
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
144
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
145
+ b_A = tl.where(mask[:, None], b_a, b_A)
146
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
147
+
148
+ # blockwise computation of lower triangular matrix's inverse
149
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
150
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
151
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
152
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
153
+
154
+ if HEAD_FIRST:
155
+ p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
156
+ p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
157
+ p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
158
+ p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
159
+ else:
160
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
161
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
162
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
163
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
164
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
165
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
166
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
167
+ # causal mask
168
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
169
+
170
+
171
+ @triton.heuristics({
172
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
173
+ })
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps)
177
+ for num_warps in NUM_WARPS
178
+ ],
179
+ key=['BT', 'BK', 'BV']
180
+ )
181
+ @triton.jit(do_not_specialize=['T'])
182
+ def fwd_wu_kernel(
183
+ w,
184
+ u,
185
+ a,
186
+ k,
187
+ v,
188
+ A,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ H: tl.constexpr,
193
+ K: tl.constexpr,
194
+ V: tl.constexpr,
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
202
+ i_b, i_h = i_bh // H, i_bh % H
203
+ if USE_OFFSETS:
204
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
205
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
206
+ T = eos - bos
207
+ else:
208
+ bos, eos = i_b * T, i_b * T + T
209
+
210
+ if HEAD_FIRST:
211
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
212
+ else:
213
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
214
+
215
+ b_A = tl.load(p_A, boundary_check=(0, 1))
216
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
217
+
218
+ for i_k in range(tl.cdiv(K, BK)):
219
+ if HEAD_FIRST:
220
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
221
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
222
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
223
+ else:
224
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
225
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
226
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ b_k = tl.load(p_k, boundary_check=(0, 1))
228
+ b_a = tl.load(p_a, boundary_check=(0, 1))
229
+ b_w = tl.dot(b_A, b_a)
230
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
231
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
232
+
233
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
234
+ b_Aak = b_Aak.to(k.dtype.element_ty)
235
+
236
+ for i_v in range(tl.cdiv(V, BV)):
237
+ if HEAD_FIRST:
238
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
240
+ else:
241
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
243
+ b_v = tl.load(p_v, boundary_check=(0, 1))
244
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
245
+ b_u = tl.dot(b_A, b_v)
246
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
247
+
248
+
249
+ def fwd_prepare_wy_repr(
250
+ a: torch.Tensor,
251
+ b: torch.Tensor,
252
+ v: torch.Tensor,
253
+ k: torch.Tensor,
254
+ offsets: Optional[torch.LongTensor],
255
+ indices: Optional[torch.LongTensor],
256
+ head_first: bool = True,
257
+ chunk_size: int = 64
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ if head_first:
260
+ B, H, T, K = a.shape
261
+ else:
262
+ B, T, H, K = a.shape
263
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
264
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
265
+ BC = min(BT, 32)
266
+ BK = min(triton.next_power_of_2(K), 64)
267
+
268
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype)
269
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
270
+
271
+ fwd_fn[(NT, B * H)](
272
+ a=a,
273
+ b=b,
274
+ A=A,
275
+ offsets=offsets,
276
+ indices=indices,
277
+ T=T,
278
+ H=H,
279
+ K=K,
280
+ BT=BT,
281
+ BK=BK,
282
+ BC=BC,
283
+ HEAD_FIRST=head_first
284
+ )
285
+ w, u = fwd_wu(
286
+ a=a,
287
+ v=v,
288
+ k=k,
289
+ A=A,
290
+ offsets=offsets,
291
+ indices=indices,
292
+ head_first=head_first,
293
+ chunk_size=chunk_size
294
+ )
295
+ return w, u, A
296
+
297
+
298
+ def fwd_wu(
299
+ a: torch.Tensor,
300
+ v: torch.Tensor,
301
+ k: torch.Tensor,
302
+ A: torch.Tensor,
303
+ offsets: Optional[torch.LongTensor],
304
+ indices: Optional[torch.LongTensor],
305
+ head_first: bool,
306
+ chunk_size: int
307
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
308
+ if head_first:
309
+ B, H, T, K, V = *a.shape, v.shape[-1]
310
+ else:
311
+ B, T, H, K, V = *a.shape, v.shape[-1]
312
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
313
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
314
+ CONST_TILING = 64 if check_shared_mem() else 32
315
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
316
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
317
+
318
+ u = torch.empty_like(v)
319
+ w = torch.empty_like(a)
320
+ fwd_wu_kernel[(NT, B*H)](
321
+ a=a,
322
+ v=v,
323
+ w=w,
324
+ u=u,
325
+ A=A,
326
+ k=k,
327
+ offsets=offsets,
328
+ indices=indices,
329
+ T=T,
330
+ H=H,
331
+ K=K,
332
+ V=V,
333
+ BT=BT,
334
+ BK=BK,
335
+ BV=BV,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return w, u