msj19 commited on
Commit
ccefec1
·
verified ·
1 Parent(s): a80b08e

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. all_results.json +10 -0
  2. fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc +0 -0
  3. fla3/ops/path_attn/__pycache__/parallel_path_fwd.cpython-310.pyc +0 -0
  4. fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc +0 -0
  5. fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc +0 -0
  6. fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  7. fla3/ops/rwkv7/fused_recurrent.py +328 -0
  8. fla3/ops/simple_gla/__pycache__/__init__.cpython-312.pyc +0 -0
  9. fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc +0 -0
  10. fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  11. fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc +0 -0
  12. fla3/ops/simple_gla/parallel.py +732 -0
  13. fla3/ops/ttt/fused_chunk.py +835 -0
  14. fla3/ops/ttt/naive.py +126 -0
  15. fla3/ops/utils/__init__.py +54 -0
  16. fla3/ops/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  17. fla3/ops/utils/__pycache__/asm.cpython-310.pyc +0 -0
  18. fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc +0 -0
  19. fla3/ops/utils/__pycache__/index.cpython-310.pyc +0 -0
  20. fla3/ops/utils/__pycache__/index.cpython-312.pyc +0 -0
  21. fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc +0 -0
  22. fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc +0 -0
  23. fla3/ops/utils/__pycache__/matmul.cpython-312.pyc +0 -0
  24. fla3/ops/utils/__pycache__/op.cpython-310.pyc +0 -0
  25. fla3/ops/utils/__pycache__/op.cpython-312.pyc +0 -0
  26. fla3/ops/utils/__pycache__/pack.cpython-310.pyc +0 -0
  27. fla3/ops/utils/__pycache__/pack.cpython-312.pyc +0 -0
  28. fla3/ops/utils/__pycache__/softmax.cpython-310.pyc +0 -0
  29. fla3/ops/utils/__pycache__/solve_tril.cpython-310.pyc +0 -0
  30. fla3/ops/utils/asm.py +17 -0
  31. fla3/ops/utils/cumsum.py +414 -0
  32. fla3/ops/utils/index.py +83 -0
  33. fla3/ops/utils/logcumsumexp.py +52 -0
  34. fla3/ops/utils/logsumexp.py +80 -0
  35. fla3/ops/utils/matmul.py +245 -0
  36. fla3/ops/utils/pack.py +208 -0
  37. fla3/ops/utils/pooling.py +207 -0
  38. fla3/ops/utils/softmax.py +111 -0
  39. fla3/ops/utils/solve_tril.py +276 -0
  40. flame/__init__.py +0 -0
  41. flame/__pycache__/__init__.cpython-310.pyc +0 -0
  42. flame/__pycache__/__init__.cpython-312.pyc +0 -0
  43. flame/__pycache__/data.cpython-310.pyc +0 -0
  44. flame/__pycache__/data.cpython-312.pyc +0 -0
  45. flame/__pycache__/logging.cpython-312.pyc +0 -0
  46. flame/__pycache__/parser.cpython-310.pyc +0 -0
  47. flame/__pycache__/parser.cpython-312.pyc +0 -0
  48. flame/data.py +246 -0
  49. flame/logging.py +118 -0
  50. flame/parser.py +94 -0
all_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 0.7839559871158865,
3
+ "num_tokens": 104891154432,
4
+ "throughput": 12525.363357673923,
5
+ "total_flos": 8.619164947133655e+20,
6
+ "train_loss": 9.173326368447839,
7
+ "train_runtime": 261696.889,
8
+ "train_samples_per_second": 195.709,
9
+ "train_steps_per_second": 0.191
10
+ }
fla3/ops/path_attn/__pycache__/parallel_path_bwd_intra.cpython-310.pyc ADDED
Binary file (5.1 kB). View file
 
fla3/ops/path_attn/__pycache__/parallel_path_fwd.cpython-310.pyc ADDED
Binary file (4.95 kB). View file
 
fla3/ops/path_attn/__pycache__/prepare_k_cache.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
fla3/ops/rwkv7/__pycache__/chunk.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
fla3/ops/rwkv7/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
fla3/ops/rwkv7/fused_recurrent.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule
13
+ from fla.ops.utils.op import exp
14
+ from fla.utils import input_guard, use_cuda_graph
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BV in [16, 32, 64]
26
+ for num_warps in [2, 4, 8, 16, 32]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BK'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def fused_recurrent_rwkv7_fwd_kernel(
34
+ r,
35
+ w,
36
+ k,
37
+ v,
38
+ kk,
39
+ a,
40
+ o,
41
+ h0,
42
+ ht,
43
+ cu_seqlens,
44
+ scale,
45
+ T,
46
+ B: tl.constexpr,
47
+ H: tl.constexpr,
48
+ K: tl.constexpr,
49
+ V: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ REVERSE: tl.constexpr,
53
+ USE_INITIAL_STATE: tl.constexpr,
54
+ STORE_FINAL_STATE: tl.constexpr,
55
+ IS_VARLEN: tl.constexpr,
56
+ IS_DECODE: tl.constexpr,
57
+ ):
58
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
59
+ i_n, i_h = i_nh // H, i_nh % H
60
+
61
+ if IS_VARLEN:
62
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
63
+ T = eos - bos
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+
67
+ o_k = tl.arange(0, BK)
68
+ o_v = i_v * BV + tl.arange(0, BV)
69
+ p_r = r + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
70
+ p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
71
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
72
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
73
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
74
+ p_kk = kk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
75
+
76
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
77
+
78
+ mask_k = o_k < K
79
+ mask_v = o_v < V
80
+ mask_h = mask_k[None, :] & mask_v[:, None]
81
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
82
+
83
+ if USE_INITIAL_STATE:
84
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
85
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
86
+
87
+ if IS_DECODE:
88
+ b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
90
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
91
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
92
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
93
+ b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
94
+ b_act_a = -b_kk
95
+ b_b = b_kk * b_a
96
+
97
+ tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
98
+ b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
99
+ b_o = tl.sum(b_h * b_r[None, :], axis=1)
100
+
101
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
102
+ else:
103
+ for _ in range(0, T):
104
+ b_r = tl.load(p_r, mask=mask_k, other=0).to(tl.float32) * scale
105
+ b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32)
106
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
107
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
108
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
109
+ b_kk = tl.load(p_kk, mask=mask_k, other=0).to(tl.float32)
110
+ b_act_a = -b_kk
111
+ b_b = b_kk * b_a
112
+
113
+ tmp = tl.sum(b_h * b_act_a[None, :], axis=1)
114
+ b_h = exp(b_w)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
115
+ b_o = tl.sum(b_h * b_r[None, :], axis=1)
116
+
117
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
118
+ p_r += (-1 if REVERSE else 1) * H*K
119
+ p_w += (-1 if REVERSE else 1) * H*K
120
+ p_k += (-1 if REVERSE else 1) * H*K
121
+ p_v += (-1 if REVERSE else 1) * H*V
122
+ p_a += (-1 if REVERSE else 1) * H*K
123
+ p_kk += (-1 if REVERSE else 1) * H*K
124
+ p_o += (-1 if REVERSE else 1) * H*V
125
+
126
+ if STORE_FINAL_STATE:
127
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
128
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
129
+
130
+
131
+ @input_guard
132
+ def fused_recurrent_rwkv7_fwd(
133
+ r: torch.Tensor,
134
+ w: torch.Tensor,
135
+ k: torch.Tensor,
136
+ v: torch.Tensor,
137
+ kk: torch.Tensor,
138
+ a: torch.Tensor,
139
+ scale: Optional[float] = 1.0,
140
+ initial_state: Optional[torch.Tensor] = None,
141
+ output_final_state: bool = False,
142
+ reverse: bool = False,
143
+ cu_seqlens: Optional[torch.LongTensor] = None,
144
+ ):
145
+ B, T, H, K, V = *k.shape, v.shape[-1]
146
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
147
+ BK = triton.next_power_of_2(K)
148
+ IS_DECODE = (T == 1)
149
+
150
+ h0 = initial_state
151
+ if not output_final_state:
152
+ ht = None
153
+ else:
154
+ ht = r.new_empty(N, H, K, V, dtype=torch.float32)
155
+ o = torch.empty_like(v)
156
+
157
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
158
+ fused_recurrent_rwkv7_fwd_kernel[grid](
159
+ r,
160
+ w,
161
+ k,
162
+ v,
163
+ kk,
164
+ a,
165
+ o,
166
+ h0,
167
+ ht,
168
+ cu_seqlens,
169
+ scale,
170
+ T=T,
171
+ B=B,
172
+ H=H,
173
+ K=K,
174
+ V=V,
175
+ BK=BK,
176
+ REVERSE=reverse,
177
+ IS_DECODE=IS_DECODE
178
+ )
179
+ return o, ht
180
+
181
+
182
+ def fused_recurrent_rwkv7(
183
+ r: torch.Tensor,
184
+ w: torch.Tensor,
185
+ k: torch.Tensor,
186
+ v: torch.Tensor,
187
+ a: torch.Tensor,
188
+ b: torch.Tensor,
189
+ scale: float = 1.0,
190
+ initial_state: torch.Tensor = None,
191
+ output_final_state: bool = True,
192
+ cu_seqlens: Optional[torch.LongTensor] = None,
193
+ head_first: bool = False,
194
+ ):
195
+ """
196
+ Args:
197
+ r (torch.Tensor):
198
+ r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
199
+ w (torch.Tensor):
200
+ log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
201
+ k (torch.Tensor):
202
+ k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
203
+ v (torch.Tensor):
204
+ v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
205
+ a (torch.Tensor):
206
+ a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
207
+ b (torch.Tensor):
208
+ b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
209
+ scale (float):
210
+ scale of the attention.
211
+ initial_state (torch.Tensor):
212
+ initial state of shape `[B, H, K, V]` if cu_seqlens is None else `[N, H, K, V]` where N = len(cu_seqlens) - 1.
213
+ output_final_state (bool):
214
+ whether to output the final state.
215
+ cu_seqlens (torch.LongTensor):
216
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
217
+ consistent with the FlashAttention API.
218
+ head_first (bool):
219
+ whether to use head first. Recommended to be False to avoid extra transposes.
220
+ Default: `False`.
221
+ """
222
+ return fused_recurrent_dplr_delta_rule(
223
+ q=r,
224
+ k=k,
225
+ v=v,
226
+ a=a,
227
+ b=b,
228
+ gk=w,
229
+ scale=scale,
230
+ initial_state=initial_state,
231
+ output_final_state=output_final_state,
232
+ cu_seqlens=cu_seqlens,
233
+ head_first=head_first,
234
+ )
235
+
236
+
237
+ def fused_mul_recurrent_rwkv7(
238
+ r: torch.Tensor,
239
+ w: torch.Tensor,
240
+ k: torch.Tensor,
241
+ v: torch.Tensor,
242
+ kk: torch.Tensor,
243
+ a: torch.Tensor,
244
+ scale: Optional[float] = 1.0,
245
+ initial_state: Optional[torch.Tensor] = None,
246
+ output_final_state: bool = False,
247
+ reverse: bool = False,
248
+ cu_seqlens: Optional[torch.Tensor] = None,
249
+ head_first: bool = False,
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ r"""
252
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
253
+
254
+ Args:
255
+ r (torch.Tensor):
256
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
257
+ w (torch.Tensor):
258
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
259
+ k (torch.Tensor):
260
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
261
+ v (torch.Tensor):
262
+ a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
263
+ kk (torch.Tensor):
264
+ b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
265
+ a (torch.Tensor):
266
+ gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space!
267
+ scale (Optional[int]):
268
+ Scale factor for the RetNet attention scores.
269
+ If not provided, it will default to `1 / sqrt(K)`. Default: 1.
270
+ initial_state (Optional[torch.Tensor]):
271
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
272
+ For equal-length input sequences, `N` equals the batch size `B`.
273
+ Default: `None`.
274
+ output_final_state (Optional[bool]):
275
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
276
+ reverse (Optional[bool]):
277
+ If `True`, process the state passing in reverse order. Default: `False`.
278
+ cu_seqlens (Optional[torch.Tensor]):
279
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
280
+ consistent with the FlashAttention API.
281
+ head_first (Optional[bool]):
282
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
283
+ Default: `False`.
284
+ """
285
+ if head_first:
286
+ raise DeprecationWarning(
287
+ "head_first is deprecated and will be removed in a future version. "
288
+ "Please use head_first=False for now instead."
289
+ )
290
+ r, w, k, v, kk, a = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (r, w, k, v, kk, a))
291
+ if not head_first and r.shape[1] < r.shape[2]:
292
+ warnings.warn(
293
+ f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). "
294
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
295
+ "when head_first=False was specified. "
296
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
297
+ )
298
+ if cu_seqlens is not None:
299
+ if r.shape[0] != 1:
300
+ raise ValueError(
301
+ f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`."
302
+ f"Please flatten variable-length inputs before processing."
303
+ )
304
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
305
+ raise ValueError(
306
+ f"The number of initial states is expected to be equal to the number of input sequences, "
307
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
308
+ )
309
+ if scale is None:
310
+ scale = r.shape[-1] ** -0.5
311
+ else:
312
+ assert scale > 0, "scale must be positive"
313
+ o, final_state = fused_recurrent_rwkv7_fwd(
314
+ r,
315
+ w,
316
+ k,
317
+ v,
318
+ kk,
319
+ a,
320
+ scale,
321
+ initial_state,
322
+ output_final_state,
323
+ reverse,
324
+ cu_seqlens,
325
+ )
326
+ if head_first:
327
+ o = rearrange(o, 'b t h ... -> b h t ...')
328
+ return o, final_state
fla3/ops/simple_gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (393 Bytes). View file
 
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
fla3/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (4.78 kB). View file
 
fla3/ops/simple_gla/__pycache__/parallel.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
fla3/ops/simple_gla/parallel.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from fla.ops.utils import prepare_chunk_indices
13
+ from fla.ops.utils.cumsum import chunk_global_cumsum, chunk_local_cumsum
14
+ from fla.ops.utils.op import safe_exp
15
+ from fla.utils import (
16
+ autocast_custom_bwd,
17
+ autocast_custom_fwd,
18
+ check_shared_mem,
19
+ input_guard,
20
+ is_intel_alchemist,
21
+ is_nvidia_hopper
22
+ )
23
+
24
+ # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
25
+ triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
26
+ NUM_WARPS = [2, 4, 8] if is_nvidia_hopper else [2, 4, 8, 16]
27
+
28
+
29
+ @triton.heuristics({
30
+ 'NV': lambda args: triton.cdiv(args['V'], args['BV']),
31
+ 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None,
32
+ 'USE_G': lambda args: args['g'] is not None,
33
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
34
+ })
35
+ @triton.autotune(
36
+ configs=[
37
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
38
+ for num_warps in [2, 4, 8, 16]
39
+ for num_stages in [2, 3, 4]
40
+ ],
41
+ key=["BT", "BS", "BK", "BV", "USE_G"],
42
+ )
43
+ @triton.jit
44
+ def parallel_simple_gla_fwd_kernel(
45
+ q,
46
+ k,
47
+ v,
48
+ g,
49
+ o,
50
+ attn,
51
+ scale,
52
+ cu_seqlens,
53
+ chunk_indices,
54
+ T,
55
+ B: tl.constexpr,
56
+ H: tl.constexpr,
57
+ K: tl.constexpr,
58
+ V: tl.constexpr,
59
+ BT: tl.constexpr,
60
+ BS: tl.constexpr,
61
+ BK: tl.constexpr,
62
+ BV: tl.constexpr,
63
+ NV: tl.constexpr,
64
+ OUTPUT_ATTENTIONS: tl.constexpr,
65
+ IS_VARLEN: tl.constexpr,
66
+ USE_G: tl.constexpr
67
+ ):
68
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
69
+ i_k, i_v = i_kv // NV, i_kv % NV
70
+ i_b, i_h = i_bh // H, i_bh % H
71
+ o += i_k * B * T * H * V
72
+
73
+ if IS_VARLEN:
74
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
75
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
76
+ T = eos - bos
77
+ else:
78
+ bos, eos = i_b * T, i_b * T + T
79
+
80
+ q += (bos * H + i_h) * K
81
+ k += (bos * H + i_h) * K
82
+ v += (bos * H + i_h) * V
83
+ o += (bos * H + i_h) * V
84
+ if USE_G:
85
+ g += bos * H + i_h
86
+ if OUTPUT_ATTENTIONS:
87
+ attn += (bos * H + i_h * T) * T + i_k * B * H * T * T
88
+ stride_qk = H * K
89
+ stride_vo = H * V
90
+ stride_g = H
91
+
92
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
93
+
94
+ # the Q block is kept in the shared memory throughout the whole kernel
95
+ # [BT, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_q = (b_q * scale).to(b_q.dtype)
98
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
99
+
100
+ # [BT]
101
+ o_q = i_t * BT + tl.arange(0, BT)
102
+ # [BS]
103
+ o_k = i_t * BT + tl.arange(0, BS)
104
+ # Q block and K block have overlap.
105
+ # masks required
106
+ if USE_G:
107
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
108
+ # [BT,]
109
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
110
+ # rescale interchunk output
111
+ else:
112
+ b_gq = None
113
+
114
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
115
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1))
116
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BT, BS]
122
+ m_s = o_q[:, None] >= o_k[None, :]
123
+ b_s = tl.dot(b_q, b_k)
124
+ if USE_G:
125
+ p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
126
+ b_gk = tl.load(p_gk, boundary_check=(0,))
127
+ b_s *= safe_exp(b_gq[:, None] - b_gk[None, :])
128
+ b_s = tl.where(m_s, b_s, 0)
129
+ else:
130
+ b_s = tl.where(m_s, b_s, 0)
131
+ # [BT, BV]
132
+ if i_s >= 0:
133
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v)
134
+ if OUTPUT_ATTENTIONS:
135
+ p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0))
136
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
137
+ o_k += BS
138
+
139
+ for i_s in range(i_t * BT - BS, -BS, -BS):
140
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1))
141
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
142
+ # [BK, BS]
143
+ b_k = tl.load(p_k, boundary_check=(0, 1))
144
+ # [BS, BV]
145
+ b_v = tl.load(p_v, boundary_check=(0, 1))
146
+ b_s = tl.dot(b_q, b_k)
147
+ if USE_G:
148
+ p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
149
+ b_g = tl.load(p_g, boundary_check=(0,))
150
+ b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
151
+ b_gp = tl.load(g + (i_s-1) * stride_g) if i_s % BT > 0 else 0.
152
+ # No concrete meaning. Just to avoid some layout bugs.
153
+ b_s *= safe_exp(b_gq[:, None] + (b_gn - b_g)[None, :])
154
+ b_gq += (b_gn - b_gp)
155
+ if OUTPUT_ATTENTIONS:
156
+ p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0))
157
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
158
+ if i_s >= 0:
159
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v)
160
+ p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
161
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
162
+
163
+
164
+ @triton.jit(do_not_specialize=['T'])
165
+ def parallel_simple_gla_bwd_kernel_dq(
166
+ i_t,
167
+ i_k,
168
+ i_v,
169
+ q,
170
+ k,
171
+ v,
172
+ g,
173
+ do,
174
+ dq,
175
+ dg,
176
+ stride_qk,
177
+ stride_vo,
178
+ stride_g,
179
+ scale,
180
+ T,
181
+ K: tl.constexpr,
182
+ V: tl.constexpr,
183
+ BT: tl.constexpr,
184
+ BS: tl.constexpr,
185
+ BK: tl.constexpr,
186
+ BV: tl.constexpr,
187
+ USE_G: tl.constexpr
188
+ ):
189
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
190
+ # [BT, BV]
191
+ b_do = tl.load(p_do, boundary_check=(0, 1))
192
+ # [BT, BK]
193
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
194
+
195
+ for i_s in range(0, i_t * BT, BS):
196
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
197
+ p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1))
198
+ # [BS, BK]
199
+ b_k = tl.load(p_k, boundary_check=(0, 1))
200
+ # [BV, BS]
201
+ b_v = tl.load(p_v, boundary_check=(0, 1))
202
+ # [BT, BV] @ [BV, BS] = [BT, BS]
203
+ b_ds = tl.dot(b_do, b_v)
204
+ if USE_G:
205
+ p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
206
+ b_g = tl.load(p_g, boundary_check=(0,))
207
+ b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
208
+ b_gp = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0.
209
+ b_ds *= safe_exp(b_gn - b_g)[None, :]
210
+ if i_s > 0:
211
+ b_dq *= safe_exp(b_gn - b_gp)
212
+ # [BT, BS] @ [BS, BK] = [BT, BK]
213
+ b_dq += tl.dot(b_ds.to(b_v.dtype), b_k)
214
+
215
+ if USE_G:
216
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
217
+ # [BT,]
218
+ b_gq = tl.load(p_gq, boundary_check=(0,))
219
+ # [BT, BK]
220
+ b_dq *= safe_exp(b_gq)[:, None]
221
+
222
+ # [BT]
223
+ o_q = i_t * BT + tl.arange(0, BT)
224
+ # [BS]
225
+ o_k = i_t * BT + tl.arange(0, BS)
226
+ # Q block and K block have overlap. masks required
227
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
228
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
229
+ p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1))
230
+ # [BS, BK]
231
+ b_k = tl.load(p_k, boundary_check=(0, 1))
232
+ # [BV, BS]
233
+ b_v = tl.load(p_v, boundary_check=(0, 1))
234
+ # [BT, BV] @ [BV, BS] = [BT, BS]
235
+ b_ds = tl.dot(b_do, b_v)
236
+ if USE_G:
237
+ p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
238
+ b_gk = tl.load(p_gk, boundary_check=(0,))
239
+ b_ds *= safe_exp(b_gq[:, None] - b_gk[None, :])
240
+ b_ds = tl.where(o_q[:, None] >= o_k[None, :], b_ds, 0)
241
+ # [BT, BK]
242
+ b_dq += tl.dot(b_ds.to(b_k.dtype), b_k)
243
+ o_k += BS
244
+
245
+ b_dq *= scale
246
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
247
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
248
+ if USE_G:
249
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
250
+ b_q = tl.load(p_q, boundary_check=(0, 1))
251
+ b_dg = tl.sum(b_dq * b_q, 1)
252
+ p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
253
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
254
+
255
+
256
+ @triton.jit(do_not_specialize=['T'])
257
+ def parallel_simple_gla_bwd_kernel_dkv(
258
+ i_t,
259
+ i_k,
260
+ i_v,
261
+ q,
262
+ k,
263
+ v,
264
+ g,
265
+ do,
266
+ dk,
267
+ dv,
268
+ dg,
269
+ scale,
270
+ stride_qk,
271
+ stride_vo,
272
+ stride_g,
273
+ T,
274
+ K: tl.constexpr,
275
+ V: tl.constexpr,
276
+ BT: tl.constexpr,
277
+ BS: tl.constexpr,
278
+ BK: tl.constexpr,
279
+ BV: tl.constexpr,
280
+ USE_G: tl.constexpr
281
+ ):
282
+ # [BT, BK]
283
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
284
+ b_k = tl.load(p_k, boundary_check=(0, 1))
285
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
286
+ # [BT, BV]
287
+ p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
288
+ b_v = tl.load(p_v, boundary_check=(0, 1))
289
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
290
+ if USE_G:
291
+ p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
292
+ b_gk = tl.load(p_gk, boundary_check=(0,))
293
+ NTS = tl.cdiv(T, BS)
294
+ # [BT, BK]
295
+ for i_s in range(NTS * BS - BS, (i_t + 1) * BT - BS, -BS):
296
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
297
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
298
+ b_q = tl.load(p_q, boundary_check=(0, 1))
299
+ b_do = tl.load(p_do, boundary_check=(0, 1))
300
+ b_ds = tl.dot(b_v, tl.trans(b_do))
301
+ b_s = tl.dot(b_k, tl.trans(b_q))
302
+ if USE_G:
303
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
304
+ b_gq = tl.load(p_gq, boundary_check=(0,))
305
+ b_gp = tl.load(g + (min(i_s + BS, T) - 1) * stride_g)
306
+ b_gn = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0.
307
+ if i_s >= 0:
308
+ tmp = safe_exp(b_gp - b_gn)
309
+ b_dk *= tmp
310
+ b_dv *= tmp
311
+ tmp2 = safe_exp(b_gq - b_gn)
312
+ b_ds *= tmp2[None, :]
313
+ b_s *= tmp2[None, :]
314
+ # [BT, BK]
315
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
316
+ # [BT, BV]
317
+ b_dv += tl.dot(b_s.to(b_do.dtype), b_do)
318
+
319
+ if USE_G:
320
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * stride_g)
321
+ if i_t >= 0:
322
+ tmp2 = safe_exp(b_g_last - b_gk)[:, None]
323
+ b_dk *= tmp2
324
+ b_dv *= tmp2
325
+
326
+ o_q = i_t * BT + tl.arange(0, BS)
327
+ o_k = i_t * BT + tl.arange(0, BT)
328
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
329
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0))
330
+ p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
331
+ # [BS, BK]
332
+ b_q = tl.load(p_q, boundary_check=(0, 1))
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_ds = tl.dot(b_v, tl.trans(b_do))
337
+ b_s = tl.dot(b_k, tl.trans(b_q))
338
+ if USE_G:
339
+ p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,))
340
+ b_gq = tl.load(p_gq, boundary_check=(0,))
341
+ if i_s >= 0:
342
+ tmp = safe_exp(-b_gk[:, None] + b_gq[None, :])
343
+ b_ds *= tmp
344
+ b_s *= tmp
345
+ m_s = o_k[:, None] <= o_q[None, :]
346
+ b_s = tl.where(m_s, b_s, 0)
347
+ b_ds = tl.where(m_s, b_ds, 0)
348
+ # [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+ b_dv += tl.dot(b_s.to(b_do.dtype), b_do)
351
+ o_q += BS
352
+ b_dk *= scale
353
+ b_dv *= scale
354
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
355
+ p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
356
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
357
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
358
+ if USE_G:
359
+ p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,))
360
+ b_dg = tl.load(p_dg, boundary_check=(0,))
361
+ b_dg -= tl.sum(b_dk * b_k, 1)
362
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
363
+
364
+
365
+ @triton.heuristics({
366
+ 'NV': lambda args: triton.cdiv(args['V'], args['BV']),
367
+ 'USE_G': lambda args: args['g'] is not None,
368
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
369
+ })
370
+ @triton.autotune(
371
+ configs=[
372
+ triton.Config(triton_config, num_warps=num_warps)
373
+ for num_warps in NUM_WARPS
374
+ ],
375
+ key=['BT', 'BS', 'BK', 'BV', 'USE_G'],
376
+ )
377
+ @triton.jit(do_not_specialize=['T'])
378
+ def parallel_simple_gla_bwd_kernel(
379
+ q,
380
+ k,
381
+ v,
382
+ g,
383
+ do,
384
+ dq,
385
+ dk,
386
+ dv,
387
+ dg,
388
+ scale,
389
+ cu_seqlens,
390
+ chunk_indices,
391
+ T,
392
+ B: tl.constexpr,
393
+ H: tl.constexpr,
394
+ K: tl.constexpr,
395
+ V: tl.constexpr,
396
+ BT: tl.constexpr,
397
+ BS: tl.constexpr,
398
+ BK: tl.constexpr,
399
+ BV: tl.constexpr,
400
+ NV: tl.constexpr,
401
+ IS_VARLEN: tl.constexpr,
402
+ USE_G: tl.constexpr
403
+ ):
404
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
405
+ i_k, i_v = i_kv // NV, i_kv % NV
406
+ i_b, i_h = i_bh // H, i_bh % H
407
+ dq += i_v * B * H * T * K
408
+ dk += i_v * B * H * T * K
409
+ dv += i_k * B * H * T * V
410
+ if USE_G:
411
+ dg += i_kv * B * H * T
412
+
413
+ if IS_VARLEN:
414
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
415
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
416
+ T = eos - bos
417
+ else:
418
+ bos, eos = i_b * T, i_b * T + T
419
+
420
+ q += (bos * H + i_h) * K
421
+ k += (bos * H + i_h) * K
422
+ v += (bos * H + i_h) * V
423
+ do += (bos * H + i_h) * V
424
+ dq += (bos * H + i_h) * K
425
+ dk += (bos * H + i_h) * K
426
+ dv += (bos * H + i_h) * V
427
+ if USE_G:
428
+ g += bos * H + i_h
429
+ dg += bos * H + i_h
430
+ stride_qk = H * K
431
+ stride_vo = H * V
432
+ stride_g = H
433
+
434
+ parallel_simple_gla_bwd_kernel_dq(
435
+ i_t=i_t,
436
+ i_k=i_k,
437
+ i_v=i_v,
438
+ q=q,
439
+ k=k,
440
+ v=v,
441
+ g=g,
442
+ do=do,
443
+ dq=dq,
444
+ dg=dg,
445
+ scale=scale,
446
+ stride_qk=stride_qk,
447
+ stride_vo=stride_vo,
448
+ stride_g=stride_g,
449
+ T=T,
450
+ K=K,
451
+ V=V,
452
+ BT=BT,
453
+ BS=BS,
454
+ BK=BK,
455
+ BV=BV,
456
+ USE_G=USE_G
457
+ )
458
+ tl.debug_barrier()
459
+ parallel_simple_gla_bwd_kernel_dkv(
460
+ i_t=i_t,
461
+ i_k=i_k,
462
+ i_v=i_v,
463
+ q=q,
464
+ k=k,
465
+ v=v,
466
+ g=g,
467
+ do=do,
468
+ dk=dk,
469
+ dv=dv,
470
+ dg=dg,
471
+ scale=scale,
472
+ stride_qk=stride_qk,
473
+ stride_vo=stride_vo,
474
+ stride_g=stride_g,
475
+ T=T,
476
+ K=K,
477
+ V=V,
478
+ BT=BT,
479
+ BS=BS,
480
+ BK=BK,
481
+ BV=BV,
482
+ USE_G=USE_G
483
+ )
484
+
485
+
486
+ def parallel_simple_gla_fwd(
487
+ q: torch.Tensor,
488
+ k: torch.Tensor,
489
+ v: torch.Tensor,
490
+ g: torch.Tensor,
491
+ scale: float,
492
+ output_attentions: bool = False,
493
+ chunk_size: int = 128,
494
+ cu_seqlens: Optional[torch.LongTensor] = None,
495
+ ):
496
+ B, T, H, K, V = *k.shape, v.shape[-1]
497
+ BT, BS = chunk_size, 32
498
+ if check_shared_mem('hopper', k.device.index):
499
+ BK = min(256, triton.next_power_of_2(K))
500
+ BV = min(256, triton.next_power_of_2(V))
501
+ elif check_shared_mem('ampere', k.device.index):
502
+ BK = min(128, triton.next_power_of_2(K))
503
+ BV = min(128, triton.next_power_of_2(V))
504
+ else:
505
+ BK = min(64, triton.next_power_of_2(K))
506
+ BV = min(64, triton.next_power_of_2(V))
507
+
508
+ NK = triton.cdiv(K, BK)
509
+ NV = triton.cdiv(V, BV)
510
+ assert BT % BS == 0
511
+
512
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
513
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
514
+
515
+ # local cumulative decay in log space
516
+ if g is not None:
517
+ g = chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens)
518
+ grid = (NK * NV, NT, B * H)
519
+ o = torch.empty(NK, *v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device)
520
+ attn = q.new_zeros(NK, B, H, T, T) if output_attentions else None
521
+
522
+ parallel_simple_gla_fwd_kernel[grid](
523
+ q=q,
524
+ k=k,
525
+ v=v,
526
+ g=g,
527
+ o=o,
528
+ attn=attn,
529
+ scale=scale,
530
+ cu_seqlens=cu_seqlens,
531
+ chunk_indices=chunk_indices,
532
+ B=B,
533
+ H=H,
534
+ T=T,
535
+ K=K,
536
+ V=V,
537
+ BT=BT,
538
+ BS=BS,
539
+ BK=BK,
540
+ BV=BV,
541
+ )
542
+ o = o.sum(0)
543
+
544
+ if output_attentions:
545
+ attn = attn.sum(0)
546
+ return o, g, attn
547
+
548
+
549
+ def parallel_simple_gla_bwd(
550
+ q: torch.Tensor,
551
+ k: torch.Tensor,
552
+ v: torch.Tensor,
553
+ g: torch.Tensor,
554
+ do: torch.Tensor,
555
+ scale: float,
556
+ chunk_size: int = 128,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ ):
559
+ B, T, H, K, V = *k.shape, v.shape[-1]
560
+ BT, BS = chunk_size, 32
561
+ if check_shared_mem('hopper', k.device.index):
562
+ BK = min(256, triton.next_power_of_2(K))
563
+ BV = min(256, triton.next_power_of_2(V))
564
+ elif check_shared_mem('ampere', k.device.index):
565
+ BK = min(128, triton.next_power_of_2(K))
566
+ BV = min(128, triton.next_power_of_2(V))
567
+ elif check_shared_mem('ada', k.device.index):
568
+ BK = min(64, triton.next_power_of_2(K))
569
+ BV = min(64, triton.next_power_of_2(V))
570
+ else:
571
+ BK = min(32, triton.next_power_of_2(K))
572
+ BV = min(32, triton.next_power_of_2(V))
573
+
574
+ NK = triton.cdiv(K, BK)
575
+ NV = triton.cdiv(V, BV)
576
+ assert BT % BS == 0
577
+
578
+ dq = torch.empty(NV, * q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
579
+ dk = torch.empty(NV, * k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
580
+ dv = torch.empty(NK, * v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device)
581
+ dg = torch.empty(NK*NV, *g.shape, dtype=torch.float, device=q.device) if g is not None else None
582
+
583
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
584
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
585
+
586
+ grid = (NK * NV, NT, B * H)
587
+ parallel_simple_gla_bwd_kernel[grid](
588
+ q=q,
589
+ k=k,
590
+ v=v,
591
+ g=g,
592
+ do=do,
593
+ dq=dq,
594
+ dk=dk,
595
+ dv=dv,
596
+ dg=dg,
597
+ cu_seqlens=cu_seqlens,
598
+ chunk_indices=chunk_indices,
599
+ scale=scale,
600
+ T=T,
601
+ B=B,
602
+ H=H,
603
+ K=K,
604
+ V=V,
605
+ BT=BT,
606
+ BS=BS,
607
+ BK=BK,
608
+ BV=BV,
609
+ )
610
+ dq = dq.sum(0)
611
+ dk = dk.sum(0)
612
+ dv = dv.sum(0)
613
+ dg = chunk_global_cumsum(dg.sum(0), reverse=True, cu_seqlens=cu_seqlens) if g is not None else None
614
+ return dq, dk, dv, dg
615
+
616
+
617
+ class ParallelSimpleGLAFunction(torch.autograd.Function):
618
+
619
+ @staticmethod
620
+ @input_guard
621
+ @autocast_custom_fwd
622
+ def forward(ctx, q, k, v, g, scale, output_attentions, cu_seqlens):
623
+ chunk_size = 128
624
+ ctx.dtype = q.dtype
625
+
626
+ o, g, attn = parallel_simple_gla_fwd(
627
+ q=q,
628
+ k=k,
629
+ v=v,
630
+ g=g,
631
+ scale=scale,
632
+ output_attentions=output_attentions,
633
+ chunk_size=chunk_size,
634
+ cu_seqlens=cu_seqlens,
635
+ )
636
+ ctx.save_for_backward(q, k, v, g, cu_seqlens)
637
+ ctx.scale = scale
638
+ ctx.chunk_size = chunk_size
639
+ return o.to(q.dtype), attn
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do, da=None):
645
+ q, k, v, g, cu_seqlens = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_simple_gla_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ do=do,
652
+ scale=ctx.scale,
653
+ chunk_size=ctx.chunk_size,
654
+ cu_seqlens=cu_seqlens,
655
+ )
656
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.dtype) if dg is not None else None, None, None, None
657
+
658
+
659
+ def parallel_simple_gla(
660
+ q: torch.Tensor,
661
+ k: torch.Tensor,
662
+ v: torch.Tensor,
663
+ g: Optional[torch.Tensor] = None,
664
+ scale: Optional[float] = None,
665
+ output_attentions: bool = False,
666
+ cu_seqlens: Optional[torch.LongTensor] = None,
667
+ head_first: bool = False
668
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
669
+ r"""
670
+ Args:
671
+ q (torch.Tensor):
672
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`
673
+ k (torch.Tensor):
674
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`
675
+ v (torch.Tensor):
676
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`
677
+ g (torch.Tensor):
678
+ Forget gates of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
679
+ Compared to GLA, the gating is head-wise instead of elementwise.
680
+ scale (Optional[int]):
681
+ Scale factor for attention scores.
682
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
683
+ output_attentions (bool):
684
+ Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`.
685
+ head_first (Optional[bool]):
686
+ Whether the inputs are in the head-first format. Default: `False`.
687
+ cu_seqlens (torch.LongTensor):
688
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
689
+ consistent with the FlashAttention API.
690
+
691
+ Returns:
692
+ o (torch.Tensor):
693
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
694
+ attn (torch.Tensor):
695
+ Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`
696
+ """
697
+ if head_first:
698
+ raise DeprecationWarning(
699
+ "head_first is deprecated and will be removed in a future version. "
700
+ "Please use head_first=False for now instead."
701
+ )
702
+ q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g))
703
+ if not head_first and q.shape[1] < q.shape[2]:
704
+ warnings.warn(
705
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
706
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
707
+ "when head_first=False was specified. "
708
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
709
+ )
710
+ if cu_seqlens is not None:
711
+ if q.shape[0] != 1:
712
+ raise ValueError(
713
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
714
+ f"Please flatten variable-length inputs before processing."
715
+ )
716
+ if output_attentions:
717
+ assert cu_seqlens is None, "output_attentions=True is not supported with variable-length sequences"
718
+
719
+ if scale is None:
720
+ scale = k.shape[-1] ** -0.5
721
+ o, attn = ParallelSimpleGLAFunction.apply(
722
+ q,
723
+ k,
724
+ v,
725
+ g,
726
+ scale,
727
+ output_attentions,
728
+ cu_seqlens
729
+ )
730
+ if head_first:
731
+ o = rearrange(o, 'b t h ... -> b h t ...')
732
+ return o, attn
fla3/ops/ttt/fused_chunk.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from fla.modules.layernorm import group_norm
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, is_nvidia_hopper
14
+
15
+ NUM_WARPS = [1, 2] if is_nvidia_hopper else [1, 2, 4, 8]
16
+
17
+
18
+ @triton.heuristics({
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
21
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
22
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
23
+ })
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({}, num_warps=1),
27
+ triton.Config({}, num_warps=2),
28
+ triton.Config({}, num_warps=4)
29
+ ],
30
+ key=['BT', 'BK', 'BV'],
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def fused_chunk_ttt_linear_fwd_kernel(
34
+ q,
35
+ k,
36
+ v,
37
+ eta,
38
+ w,
39
+ b,
40
+ o,
41
+ scale,
42
+ eps,
43
+ h0,
44
+ hb0,
45
+ ht,
46
+ hbt,
47
+ cu_seqlens,
48
+ T,
49
+ H: tl.constexpr,
50
+ K: tl.constexpr,
51
+ V: tl.constexpr,
52
+ BT: tl.constexpr,
53
+ BK: tl.constexpr,
54
+ BV: tl.constexpr,
55
+ USE_INITIAL_STATE: tl.constexpr,
56
+ USE_INITIAL_STATE_B: tl.constexpr,
57
+ STORE_FINAL_STATE: tl.constexpr,
58
+ IS_VARLEN: tl.constexpr,
59
+ ):
60
+ i_nh = tl.program_id(0)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if IS_VARLEN:
63
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+
70
+ o_i = tl.arange(0, BT)
71
+ v_i = tl.arange(0, BV)
72
+ m_A = o_i[:, None] >= o_i[None, :]
73
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
74
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
75
+
76
+ # [BK, BV]
77
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
78
+ # [BV]
79
+ b_hb = tl.zeros([BV], dtype=tl.float32)
80
+ if USE_INITIAL_STATE:
81
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
82
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
83
+ if USE_INITIAL_STATE_B:
84
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
85
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
86
+
87
+ for i_t in range(NT):
88
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
90
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
91
+ p_o = tl.make_block_ptr(o+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
92
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
93
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
94
+ # [BK, BT]
95
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
96
+ # [BT, BV]
97
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
98
+
99
+ # [BT, BV]
100
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
101
+ b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
102
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
103
+ xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
104
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
105
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
106
+ b_kh_hat = (b_kh - mean) * rstd
107
+
108
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
109
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
110
+ b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
111
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
112
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
113
+
114
+ # [BT, BK]
115
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
116
+ # [BT]
117
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
118
+ b_q = (b_q * scale).to(b_k.dtype)
119
+
120
+ # [BT, BT]
121
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
122
+ b_A = tl.where(m_A, b_A, 0)
123
+ b_Ae = tl.where(m_A, b_e[:, None], 0.0)
124
+
125
+ b_o = - tl.dot(b_e[:, None] * b_A.to(b_v2.dtype), b_v2, allow_tf32=False)
126
+ b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v2.dtype), b_v2, allow_tf32=False)
127
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
128
+ b_e_last = tl.load(p_e_last)
129
+ b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
130
+ b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
131
+ b_h = tl.where((v_i < V)[None, :], b_h, 0.)
132
+ b_hb = tl.where((v_i < V), b_hb, 0.)
133
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
134
+
135
+ if STORE_FINAL_STATE:
136
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
137
+ p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
138
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
139
+ tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,))
140
+
141
+
142
+ @triton.heuristics({
143
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
144
+ 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None,
145
+ })
146
+ @triton.autotune(
147
+ configs=[
148
+ triton.Config({}, num_warps=1),
149
+ triton.Config({}, num_warps=2),
150
+ triton.Config({}, num_warps=4)
151
+ ],
152
+ key=['BT', 'BK', 'BV'],
153
+ )
154
+ @triton.jit(do_not_specialize=['T'])
155
+ def fused_chunk_ttt_linear_bwd_kernel_h(
156
+ k,
157
+ v,
158
+ v2,
159
+ x,
160
+ y,
161
+ r,
162
+ w,
163
+ b,
164
+ eta,
165
+ h0,
166
+ hb0,
167
+ h,
168
+ do,
169
+ dq,
170
+ scale,
171
+ eps,
172
+ T,
173
+ H: tl.constexpr,
174
+ K: tl.constexpr,
175
+ V: tl.constexpr,
176
+ BT: tl.constexpr,
177
+ BK: tl.constexpr,
178
+ BV: tl.constexpr,
179
+ USE_INITIAL_STATE: tl.constexpr,
180
+ USE_INITIAL_STATE_B: tl.constexpr,
181
+ ):
182
+ i_nh = tl.program_id(0)
183
+ i_n, i_h = i_nh // H, i_nh % H
184
+ bos, _ = i_n * T, i_n * T + T
185
+ NT = tl.cdiv(T, BT)
186
+ boh = i_n * NT
187
+
188
+ o_i = tl.arange(0, BT)
189
+ v_i = tl.arange(0, BV)
190
+ m_A = o_i[:, None] >= o_i[None, :]
191
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
192
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
193
+
194
+ # [BK, BV]
195
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
196
+ # [BV]
197
+ b_hb = tl.zeros([BV], dtype=tl.float32)
198
+ if USE_INITIAL_STATE:
199
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
200
+ b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
201
+ if USE_INITIAL_STATE_B:
202
+ p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
203
+ b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32)
204
+
205
+ for i_t in range(NT):
206
+ p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
207
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
208
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
209
+ p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
210
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
211
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
212
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
213
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
214
+ p_dq = tl.make_block_ptr(dq+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
215
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
216
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+ # [BK, BT]
219
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
220
+ # [BT, BV]
221
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
222
+
223
+ b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :]
224
+ b_kh = tl.where((v_i < V)[None, :], b_kh, 0.)
225
+ mean = tl.sum(b_kh, axis=1, keep_dims=True) / V
226
+ xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.)
227
+ var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V
228
+ rstd = 1 / tl.sqrt(var.to(tl.float32) + eps)
229
+ b_kh_hat = (b_kh - mean) * rstd
230
+
231
+ b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \
232
+ b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k)
233
+ b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.)
234
+ b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype)
235
+ * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V
236
+ tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1))
237
+ tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1))
238
+ tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1))
239
+ tl.store(p_v2, b_v2.to(p_v2.dtype.element_ty), boundary_check=(0, 1))
240
+
241
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
242
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
243
+
244
+ b_v2 = tl.where((v_i < V)[None, :], b_v2, 0.)
245
+ b_ds = tl.dot(b_do, tl.trans(b_v2).to(b_do.dtype))
246
+ b_ds = tl.where(m_A, b_ds, 0)
247
+ b_ds = b_ds.to(b_k.dtype)
248
+ b_dq = tl.dot(b_do, tl.trans(b_h).to(b_do.dtype))
249
+ b_dq -= tl.dot(b_ds, tl.trans(b_k)) * b_e[:, None]
250
+ b_dq *= scale
251
+
252
+ b_e_last = tl.load(p_e_last)
253
+ b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False)
254
+ b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0)
255
+ b_h = tl.where((v_i < V)[None, :], b_h, 0.)
256
+ b_hb = tl.where((v_i < V), b_hb, 0.)
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
262
+ 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None,
263
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
264
+ 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None,
265
+ })
266
+ @triton.autotune(
267
+ configs=[
268
+ triton.Config({}, num_warps=num_warps)
269
+ for num_warps in NUM_WARPS
270
+ ],
271
+ key=['BT', 'BK', 'BV'],
272
+ )
273
+ @triton.jit(do_not_specialize=['T'])
274
+ def fused_chunk_ttt_linear_bwd_kernel_dh(
275
+ q,
276
+ k,
277
+ v,
278
+ v2,
279
+ x,
280
+ y,
281
+ r,
282
+ w,
283
+ b,
284
+ eta,
285
+ h,
286
+ dht,
287
+ dhbt,
288
+ dh0,
289
+ dhb0,
290
+ do,
291
+ dk,
292
+ dv,
293
+ de,
294
+ dw,
295
+ db,
296
+ scale,
297
+ T,
298
+ H: tl.constexpr,
299
+ K: tl.constexpr,
300
+ V: tl.constexpr,
301
+ BT: tl.constexpr,
302
+ BK: tl.constexpr,
303
+ BV: tl.constexpr,
304
+ USE_INITIAL_STATE: tl.constexpr,
305
+ USE_INITIAL_STATE_B: tl.constexpr,
306
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
307
+ USE_FINAL_STATE_GRADIENT_B: tl.constexpr,
308
+ ):
309
+ i_nh = tl.program_id(0)
310
+ i_n, i_h = i_nh // H, i_nh % H
311
+ bos, _ = i_n * T, i_n * T + T
312
+ NT = tl.cdiv(T, BT)
313
+ boh = i_n * NT
314
+
315
+ # [BK, BV]
316
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
317
+ # [BV]
318
+ b_dhb = tl.zeros([BV], dtype=tl.float32)
319
+ if USE_FINAL_STATE_GRADIENT:
320
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
321
+ b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero")
322
+ if USE_FINAL_STATE_GRADIENT_B:
323
+ p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
324
+ b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero")
325
+
326
+ # [BV]
327
+ o_i = tl.arange(0, BT)
328
+ v_i = tl.arange(0, BV)
329
+ m_A = o_i[:, None] >= o_i[None, :]
330
+ m_A_t = o_i[:, None] <= o_i[None, :]
331
+ b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.)
332
+ b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.)
333
+ b_dw = tl.zeros([BV,], dtype=b_w.dtype)
334
+ b_db = tl.zeros([BV,], dtype=b_b.dtype)
335
+ p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
336
+ p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (0,), (BV,), (0,))
337
+
338
+ for i_t in range(NT - 1, -1, -1):
339
+ p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1))
340
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1))
341
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
342
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
343
+ p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
344
+ p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
345
+ p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
346
+ p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0))
347
+ p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
348
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
349
+ p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0))
350
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
351
+ p_de = tl.make_block_ptr(de+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,))
352
+ p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H
353
+ b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero")
354
+ b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero")
355
+ b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero")
356
+ b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero")
357
+ b_e_last = tl.load(p_e_last)
358
+ b_A = tl.dot(b_k, b_q)
359
+ b_A = - tl.where(m_A_t, b_A * scale * b_e[None, :], 0).to(do.dtype.element_ty)
360
+ b_Ae = - tl.where(m_A_t, b_e[None, :], 0).to(do.dtype.element_ty)
361
+ b_dv_new = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do)
362
+ b_dv_new -= tl.dot(b_e_last * b_k, b_dh.to(b_k.dtype))
363
+ b_dv_new -= b_e_last * b_dhb.to(b_k.dtype)[None, :]
364
+
365
+ b_v2 = tl.load(p_v2, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
366
+ b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
367
+ b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype)
368
+ b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
369
+ b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) -
370
+ b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
371
+ b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) +
372
+ b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V
373
+ b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v2.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True)
374
+
375
+ b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero")
376
+ b_w = b_w.to(b_k.dtype)
377
+ b_b = b_b.to(b_k.dtype)
378
+ b_dv = -b_w * b_dy.to(b_k.dtype)
379
+ b_dk = b_w * b_dy.to(b_k.dtype)
380
+ b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) +
381
+ (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype)
382
+ b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype)
383
+ b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype)
384
+
385
+ b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero")
386
+ b_q = (b_q * scale).to(b_q.dtype)
387
+ b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) -
388
+ b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V
389
+ b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V
390
+ b_dkh = tl.where((v_i < V)[None, :] * (o_i < T-i_t*BT)[:, None], b_dkh, 0.)
391
+ b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype)
392
+
393
+ b_ds = tl.dot(b_do, tl.trans(b_v2))
394
+ b_ds = tl.where(m_A, b_ds, 0)
395
+ b_ds = b_ds.to(b_k.dtype)
396
+ i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1)
397
+ mask = (o_i == i_last)
398
+ b_dk -= b_e_last * tl.dot(b_v2, tl.trans(b_dh).to(b_v2.dtype))
399
+ b_dk -= tl.dot(tl.trans(b_ds), tl.trans(b_q) * b_e[:, None])
400
+ b_de = mask * tl.sum(- b_dh * tl.trans(tl.dot(tl.trans(b_v2), b_k))).to(b_k.dtype)
401
+ b_de -= mask * tl.sum(b_dhb * tl.sum(b_v2, axis=0)).to(b_k.dtype)
402
+ b_de -= tl.sum(tl.dot(b_ds, b_k) * tl.trans(b_q).to(b_k.dtype), axis=1)
403
+ b_de -= tl.sum(b_ds, axis=1)
404
+ b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh)
405
+ b_dhb += tl.sum(b_do + b_dkh, axis=0)
406
+ b_dh = tl.where((v_i < V)[None, :], b_dh, 0.)
407
+ b_dhb = tl.where((v_i < V), b_dhb, 0.)
408
+
409
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
410
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
411
+ tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,))
412
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,))
413
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))
414
+
415
+ if USE_INITIAL_STATE:
416
+ p_dh0 = tl.make_block_ptr(dh0+i_nh*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0))
417
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
418
+ if USE_INITIAL_STATE_B:
419
+ p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (0,), (BV,), (0,))
420
+ tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,))
421
+
422
+
423
+ def fused_chunk_ttt_linear_bwd_h(
424
+ q: torch.Tensor,
425
+ k: torch.Tensor,
426
+ v: torch.Tensor,
427
+ w: torch.Tensor,
428
+ b: torch.Tensor,
429
+ eta: torch.Tensor,
430
+ scale: float,
431
+ eps: float,
432
+ do: torch.Tensor,
433
+ BT: int = 16,
434
+ initial_state: torch.Tensor = None,
435
+ initial_state_bias: torch.Tensor = None,
436
+ cu_seqlens: Optional[torch.LongTensor] = None,
437
+ ):
438
+ assert cu_seqlens is None, "bwd of varlen is not implemented yet."
439
+ B, T, H, K, V = *k.shape, v.shape[-1]
440
+ # N: the actual number of sequences in the batch with either equal or variable lengths
441
+ N, NT = B, triton.cdiv(T, BT)
442
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
443
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
444
+
445
+ h = k.new_empty(B, NT, H, K, V)
446
+ r = v.new_empty(B, T, H, 1, dtype=torch.float32)
447
+ v2 = torch.empty_like(v)
448
+ x = torch.empty_like(v)
449
+ y = torch.empty_like(v)
450
+ dq = torch.empty_like(q)
451
+
452
+ grid = (N * H,)
453
+ fused_chunk_ttt_linear_bwd_kernel_h[grid](
454
+ k=k,
455
+ v=v,
456
+ v2=v2,
457
+ x=x,
458
+ y=y,
459
+ r=r,
460
+ w=w,
461
+ b=b,
462
+ eta=eta,
463
+ h0=initial_state,
464
+ hb0=initial_state_bias,
465
+ h=h,
466
+ do=do,
467
+ dq=dq,
468
+ scale=scale,
469
+ eps=eps,
470
+ T=T,
471
+ H=H,
472
+ K=K,
473
+ V=V,
474
+ BT=BT,
475
+ BK=BK,
476
+ BV=BV,
477
+ )
478
+ return dq, h, v2, x, y, r
479
+
480
+
481
+ def fused_chunk_ttt_linear_bwd_dh(
482
+ q: torch.Tensor,
483
+ k: torch.Tensor,
484
+ v: torch.Tensor,
485
+ v2: torch.Tensor,
486
+ x: torch.Tensor,
487
+ y: torch.Tensor,
488
+ r: torch.Tensor,
489
+ w: torch.Tensor,
490
+ b: torch.Tensor,
491
+ eta: torch.Tensor,
492
+ scale: float,
493
+ h: torch.Tensor,
494
+ do: torch.Tensor,
495
+ dht: torch.Tensor,
496
+ dhbt: torch.Tensor,
497
+ BT: int = 16,
498
+ initial_state: torch.Tensor = None,
499
+ initial_state_bias: torch.Tensor = None,
500
+ cu_seqlens: Optional[torch.LongTensor] = None,
501
+ ):
502
+ assert cu_seqlens is None, "bwd of varlen is not implemented yet."
503
+ B, T, H, K, V = *k.shape, v.shape[-1]
504
+ # N: the actual number of sequences in the batch with either equal or variable lengths
505
+ N = B
506
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
507
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
508
+
509
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32) if initial_state is not None else None
510
+ dhb0 = torch.empty_like(initial_state_bias, dtype=torch.float32) if initial_state_bias is not None else None
511
+ dk = torch.empty_like(k)
512
+ dv = torch.empty_like(v)
513
+ de = torch.empty_like(eta)
514
+ dw = w.new_empty(B, H, V)
515
+ db = b.new_empty(B, H, V)
516
+
517
+ grid = (N * H,)
518
+ fused_chunk_ttt_linear_bwd_kernel_dh[grid](
519
+ q=q,
520
+ k=k,
521
+ v=v,
522
+ v2=v2,
523
+ x=x,
524
+ y=y,
525
+ r=r,
526
+ w=w,
527
+ b=b,
528
+ eta=eta,
529
+ h=h,
530
+ dht=dht,
531
+ dhbt=dhbt,
532
+ dh0=dh0,
533
+ dhb0=dhb0,
534
+ do=do,
535
+ dk=dk,
536
+ dv=dv,
537
+ de=de,
538
+ dw=dw,
539
+ db=db,
540
+ scale=scale,
541
+ T=T,
542
+ H=H,
543
+ K=K,
544
+ V=V,
545
+ BT=BT,
546
+ BK=BK,
547
+ BV=BV,
548
+ )
549
+ dw = dw.sum(dim=0)
550
+ db = db.sum(dim=0)
551
+ return dk, dv, de, dw, db, dh0, dhb0
552
+
553
+
554
+ def fused_chunk_ttt_linear_fwd(
555
+ q: torch.Tensor,
556
+ k: torch.Tensor,
557
+ v: torch.Tensor,
558
+ w: torch.Tensor,
559
+ b: torch.Tensor,
560
+ eta: torch.Tensor,
561
+ scale: float,
562
+ eps: float,
563
+ initial_state: torch.Tensor,
564
+ initial_state_bias: torch.Tensor,
565
+ output_final_state: bool,
566
+ cu_seqlens: Optional[torch.LongTensor] = None,
567
+ BT: int = 16
568
+ ):
569
+ B, T, H, K, V = *k.shape, v.shape[-1]
570
+ # N: the actual number of sequences in the batch with either equal or variable lengths
571
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
572
+ BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V)
573
+ assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128."
574
+ o = torch.empty_like(v)
575
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
576
+ final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None
577
+
578
+ grid = (N * H,)
579
+ fused_chunk_ttt_linear_fwd_kernel[grid](
580
+ q=q,
581
+ k=k,
582
+ v=v,
583
+ eta=eta,
584
+ w=w,
585
+ b=b,
586
+ o=o,
587
+ scale=scale,
588
+ eps=eps,
589
+ h0=initial_state,
590
+ hb0=initial_state_bias,
591
+ ht=final_state,
592
+ hbt=final_state_bias,
593
+ cu_seqlens=cu_seqlens,
594
+ T=T,
595
+ H=H,
596
+ K=K,
597
+ V=V,
598
+ BT=BT,
599
+ BK=BK,
600
+ BV=BV,
601
+ )
602
+ return o, final_state, final_state_bias
603
+
604
+
605
+ def fused_chunk_ttt_linear_bwd(
606
+ q: torch.Tensor,
607
+ k: torch.Tensor,
608
+ v: torch.Tensor,
609
+ w: torch.Tensor,
610
+ b: torch.Tensor,
611
+ eta: torch.Tensor,
612
+ scale: float,
613
+ eps: float,
614
+ do: torch.Tensor,
615
+ dht: torch.Tensor,
616
+ dhbt: torch.Tensor,
617
+ BT: int = 16,
618
+ initial_state: torch.Tensor = None,
619
+ initial_state_bias: torch.Tensor = None,
620
+ cu_seqlens: Optional[torch.LongTensor] = None,
621
+ ):
622
+ assert cu_seqlens is None, "bwd of varlen is not implemented yet."
623
+ dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h(
624
+ q=q,
625
+ k=k,
626
+ v=v,
627
+ w=w,
628
+ b=b,
629
+ eta=eta,
630
+ scale=scale,
631
+ eps=eps,
632
+ do=do,
633
+ BT=BT,
634
+ initial_state=initial_state,
635
+ initial_state_bias=initial_state_bias,
636
+ cu_seqlens=cu_seqlens,
637
+ )
638
+ dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd_dh(
639
+ q=q,
640
+ k=k,
641
+ v=v,
642
+ v2=v2,
643
+ x=x,
644
+ y=y,
645
+ r=rstd,
646
+ w=w,
647
+ b=b,
648
+ eta=eta,
649
+ scale=scale,
650
+ h=h,
651
+ do=do,
652
+ dht=dht,
653
+ dhbt=dhbt,
654
+ BT=BT,
655
+ initial_state=initial_state,
656
+ initial_state_bias=initial_state_bias,
657
+ cu_seqlens=cu_seqlens,
658
+ )
659
+ return dq, dk, dv, de, dw, db, dh0, dhb0
660
+
661
+
662
+ class FusedChunkTTTLinearFunction(torch.autograd.Function):
663
+
664
+ @staticmethod
665
+ @input_guard
666
+ @autocast_custom_fwd
667
+ def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state,
668
+ initial_state_bias, output_final_state, cu_seqlens):
669
+ o, final_state, final_state_bias = fused_chunk_ttt_linear_fwd(
670
+ q=q,
671
+ k=k,
672
+ v=v,
673
+ w=w,
674
+ b=b,
675
+ eta=eta,
676
+ scale=scale,
677
+ eps=eps,
678
+ BT=BT,
679
+ initial_state=initial_state,
680
+ initial_state_bias=initial_state_bias,
681
+ output_final_state=output_final_state,
682
+ cu_seqlens=cu_seqlens,
683
+ )
684
+ ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias)
685
+ ctx.BT = BT
686
+ ctx.scale = scale
687
+ ctx.eps = eps
688
+ ctx.cu_seqlens = cu_seqlens
689
+ return o.to(q.dtype), final_state, final_state_bias
690
+
691
+ @staticmethod
692
+ @input_guard
693
+ @autocast_custom_bwd
694
+ def backward(ctx, do, dht, dhbt):
695
+ q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors
696
+ dq, dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd(
697
+ q=q,
698
+ k=k,
699
+ v=v,
700
+ w=w,
701
+ b=b,
702
+ eta=eta,
703
+ scale=ctx.scale,
704
+ eps=ctx.eps,
705
+ do=do,
706
+ dht=dht,
707
+ dhbt=dhbt,
708
+ BT=ctx.BT,
709
+ initial_state=initial_state,
710
+ initial_state_bias=initial_state_bias,
711
+ cu_seqlens=ctx.cu_seqlens,
712
+ )
713
+ return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None
714
+
715
+
716
+ def norm_residual(x, weight, bias, eps):
717
+ # GroupNorm and Residual
718
+ B, T, H, D = x.shape
719
+ x += group_norm(
720
+ x.reshape(B, T, -1).clone(),
721
+ weight=weight.reshape(-1).clone(),
722
+ bias=bias.reshape(-1).clone(),
723
+ eps=eps,
724
+ num_groups=H,
725
+ ).reshape(x.shape)
726
+ return x
727
+
728
+
729
+ def fused_chunk_ttt_linear(
730
+ q: torch.Tensor,
731
+ k: torch.Tensor,
732
+ v: torch.Tensor,
733
+ w: torch.Tensor,
734
+ b: torch.Tensor,
735
+ eta: torch.Tensor,
736
+ scale: float = None,
737
+ eps: float = 1e-6,
738
+ chunk_size: int = 16,
739
+ initial_state: torch.Tensor = None,
740
+ initial_state_bias: torch.Tensor = None,
741
+ output_final_state: bool = False,
742
+ cu_seqlens: Optional[torch.LongTensor] = None,
743
+ head_first: bool = False,
744
+ ):
745
+ r"""
746
+ Args:
747
+ q (torch.Tensor):
748
+ queries of shape `(B, H, T, K)`
749
+ k (torch.Tensor):
750
+ keys of shape `(B, H, T, K)`
751
+ v (torch.Tensor):
752
+ values of shape `(B, H, T, V)`
753
+ w (torch.Tensor):
754
+ layer norm weight of shape `(H, V)`
755
+ b (torch.Tensor):
756
+ layer norm bias of shape `(H, V)`
757
+ eta (torch.Tensor):
758
+ Learning rate for hidden state, of shape `(B, H, T, 1)`.
759
+ scale (Optional[int]):
760
+ Scale factor for the RetNet attention scores.
761
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
762
+ chunk_size (int):
763
+ chunk size. Default: `16`.
764
+ initial_state (Optional[torch.Tensor]):
765
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
766
+ initial_state_bias (Optional[torch.Tensor]):
767
+ Initial state bias of shape `(B, H, 1, V)`. Default: `None`.
768
+ output_final_state (Optional[bool]):
769
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
770
+ cu_seqlens (torch.LongTensor):
771
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
772
+ consistent with the FlashAttention API.
773
+ head_first (Optional[bool]):
774
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
775
+ Default: `False`.
776
+
777
+ Returns:
778
+ o (torch.Tensor):
779
+ Outputs of shape `[B, H, T, V]`
780
+ final_state (torch.Tensor):
781
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`.
782
+ final_state_bias (torch.Tensor):
783
+ Final state bias of shape `[B, H, 1, V]` if `output_final_state=True` else `None`.
784
+ """
785
+ assert q.dtype == k.dtype == v.dtype
786
+ assert k.shape[-1] == v.shape[-1], "DK must equal to DV."
787
+ if isinstance(eta, float):
788
+ eta = torch.full_like(q[:, :, :, :1], eta)
789
+ if head_first:
790
+ raise DeprecationWarning(
791
+ "head_first is deprecated and will be removed in a future version. "
792
+ "Please use head_first=False for now instead."
793
+ )
794
+ q, k, v, eta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, eta))
795
+ if not head_first and q.shape[1] < q.shape[2]:
796
+ warnings.warn(
797
+ f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
798
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
799
+ "when head_first=False was specified. "
800
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
801
+ )
802
+ if cu_seqlens is not None:
803
+ if q.shape[0] != 1:
804
+ raise ValueError(
805
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
806
+ f"Please flatten variable-length inputs before processing."
807
+ )
808
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
809
+ raise ValueError(
810
+ f"The number of initial states is expected to be equal to the number of input sequences, "
811
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
812
+ )
813
+ if scale is None:
814
+ scale = k.shape[-1] ** -0.5
815
+ else:
816
+ assert scale > 0, "Scale must be positive."
817
+ o, final_state, final_state_bias = FusedChunkTTTLinearFunction.apply(
818
+ q,
819
+ k,
820
+ v,
821
+ w,
822
+ b,
823
+ chunk_size,
824
+ eta,
825
+ scale,
826
+ eps,
827
+ initial_state,
828
+ initial_state_bias,
829
+ output_final_state,
830
+ cu_seqlens,
831
+ )
832
+ o = norm_residual(o, w, b, eps)
833
+ if head_first:
834
+ o = rearrange(o, 'b t h ... -> b h t ...')
835
+ return o, final_state, final_state_bias
fla3/ops/ttt/naive.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def ttt_linear(
9
+ q: torch.Tensor,
10
+ k: torch.Tensor,
11
+ v: torch.Tensor,
12
+ w: torch.Tensor,
13
+ b: torch.Tensor,
14
+ eta: torch.Tensor,
15
+ scale: float,
16
+ eps: float,
17
+ mini_batch_size: int,
18
+ initial_state: torch.Tensor,
19
+ initial_state_bias: torch.Tensor,
20
+ output_final_state: bool
21
+ ):
22
+ B, H, T, D = q.shape
23
+ BT = mini_batch_size
24
+ NT = T // BT
25
+ # [NT, B, H, mini_batch_size, D]
26
+ _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
27
+ _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
28
+ _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4)
29
+ # [NT, B, H, BT, 1]
30
+ _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4)
31
+ # [H, 1, D]
32
+ w = w.reshape(H, 1, D).to(torch.float32)
33
+ b = b.reshape(H, 1, D).to(torch.float32)
34
+
35
+ h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state
36
+ hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias
37
+ q *= scale
38
+ # [NT, B, H, BT, D]
39
+ o = torch.empty_like(_v)
40
+
41
+ for i in range(NT):
42
+ q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]]
43
+ kh = k_i @ h + hb
44
+ reconstruction_target = v_i - k_i
45
+
46
+ mean = kh.mean(-1, True)
47
+ var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
48
+ rstd = torch.sqrt(var + eps).to(torch.float32)
49
+ kh_hat = (kh - mean) / rstd
50
+
51
+ g = w * kh_hat + b - reconstruction_target
52
+ g *= w
53
+ v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D)
54
+
55
+ Attn = torch.tril(q_i @ k_i.transpose(-2, -1))
56
+ o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new
57
+ h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new
58
+ hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True)
59
+ # layer norm with residuals
60
+
61
+ mean = o_i.mean(dim=-1, keepdim=True)
62
+ var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
63
+ rstd = torch.sqrt(var + eps).to(torch.float32)
64
+ o[i] = o_i + (o_i - mean) / rstd * w + b
65
+
66
+ # [B, H, T, D]
67
+ o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D)
68
+ h = h if output_final_state else None
69
+ hb = hb if output_final_state else None
70
+ return o, h, hb
71
+
72
+
73
+ def chunk_ttt_linear_ref(
74
+ q: torch.Tensor,
75
+ k: torch.Tensor,
76
+ v: torch.Tensor,
77
+ w: torch.Tensor,
78
+ b: torch.Tensor,
79
+ eta: torch.Tensor,
80
+ scale: float = None,
81
+ eps: float = 1e-6,
82
+ mini_batch_size: int = 16,
83
+ initial_state: torch.Tensor = None,
84
+ initial_state_bias: torch.Tensor = None,
85
+ output_final_state: bool = False,
86
+ head_first: bool = False,
87
+ ):
88
+ assert q.dtype == k.dtype == v.dtype
89
+ assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same."
90
+ if isinstance(eta, float):
91
+ eta = torch.full_like(q[:, :, :, :1], eta)
92
+ if scale is None:
93
+ scale = k.shape[-1] ** -0.5
94
+ if not head_first:
95
+ q = q.transpose(1, 2)
96
+ k = k.transpose(1, 2)
97
+ v = v.transpose(1, 2)
98
+ eta = eta.transpose(1, 2)
99
+ T = q.shape[-2]
100
+ padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size
101
+ if padded > 0:
102
+ q = F.pad(q, (0, 0, 0, padded))
103
+ k = F.pad(k, (0, 0, 0, padded))
104
+ v = F.pad(v, (0, 0, 0, padded))
105
+ eta = F.pad(eta, (0, 0, 0, padded))
106
+ eta[:, :, -1, :] = eta[:, :, -(padded+1), :]
107
+ assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size."
108
+ q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b])
109
+ o, final_state, final_state_bias = ttt_linear(
110
+ q,
111
+ k,
112
+ v,
113
+ w,
114
+ b,
115
+ eta,
116
+ scale,
117
+ eps,
118
+ mini_batch_size,
119
+ initial_state,
120
+ initial_state_bias,
121
+ output_final_state,
122
+ )
123
+ o = o[:, :, :T, :].contiguous()
124
+ if not head_first:
125
+ o = o.transpose(1, 2)
126
+ return o, final_state, final_state_bias
fla3/ops/utils/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .asm import fp32_to_tf32_asm
4
+ from .cumsum import (
5
+ chunk_global_cumsum,
6
+ chunk_global_cumsum_scalar,
7
+ chunk_global_cumsum_vector,
8
+ chunk_local_cumsum,
9
+ chunk_local_cumsum_scalar,
10
+ chunk_local_cumsum_vector
11
+ )
12
+ from .index import (
13
+ prepare_chunk_indices,
14
+ prepare_chunk_offsets,
15
+ prepare_cu_seqlens_from_mask,
16
+ prepare_lens,
17
+ prepare_lens_from_mask,
18
+ prepare_position_ids,
19
+ prepare_sequence_ids,
20
+ prepare_token_indices
21
+ )
22
+ from .logsumexp import logsumexp_fwd
23
+ from .matmul import addmm, matmul
24
+ from .pack import pack_sequence, unpack_sequence
25
+ from .pooling import mean_pooling
26
+ from .softmax import softmax_bwd, softmax_fwd
27
+ from .solve_tril import solve_tril
28
+
29
+ __all__ = [
30
+ 'chunk_global_cumsum',
31
+ 'chunk_global_cumsum_scalar',
32
+ 'chunk_global_cumsum_vector',
33
+ 'chunk_local_cumsum',
34
+ 'chunk_local_cumsum_scalar',
35
+ 'chunk_local_cumsum_vector',
36
+ 'pack_sequence',
37
+ 'unpack_sequence',
38
+ 'prepare_chunk_indices',
39
+ 'prepare_chunk_offsets',
40
+ 'prepare_cu_seqlens_from_mask',
41
+ 'prepare_lens',
42
+ 'prepare_lens_from_mask',
43
+ 'prepare_position_ids',
44
+ 'prepare_sequence_ids',
45
+ 'prepare_token_indices',
46
+ 'logsumexp_fwd',
47
+ 'addmm',
48
+ 'matmul',
49
+ 'mean_pooling',
50
+ 'softmax_bwd',
51
+ 'softmax_fwd',
52
+ 'fp32_to_tf32_asm',
53
+ 'solve_tril',
54
+ ]
fla3/ops/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
fla3/ops/utils/__pycache__/asm.cpython-310.pyc ADDED
Binary file (482 Bytes). View file
 
fla3/ops/utils/__pycache__/cumsum.cpython-312.pyc ADDED
Binary file (21.4 kB). View file
 
fla3/ops/utils/__pycache__/index.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
fla3/ops/utils/__pycache__/index.cpython-312.pyc ADDED
Binary file (5.48 kB). View file
 
fla3/ops/utils/__pycache__/logcumsumexp.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
fla3/ops/utils/__pycache__/logsumexp.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
fla3/ops/utils/__pycache__/matmul.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
fla3/ops/utils/__pycache__/op.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
fla3/ops/utils/__pycache__/op.cpython-312.pyc ADDED
Binary file (1.56 kB). View file
 
fla3/ops/utils/__pycache__/pack.cpython-310.pyc ADDED
Binary file (4.56 kB). View file
 
fla3/ops/utils/__pycache__/pack.cpython-312.pyc ADDED
Binary file (8.01 kB). View file
 
fla3/ops/utils/__pycache__/softmax.cpython-310.pyc ADDED
Binary file (2.35 kB). View file
 
fla3/ops/utils/__pycache__/solve_tril.cpython-310.pyc ADDED
Binary file (7.63 kB). View file
 
fla3/ops/utils/asm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from ...utils import device_platform
4
+
5
+
6
+ def fp32_to_tf32_asm() -> str:
7
+ """
8
+ Get the assembly code for converting FP32 to TF32.
9
+ """
10
+ ASM_DICT = {
11
+ 'nvidia': 'cvt.rna.tf32.f32 $0, $1;'
12
+ }
13
+ if device_platform in ASM_DICT:
14
+ return ASM_DICT[device_platform]
15
+ else:
16
+ # return empty string if the device is not supported
17
+ return ""
fla3/ops/utils/cumsum.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from ...ops.utils.index import prepare_chunk_indices
12
+ from ...utils import check_shared_mem, input_guard
13
+
14
+ BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_local_cumsum_scalar_kernel(
29
+ s,
30
+ o,
31
+ cu_seqlens,
32
+ chunk_indices,
33
+ T,
34
+ B: tl.constexpr,
35
+ H: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ REVERSE: tl.constexpr,
38
+ IS_VARLEN: tl.constexpr,
39
+ HEAD_FIRST: tl.constexpr,
40
+ ):
41
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
42
+ i_b, i_h = i_bh // H, i_bh % H
43
+ if IS_VARLEN:
44
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
45
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
46
+ T = eos - bos
47
+ else:
48
+ bos, eos = i_b * T, i_b * T + T
49
+
50
+ if HEAD_FIRST:
51
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
53
+ else:
54
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
55
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
56
+ # [BT]
57
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
58
+ b_o = tl.cumsum(b_s, axis=0)
59
+ if REVERSE:
60
+ b_z = tl.sum(b_s, axis=0)
61
+ b_o = -b_o + b_z[None] + b_s
62
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
63
+
64
+
65
+ @triton.heuristics({
66
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
67
+ })
68
+ @triton.autotune(
69
+ configs=[
70
+ triton.Config({'BS': BS}, num_warps=num_warps)
71
+ for BS in BS_LIST
72
+ for num_warps in [2, 4, 8]
73
+ ],
74
+ key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']
75
+ )
76
+ @triton.jit(do_not_specialize=['T'])
77
+ def chunk_local_cumsum_vector_kernel(
78
+ s,
79
+ o,
80
+ cu_seqlens,
81
+ chunk_indices,
82
+ T,
83
+ B: tl.constexpr,
84
+ H: tl.constexpr,
85
+ S: tl.constexpr,
86
+ BT: tl.constexpr,
87
+ BS: tl.constexpr,
88
+ REVERSE: tl.constexpr,
89
+ IS_VARLEN: tl.constexpr,
90
+ HEAD_FIRST: tl.constexpr,
91
+ ):
92
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ i_b, i_h = i_bh // H, i_bh % H
94
+ if IS_VARLEN:
95
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
96
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
97
+ T = eos - bos
98
+ else:
99
+ bos, eos = i_b * T, i_b * T + T
100
+
101
+ o_i = tl.arange(0, BT)
102
+ if REVERSE:
103
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
104
+ else:
105
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
106
+
107
+ if HEAD_FIRST:
108
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
109
+ p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
110
+ else:
111
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
112
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
113
+ # [BT, BS]
114
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
115
+ b_o = tl.dot(m_s, b_s, allow_tf32=False)
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.heuristics({
120
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
121
+ })
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
125
+ for BT in [32, 64, 128, 256]
126
+ for num_warps in [2, 4, 8]
127
+ for num_stages in [1, 2, 3, 4]
128
+ ],
129
+ key=['B', 'H', 'IS_VARLEN', 'REVERSE']
130
+ )
131
+ @triton.jit(do_not_specialize=['T'])
132
+ def chunk_global_cumsum_scalar_kernel(
133
+ s,
134
+ o,
135
+ cu_seqlens,
136
+ T,
137
+ B: tl.constexpr,
138
+ H: tl.constexpr,
139
+ BT: tl.constexpr,
140
+ REVERSE: tl.constexpr,
141
+ IS_VARLEN: tl.constexpr,
142
+ HEAD_FIRST: tl.constexpr,
143
+ ):
144
+ i_nh = tl.program_id(0)
145
+ i_n, i_h = i_nh // H, i_nh % H
146
+ if IS_VARLEN:
147
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
148
+ else:
149
+ bos, eos = i_n * T, i_n * T + T
150
+ T = eos - bos
151
+
152
+ b_z = tl.zeros([], dtype=tl.float32)
153
+ NT = tl.cdiv(T, BT)
154
+ for i_c in range(NT):
155
+ i_t = NT-1-i_c if REVERSE else i_c
156
+ if HEAD_FIRST:
157
+ p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
158
+ p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
159
+ else:
160
+ p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
161
+ p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
162
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
163
+ b_o = tl.cumsum(b_s, axis=0)
164
+ b_ss = tl.sum(b_s, 0)
165
+ if REVERSE:
166
+ b_o = -b_o + b_ss + b_s
167
+ b_o += b_z
168
+ if i_c >= 0:
169
+ b_z += b_ss
170
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
171
+
172
+
173
+ @triton.heuristics({
174
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
175
+ })
176
+ @triton.autotune(
177
+ configs=[
178
+ triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages)
179
+ for BT in [16, 32, 64, 128]
180
+ for num_warps in [2, 4, 8]
181
+ for num_stages in [1, 2, 3, 4]
182
+ ],
183
+ key=['B', 'H', 'S', 'IS_VARLEN', 'REVERSE']
184
+ )
185
+ @triton.jit(do_not_specialize=['T'])
186
+ def chunk_global_cumsum_vector_kernel(
187
+ s,
188
+ z,
189
+ cu_seqlens,
190
+ T,
191
+ B: tl.constexpr,
192
+ H: tl.constexpr,
193
+ S: tl.constexpr,
194
+ BT: tl.constexpr,
195
+ BS: tl.constexpr,
196
+ REVERSE: tl.constexpr,
197
+ IS_VARLEN: tl.constexpr,
198
+ HEAD_FIRST: tl.constexpr,
199
+ ):
200
+ i_s, i_nh = tl.program_id(0), tl.program_id(1)
201
+ i_n, i_h = i_nh // H, i_nh % H
202
+ if IS_VARLEN:
203
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
204
+ else:
205
+ bos, eos = i_n * T, i_n * T + T
206
+ T = eos - bos
207
+
208
+ o_i = tl.arange(0, BT)
209
+ if REVERSE:
210
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
211
+ else:
212
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
213
+
214
+ b_z = tl.zeros([BS], dtype=tl.float32)
215
+ NT = tl.cdiv(T, BT)
216
+ for i_c in range(NT):
217
+ i_t = NT-1-i_c if REVERSE else i_c
218
+ if HEAD_FIRST:
219
+ p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
220
+ p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
221
+ else:
222
+ p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
223
+ p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
224
+ # [BT, BS]
225
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
226
+ b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
227
+ tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
228
+ if i_c >= 0:
229
+ b_z += tl.sum(b_s, 0)
230
+
231
+
232
+ def chunk_local_cumsum_scalar(
233
+ g: torch.Tensor,
234
+ chunk_size: int,
235
+ reverse: bool = False,
236
+ cu_seqlens: Optional[torch.Tensor] = None,
237
+ head_first: bool = False,
238
+ output_dtype: Optional[torch.dtype] = torch.float
239
+ ) -> torch.Tensor:
240
+ if head_first:
241
+ B, H, T = g.shape
242
+ else:
243
+ B, T, H = g.shape
244
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
245
+ BT = chunk_size
246
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
247
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
248
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
249
+ grid = (NT, B * H)
250
+ chunk_local_cumsum_scalar_kernel[grid](
251
+ g_org,
252
+ g,
253
+ cu_seqlens,
254
+ chunk_indices,
255
+ T=T,
256
+ B=B,
257
+ H=H,
258
+ BT=BT,
259
+ HEAD_FIRST=head_first,
260
+ REVERSE=reverse
261
+ )
262
+ return g
263
+
264
+
265
+ def chunk_local_cumsum_vector(
266
+ g: torch.Tensor,
267
+ chunk_size: int,
268
+ reverse: bool = False,
269
+ cu_seqlens: Optional[torch.Tensor] = None,
270
+ head_first: bool = False,
271
+ output_dtype: Optional[torch.dtype] = torch.float
272
+ ) -> torch.Tensor:
273
+ if head_first:
274
+ B, H, T, S = g.shape
275
+ else:
276
+ B, T, H, S = g.shape
277
+ BT = chunk_size
278
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
279
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
280
+ assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
281
+
282
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
283
+ def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
284
+ # keep cummulative normalizer in fp32
285
+ # this kernel is equivalent to
286
+ # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
287
+ chunk_local_cumsum_vector_kernel[grid](
288
+ g_org,
289
+ g,
290
+ cu_seqlens,
291
+ chunk_indices,
292
+ T=T,
293
+ B=B,
294
+ H=H,
295
+ S=S,
296
+ BT=BT,
297
+ HEAD_FIRST=head_first,
298
+ REVERSE=reverse
299
+ )
300
+ return g
301
+
302
+
303
+ @input_guard
304
+ def chunk_global_cumsum_scalar(
305
+ s: torch.Tensor,
306
+ reverse: bool = False,
307
+ cu_seqlens: Optional[torch.Tensor] = None,
308
+ head_first: bool = False,
309
+ output_dtype: Optional[torch.dtype] = torch.float
310
+ ) -> torch.Tensor:
311
+ if head_first:
312
+ B, H, T = s.shape
313
+ else:
314
+ B, T, H = s.shape
315
+ N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
316
+
317
+ z = torch.empty_like(s, dtype=output_dtype or s.dtype)
318
+ grid = (N * H,)
319
+ chunk_global_cumsum_scalar_kernel[grid](
320
+ s,
321
+ z,
322
+ cu_seqlens,
323
+ T=T,
324
+ B=B,
325
+ H=H,
326
+ HEAD_FIRST=head_first,
327
+ REVERSE=reverse
328
+ )
329
+ return z
330
+
331
+
332
+ @input_guard
333
+ def chunk_global_cumsum_vector(
334
+ s: torch.Tensor,
335
+ reverse: bool = False,
336
+ cu_seqlens: Optional[torch.Tensor] = None,
337
+ head_first: bool = False,
338
+ output_dtype: Optional[torch.dtype] = torch.float
339
+ ) -> torch.Tensor:
340
+ if head_first:
341
+ B, H, T, S = s.shape
342
+ else:
343
+ B, T, H, S = s.shape
344
+ N = len(cu_seqlens) - 1 if cu_seqlens is not None else B
345
+ BS = min(32, triton.next_power_of_2(S))
346
+
347
+ z = torch.empty_like(s, dtype=output_dtype or s.dtype)
348
+ grid = (triton.cdiv(S, BS), N * H)
349
+ chunk_global_cumsum_vector_kernel[grid](
350
+ s,
351
+ z,
352
+ cu_seqlens,
353
+ T=T,
354
+ B=B,
355
+ H=H,
356
+ S=S,
357
+ BS=BS,
358
+ HEAD_FIRST=head_first,
359
+ REVERSE=reverse
360
+ )
361
+ return z
362
+
363
+
364
+ @input_guard
365
+ def chunk_global_cumsum(
366
+ s: torch.Tensor,
367
+ reverse: bool = False,
368
+ cu_seqlens: Optional[torch.Tensor] = None,
369
+ head_first: bool = False,
370
+ output_dtype: Optional[torch.dtype] = torch.float
371
+ ) -> torch.Tensor:
372
+ if cu_seqlens is not None:
373
+ assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
374
+ if len(s.shape) == 3:
375
+ return chunk_global_cumsum_scalar(s, reverse, cu_seqlens, head_first, output_dtype)
376
+ elif len(s.shape) == 4:
377
+ return chunk_global_cumsum_vector(s, reverse, cu_seqlens, head_first, output_dtype)
378
+ else:
379
+ raise ValueError(
380
+ f"Unsupported input shape {s.shape}. "
381
+ f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` "
382
+ f"or [B, H, T]/[B, H, T, D] otherwise"
383
+ )
384
+
385
+
386
+ @input_guard
387
+ def chunk_local_cumsum(
388
+ g: torch.Tensor,
389
+ chunk_size: int,
390
+ reverse: bool = False,
391
+ cu_seqlens: Optional[torch.Tensor] = None,
392
+ head_first: bool = False,
393
+ output_dtype: Optional[torch.dtype] = torch.float,
394
+ **kwargs
395
+ ) -> torch.Tensor:
396
+ if not head_first and g.shape[1] < g.shape[2]:
397
+ warnings.warn(
398
+ f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
399
+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
400
+ "when head_first=False was specified. "
401
+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
402
+ )
403
+ if cu_seqlens is not None:
404
+ assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
405
+ if len(g.shape) == 3:
406
+ return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
407
+ elif len(g.shape) == 4:
408
+ return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
409
+ else:
410
+ raise ValueError(
411
+ f"Unsupported input shape {g.shape}. "
412
+ f"which should be (B, T, H, D) if `head_first=False` "
413
+ f"or (B, H, T, D) otherwise"
414
+ )
fla3/ops/utils/index.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from ...utils import tensor_cache
10
+
11
+
12
+ @triton.autotune(
13
+ configs=[
14
+ triton.Config({}, num_warps=num_warps)
15
+ for num_warps in [4, 8, 16, 32]
16
+ ],
17
+ key=['B'],
18
+ )
19
+ @triton.jit
20
+ def prepare_position_ids_kernel(
21
+ y,
22
+ cu_seqlens,
23
+ B: tl.constexpr
24
+ ):
25
+ i_n = tl.program_id(0)
26
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
27
+ T = eos - bos
28
+
29
+ o = tl.arange(0, B)
30
+ for i in range(0, tl.cdiv(T, B) * B, B):
31
+ o_i = o + i
32
+ tl.store(y + bos + o_i, o_i, o_i < T)
33
+
34
+
35
+ @tensor_cache
36
+ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
37
+ return cu_seqlens[1:] - cu_seqlens[:-1]
38
+
39
+
40
+ @tensor_cache
41
+ def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor:
42
+ return mask.sum(dim=-1, dtype=torch.int32)
43
+
44
+
45
+ @tensor_cache
46
+ def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor:
47
+ return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0))
48
+
49
+
50
+ @tensor_cache
51
+ def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
52
+ return torch.cat([
53
+ torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
54
+ for n in prepare_lens(cu_seqlens).unbind()
55
+ ])
56
+
57
+
58
+ @tensor_cache
59
+ def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
60
+ return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
61
+
62
+
63
+ @tensor_cache
64
+ def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
65
+ position_ids = prepare_position_ids(cu_seqlens)
66
+ return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
67
+
68
+
69
+ @tensor_cache
70
+ def prepare_chunk_indices(
71
+ cu_seqlens: torch.LongTensor,
72
+ chunk_size: int
73
+ ) -> torch.LongTensor:
74
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
75
+ return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
76
+
77
+
78
+ @tensor_cache
79
+ def prepare_chunk_offsets(
80
+ cu_seqlens: torch.LongTensor,
81
+ chunk_size: int
82
+ ) -> torch.LongTensor:
83
+ return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
fla3/ops/utils/logcumsumexp.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from ...ops.utils.op import exp, log
8
+
9
+
10
+ @triton.autotune(
11
+ configs=[
12
+ triton.Config({'BT': BT}, num_warps=num_warps)
13
+ for BT in [16, 32, 64]
14
+ for num_warps in [2, 4, 8]
15
+ ],
16
+ key=['S']
17
+ )
18
+ @triton.jit(do_not_specialize=['T'])
19
+ def logcumsumexp_fwd_kernel(
20
+ s,
21
+ z,
22
+ T,
23
+ S: tl.constexpr,
24
+ BT: tl.constexpr
25
+ ):
26
+ i_bh = tl.program_id(0)
27
+ o_i = tl.arange(0, BT)
28
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
29
+
30
+ b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
31
+ b_zp = tl.zeros([S,], dtype=tl.float32)
32
+ for i_t in range(tl.cdiv(T, BT)):
33
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
34
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0))
35
+
36
+ # [BT, S]
37
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
38
+ # [S,]
39
+ b_mc = tl.max(b_s, 0)
40
+ b_mc = tl.maximum(b_mp, b_mc)
41
+ b_zp = b_zp * exp(b_mp - b_mc)
42
+ # [BT, S]
43
+ b_s = exp(b_s - b_mc)
44
+ b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
45
+ # [S,]
46
+ b_zc = tl.max(b_z, 0)
47
+ b_mp = b_mc
48
+ b_zp = b_zc
49
+ # [BT, BS]
50
+ # small eps to prevent underflows
51
+ b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
52
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
fla3/ops/utils/logsumexp.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.op import exp, log
11
+
12
+
13
+ @triton.heuristics({
14
+ 'HAS_SCALE': lambda args: args['scale'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({}, num_warps=num_warps)
19
+ for num_warps in [1, 2, 4, 8, 16, 32]
20
+ ],
21
+ key=['D']
22
+ )
23
+ @triton.jit
24
+ def logsumexp_fwd_kernel(
25
+ x,
26
+ z,
27
+ scale,
28
+ D: tl.constexpr,
29
+ B: tl.constexpr,
30
+ HAS_SCALE: tl.constexpr
31
+ ):
32
+ i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
33
+ o_d = i_d * B + tl.arange(0, B)
34
+ m_d = o_d < D
35
+
36
+ b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
37
+ if HAS_SCALE:
38
+ b_x = b_x * scale
39
+ b_m = tl.max(b_x, 0)
40
+ b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m
41
+ tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z)
42
+
43
+
44
+ def logsumexp_fwd(
45
+ x,
46
+ scale: Optional[float] = None,
47
+ dtype: Optional[torch.dtype] = None
48
+ ):
49
+ r"""
50
+ Compute the logsumexp of the input tensor over the last dimension.
51
+
52
+ Args:
53
+ x (Tensor):
54
+ The input tensor of any shape.
55
+ scale (Optional[float]):
56
+ The scale applied to the input tensor. Default: `None`.
57
+ dtype (Optional[torch.dtype]):
58
+ The data type of the output tensor. Default: `None`.
59
+ Returns:
60
+ Tensor: The logsumexp of the input tensor.
61
+ """
62
+
63
+ shape = x.shape
64
+ x = x.view(-1, shape[-1])
65
+ N, D = x.shape
66
+ B = min(triton.next_power_of_2(D), 64 * 1024)
67
+ ND = triton.cdiv(D, B)
68
+
69
+ z = x.new_empty(N, ND, dtype=torch.float)
70
+ logsumexp_fwd_kernel[(N, ND)](
71
+ x=x,
72
+ z=z,
73
+ scale=scale,
74
+ D=D,
75
+ B=B
76
+ )
77
+ z = z.logsumexp(-1).view(*shape[:-1])
78
+ if dtype is not None and dtype != torch.float:
79
+ z = z.to(dtype)
80
+ return z
fla3/ops/utils/matmul.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # code adapted from
5
+ # https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from ...ops.utils.op import exp
14
+ from ...utils import input_guard
15
+
16
+
17
+ # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
18
+ # - A list of `triton.Config` objects that define different configurations of
19
+ # meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
20
+ # - An auto-tuning *key* whose change in values will trigger evaluation of all the
21
+ # provided configs
22
+ @triton.heuristics({
23
+ 'HAS_ALPHA': lambda args: args['alpha'] is not None,
24
+ 'HAS_BETA': lambda args: args['beta'] is not None
25
+ })
26
+ @triton.autotune(
27
+ configs=[
28
+ triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
29
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
30
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
31
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
32
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
33
+ triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
34
+ triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
35
+ triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
36
+ # Good config for fp8 inputs.
37
+ # triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
38
+ # triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
39
+ # triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
40
+ # triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
41
+ # triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
42
+ # triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
43
+ # triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
44
+ # triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
45
+ ],
46
+ key=['M', 'N', 'K']
47
+ )
48
+ @triton.jit
49
+ def matmul_kernel(
50
+ # Pointers to matrices
51
+ a,
52
+ b,
53
+ c,
54
+ input,
55
+ alpha,
56
+ beta,
57
+ # Matrix dimensions
58
+ M,
59
+ N,
60
+ K,
61
+ # The stride variables represent how much to increase the ptr by when moving by 1
62
+ # element in a particular dimension. E.g. `s_am` is how much to increase `a`
63
+ # by to get the element one row down (A has M rows).
64
+ stride_ab, stride_am, stride_ak, # a: batch, M, K
65
+ stride_bk, stride_bn, # b: K, N
66
+ stride_cb, stride_cm, stride_cn, # c: batch, M, N
67
+ # Meta-parameters
68
+ BM: tl.constexpr,
69
+ BK: tl.constexpr,
70
+ BN: tl.constexpr,
71
+ G: tl.constexpr,
72
+ ACTIVATION: tl.constexpr,
73
+ HAS_INPUT: tl.constexpr,
74
+ HAS_ALPHA: tl.constexpr,
75
+ HAS_BETA: tl.constexpr,
76
+ ALLOW_TF32: tl.constexpr,
77
+ X_DIM: tl.constexpr = 1,
78
+ ):
79
+ """Kernel for computing the matmul C = A x B.
80
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
81
+ """
82
+ # -----------------------------------------------------------
83
+ # Map program ids `pid` to the block of C it should compute.
84
+ # This is done in a grouped ordering to promote L2 data reuse.
85
+ # See above `L2 Cache Optimizations` section for details.
86
+ i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
87
+
88
+ NM, NN = tl.num_programs(1), tl.num_programs(2)
89
+ i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
90
+
91
+ # ----------------------------------------------------------
92
+ # Create pointers for the first blocks of A and B.
93
+ # We will advance this pointer as we move in the K direction
94
+ # and accumulate
95
+ # `p_a` is a block of [BM, BK] pointers
96
+ # `p_b` is a block of [BK, BN] pointers
97
+ # See above `Pointer Arithmetic` section for details
98
+ a_batch_ptr = a + i_b * stride_ab
99
+ o_am = (i_m * BM + tl.arange(0, BM)) % M
100
+ o_bn = (i_n * BN + tl.arange(0, BN)) % N
101
+ o_k = tl.arange(0, BK)
102
+
103
+ p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak)
104
+ p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn)
105
+
106
+ b_acc = tl.zeros((BM, BN), dtype=tl.float32)
107
+ for k in range(0, tl.cdiv(K, BK)):
108
+ # Load the next block of A and B, generate a mask by checking the K dimension.
109
+ # If it is out of bounds, set it to 0.
110
+ b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
111
+ b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
112
+ # We accumulate along the K dimension.
113
+ b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32)
114
+ # Advance the ptrs to the next K block.
115
+ p_a += BK * stride_ak
116
+ p_b += BK * stride_bk
117
+
118
+ o_cm = i_m * BM + tl.arange(0, BM)
119
+ o_cn = i_n * BN + tl.arange(0, BN)
120
+ mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
121
+
122
+ b_c = b_acc
123
+ # You can fuse arbitrary activation functions here
124
+ # while the b_acc is still in FP32!
125
+ if ACTIVATION == "leaky_relu":
126
+ b_c = leaky_relu(b_c)
127
+ elif ACTIVATION == "relu":
128
+ b_c = relu(b_c)
129
+ elif ACTIVATION == "sigmoid":
130
+ b_c = sigmoid(b_c)
131
+ elif ACTIVATION == "tanh":
132
+ b_c = tanh(b_c)
133
+
134
+ if HAS_ALPHA:
135
+ b_c *= tl.load(alpha)
136
+
137
+ if HAS_INPUT:
138
+ p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :]
139
+ mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask
140
+ b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32)
141
+ if HAS_BETA:
142
+ b_i *= tl.load(beta)
143
+ b_c += b_i
144
+
145
+ # -----------------------------------------------------------
146
+ # Write back the block of the output matrix C with masks.
147
+ c_batch_ptr = c + i_b * stride_cb
148
+ p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :]
149
+ tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
150
+
151
+
152
+ # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
153
+ @triton.jit
154
+ def leaky_relu(x):
155
+ return tl.where(x >= 0, x, 0.01 * x)
156
+
157
+
158
+ @triton.jit
159
+ def sigmoid(x):
160
+ # σ(x) = 1 / (1 + exp(-x))
161
+ return 1.0 / (1.0 + exp(-x))
162
+
163
+
164
+ @triton.jit
165
+ def tanh(x):
166
+ # tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
167
+ # 2 * sigmoid(2x) - 1
168
+ return (exp(x) - exp(-x)) / (exp(x) + exp(-x))
169
+
170
+
171
+ @triton.jit
172
+ def relu(x):
173
+ # ReLU(x) = max(0, x)
174
+ return tl.maximum(x, 0.0)
175
+
176
+
177
+ @input_guard
178
+ def matmul(a, b, activation=''):
179
+ assert a.dim() in [2, 3], "a must be 2D or 3D"
180
+ assert b.dim() == 2, "b must be 2D"
181
+ assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
182
+
183
+ if a.dim() == 2:
184
+ a_dim = 2
185
+ a = a.unsqueeze(0).contiguous() # (1, M, K)
186
+ else:
187
+ a_dim = 3
188
+ allow_tf32 = False if a.dtype == torch.float32 else True
189
+
190
+ B, M, K = a.shape[0], a.shape[1], a.shape[2]
191
+ K_b, N = b.shape
192
+ assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
193
+ c = a.new_empty(B, M, N)
194
+
195
+ def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
196
+ matmul_kernel[grid](
197
+ a, b, c, None, None, None,
198
+ M, N, K,
199
+ a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
200
+ b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
201
+ c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
202
+ ACTIVATION=activation,
203
+ ALLOW_TF32=allow_tf32,
204
+ HAS_INPUT=False,
205
+ )
206
+ return c.squeeze(0) if a_dim == 2 else c
207
+
208
+
209
+ @input_guard
210
+ def addmm(
211
+ x: torch.Tensor,
212
+ a: torch.Tensor,
213
+ b: torch.Tensor,
214
+ alpha: Optional[float] = None,
215
+ beta: Optional[float] = None,
216
+ ) -> torch.Tensor:
217
+ assert a.dim() in [2, 3], "a must be 2D or 3D"
218
+ assert b.dim() == 2, "b must be 2D"
219
+ assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}"
220
+
221
+ if a.dim() == 2:
222
+ a_dim = 2
223
+ a = a.unsqueeze(0).contiguous() # (1, M, K)
224
+ else:
225
+ a_dim = 3
226
+ allow_tf32 = False if a.dtype == torch.float32 else True
227
+
228
+ B, M, K = a.shape[0], a.shape[1], a.shape[2]
229
+ K_b, N = b.shape
230
+ assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}"
231
+ c = a.new_empty(B, M, N)
232
+
233
+ def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
234
+ matmul_kernel[grid](
235
+ a, b, c, x, alpha, beta,
236
+ M, N, K,
237
+ a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak
238
+ b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2)
239
+ c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn
240
+ ACTIVATION=None,
241
+ ALLOW_TF32=allow_tf32,
242
+ HAS_INPUT=True,
243
+ X_DIM=x.dim(),
244
+ )
245
+ return c.squeeze(0) if a_dim == 2 else c
fla3/ops/utils/pack.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Code adapted from https://github.com/mayank31398/cute-kernels
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from ...ops.utils.index import prepare_lens
13
+ from ...utils import input_guard
14
+
15
+
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({}, num_warps=num_warps)
19
+ for num_warps in [4, 8, 16, 32]
20
+ ],
21
+ key=['D', 'PADDING_SIDE', 'PACK']
22
+ )
23
+ @triton.jit
24
+ def packunpack_sequence_kernel(
25
+ x,
26
+ y,
27
+ cu_seqlens,
28
+ S,
29
+ D,
30
+ BD: tl.constexpr,
31
+ PADDING_SIDE: tl.constexpr,
32
+ PACK: tl.constexpr,
33
+ ):
34
+ i_d, i_s, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
36
+
37
+ T = eos - bos
38
+ if PADDING_SIDE == 'left':
39
+ NP = S - T
40
+ if i_s < NP:
41
+ return
42
+ i_t = bos + (i_s - NP)
43
+ else:
44
+ if i_s >= T:
45
+ return
46
+ i_t = bos + i_s
47
+
48
+ o_d = i_d * BD + tl.arange(0, BD)
49
+ mask = o_d < D
50
+
51
+ if PACK:
52
+ b_x = tl.load(x + (i_b * S + i_s) * D + o_d, mask=mask)
53
+ tl.store(y + i_t * D + o_d, b_x, mask=mask)
54
+ else:
55
+ b_x = tl.load(x + i_t * D + o_d, mask=mask)
56
+ tl.store(y + (i_b * S + i_s) * D + o_d, b_x, mask=mask)
57
+
58
+
59
+ def pack_sequence_fwdbwd(
60
+ x: torch.Tensor,
61
+ cu_seqlens: torch.Tensor,
62
+ padding_side: str,
63
+ ) -> torch.Tensor:
64
+ B, S = x.shape[:2]
65
+ D = x.numel() // (B * S)
66
+ BD = min(triton.next_power_of_2(D), 4096)
67
+ ND = triton.cdiv(D, BD)
68
+
69
+ y = torch.empty(cu_seqlens[-1].item(), *x.shape[2:], device=x.device, dtype=x.dtype)
70
+ packunpack_sequence_kernel[ND, S, B](
71
+ x=x,
72
+ y=y,
73
+ cu_seqlens=cu_seqlens,
74
+ S=S,
75
+ D=D,
76
+ BD=BD,
77
+ PADDING_SIDE=padding_side,
78
+ PACK=True,
79
+ )
80
+ return y
81
+
82
+
83
+ def unpack_sequence_fwdbwd(
84
+ x: torch.Tensor,
85
+ cu_seqlens: torch.Tensor,
86
+ padding_side: str,
87
+ desired_shape: torch.Size,
88
+ ) -> torch.Tensor:
89
+ if desired_shape is None:
90
+ desired_shape = (len(cu_seqlens) - 1, prepare_lens(cu_seqlens).max().item(), *x.shape[1:])
91
+ y = torch.zeros(desired_shape, device=x.device, dtype=x.dtype)
92
+ B, S = y.shape[:2]
93
+ D = y.numel() // (B * S)
94
+ BD = min(triton.next_power_of_2(D), 4096)
95
+ ND = triton.cdiv(D, BD)
96
+
97
+ packunpack_sequence_kernel[ND, S, B](
98
+ x=x,
99
+ y=y,
100
+ cu_seqlens=cu_seqlens,
101
+ S=S,
102
+ D=D,
103
+ BD=BD,
104
+ PADDING_SIDE=padding_side,
105
+ PACK=False,
106
+ )
107
+ return y
108
+
109
+
110
+ class PackSequenceFunction(torch.autograd.Function):
111
+
112
+ @staticmethod
113
+ @input_guard
114
+ def forward(
115
+ ctx,
116
+ x: torch.Tensor,
117
+ cu_seqlens: torch.Tensor,
118
+ padding_side: str,
119
+ ) -> torch.Tensor:
120
+ assert padding_side in ['left', 'right']
121
+ assert x.ndim >= 2
122
+
123
+ ctx.cu_seqlens = cu_seqlens
124
+ ctx.padding_side = padding_side
125
+ ctx.desired_shape = x.shape
126
+
127
+ y = pack_sequence_fwdbwd(
128
+ x=x,
129
+ cu_seqlens=cu_seqlens,
130
+ padding_side=padding_side,
131
+ )
132
+ return y
133
+
134
+ @staticmethod
135
+ @input_guard
136
+ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
137
+ dx = unpack_sequence_fwdbwd(
138
+ x=dy,
139
+ cu_seqlens=ctx.cu_seqlens,
140
+ padding_side=ctx.padding_side,
141
+ desired_shape=ctx.desired_shape,
142
+ )
143
+ return dx, *[None] * 10
144
+
145
+
146
+ class UnpackSequenceFunction(torch.autograd.Function):
147
+
148
+ @staticmethod
149
+ @input_guard
150
+ def forward(
151
+ ctx,
152
+ x: torch.Tensor,
153
+ cu_seqlens: torch.Tensor,
154
+ padding_side: str,
155
+ desired_shape: Optional[torch.Size] = None,
156
+ ) -> torch.Tensor:
157
+ assert padding_side in ['left', 'right']
158
+ assert x.ndim >= 2
159
+ if desired_shape is not None:
160
+ assert desired_shape[0] == cu_seqlens.shape[0] - 1
161
+ assert desired_shape[2:] == x.shape[1:]
162
+
163
+ ctx.cu_seqlens = cu_seqlens
164
+ ctx.padding_side = padding_side
165
+
166
+ y = unpack_sequence_fwdbwd(
167
+ x=x,
168
+ cu_seqlens=cu_seqlens,
169
+ padding_side=padding_side,
170
+ desired_shape=desired_shape,
171
+ )
172
+ return y
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
177
+ dx = pack_sequence_fwdbwd(
178
+ x=dy,
179
+ cu_seqlens=ctx.cu_seqlens,
180
+ padding_side=ctx.padding_side,
181
+ )
182
+ return dx, None, None, None
183
+
184
+
185
+ def pack_sequence(
186
+ x: torch.Tensor,
187
+ cu_seqlens: torch.Tensor,
188
+ padding_side: str = 'left'
189
+ ) -> torch.Tensor:
190
+ return PackSequenceFunction.apply(
191
+ x,
192
+ cu_seqlens,
193
+ padding_side,
194
+ )
195
+
196
+
197
+ def unpack_sequence(
198
+ x: torch.Tensor,
199
+ cu_seqlens: torch.Tensor,
200
+ padding_side: str = 'left',
201
+ desired_shape: Optional[torch.Size] = None,
202
+ ) -> torch.Tensor:
203
+ return UnpackSequenceFunction.apply(
204
+ x,
205
+ cu_seqlens,
206
+ padding_side,
207
+ desired_shape,
208
+ )
fla3/ops/utils/pooling.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.index import prepare_chunk_indices
11
+ from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BD': BD}, num_warps=num_warps)
20
+ for BD in [16, 32, 64, 128]
21
+ for num_warps in [1, 2, 4, 8]
22
+ ],
23
+ key=['BT']
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def mean_pooling_fwd_kernel(
27
+ x,
28
+ o,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ D: tl.constexpr,
34
+ BT: tl.constexpr,
35
+ BD: tl.constexpr,
36
+ IS_VARLEN: tl.constexpr
37
+ ):
38
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+ i_b, i_h = i_bh // H, i_bh % H
40
+ if IS_VARLEN:
41
+ i_tg = i_t
42
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
43
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
44
+ T = eos - bos
45
+ NT = tl.cdiv(T, BT)
46
+ else:
47
+ NT = tl.cdiv(T, BT)
48
+ i_tg = i_b * NT + i_t
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
52
+ p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
53
+ # [BT, BD]
54
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
55
+ # [BD]
56
+ b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
57
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
58
+
59
+
60
+ @triton.heuristics({
61
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
62
+ })
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BD': BD}, num_warps=num_warps)
66
+ for BD in [16, 32, 64, 128]
67
+ for num_warps in [1, 2, 4, 8]
68
+ ],
69
+ key=['BT']
70
+ )
71
+ @triton.jit(do_not_specialize=['T'])
72
+ def mean_pooling_bwd_kernel(
73
+ do,
74
+ dx,
75
+ cu_seqlens,
76
+ chunk_indices,
77
+ T,
78
+ H: tl.constexpr,
79
+ D: tl.constexpr,
80
+ BT: tl.constexpr,
81
+ BD: tl.constexpr,
82
+ IS_VARLEN: tl.constexpr
83
+ ):
84
+ i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
85
+ i_b, i_h = i_bh // H, i_bh % H
86
+ if IS_VARLEN:
87
+ i_tg = i_t
88
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
89
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
90
+ T = eos - bos
91
+ NT = tl.cdiv(T, BT)
92
+ else:
93
+ NT = tl.cdiv(T, BT)
94
+ i_tg = i_b * NT + i_t
95
+ bos, eos = i_b * T, i_b * T + T
96
+
97
+ p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
98
+ p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
99
+ # [BD]
100
+ b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
101
+ # [BT, BD]
102
+ b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
103
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
104
+
105
+
106
+ def mean_pooling_fwd(
107
+ x: torch.Tensor,
108
+ chunk_size: int,
109
+ cu_seqlens: Optional[torch.LongTensor] = None
110
+ ) -> torch.Tensor:
111
+ B, T, H, D = x.shape
112
+ BT = chunk_size
113
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
114
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
115
+
116
+ o = x.new_empty(B, NT, H, D)
117
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
118
+ mean_pooling_fwd_kernel[grid](
119
+ x,
120
+ o,
121
+ cu_seqlens,
122
+ chunk_indices,
123
+ T=T,
124
+ H=H,
125
+ D=D,
126
+ BT=BT,
127
+ )
128
+ return o
129
+
130
+
131
+ def mean_pooling_bwd(
132
+ do: torch.Tensor,
133
+ batch_size: int,
134
+ seq_len: int,
135
+ chunk_size: int,
136
+ cu_seqlens: Optional[torch.LongTensor] = None
137
+ ) -> torch.Tensor:
138
+ B, T, H, D = batch_size, seq_len, *do.shape[-2:]
139
+ BT = chunk_size
140
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
141
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
142
+
143
+ dx = do.new_empty(B, T, H, D)
144
+ def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
145
+ mean_pooling_bwd_kernel[grid](
146
+ do,
147
+ dx,
148
+ cu_seqlens,
149
+ chunk_indices,
150
+ T=T,
151
+ H=H,
152
+ D=D,
153
+ BT=BT,
154
+ )
155
+ return dx
156
+
157
+
158
+ class MeanPoolingFunction(torch.autograd.Function):
159
+
160
+ @staticmethod
161
+ @input_guard
162
+ @autocast_custom_fwd
163
+ def forward(
164
+ ctx,
165
+ x: torch.Tensor,
166
+ chunk_size: int,
167
+ cu_seqlens: Optional[torch.LongTensor] = None
168
+ ) -> torch.Tensor:
169
+ o = mean_pooling_fwd(x, chunk_size, cu_seqlens)
170
+ ctx.batch_size = x.shape[0]
171
+ ctx.seq_len = x.shape[1]
172
+ ctx.chunk_size = chunk_size
173
+ ctx.cu_seqlens = cu_seqlens
174
+ return o
175
+
176
+ @staticmethod
177
+ @input_guard
178
+ @autocast_custom_bwd
179
+ def backward(
180
+ ctx, do
181
+ ) -> Tuple[torch.Tensor, None, None]:
182
+ batch_size = ctx.batch_size
183
+ seq_len = ctx.seq_len
184
+ chunk_size = ctx.chunk_size
185
+ cu_seqlens = ctx.cu_seqlens
186
+ dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, cu_seqlens)
187
+ return dx, None, None
188
+
189
+
190
+ def mean_pooling(
191
+ x: torch.Tensor,
192
+ chunk_size: int,
193
+ cu_seqlens: Optional[torch.LongTensor] = None,
194
+ head_first: bool = False
195
+ ) -> torch.Tensor:
196
+ if head_first:
197
+ x = x.transpose(1, 2)
198
+ if cu_seqlens is not None:
199
+ if x.shape[0] != 1:
200
+ raise ValueError(
201
+ f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
202
+ f"Please ..tten variable-length inputs before processing."
203
+ )
204
+ o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
205
+ if head_first:
206
+ o = o.transpose(1, 2)
207
+ return o
fla3/ops/utils/softmax.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.op import exp
11
+
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config({}, num_warps=1),
16
+ triton.Config({}, num_warps=2),
17
+ triton.Config({}, num_warps=4),
18
+ triton.Config({}, num_warps=8),
19
+ triton.Config({}, num_warps=16),
20
+ triton.Config({}, num_warps=32)
21
+ ],
22
+ key=['D']
23
+ )
24
+ @triton.jit
25
+ def softmax_fwd_kernel(
26
+ x,
27
+ p,
28
+ D: tl.constexpr,
29
+ B: tl.constexpr
30
+ ):
31
+ i_n = tl.program_id(0)
32
+ o_d = tl.arange(0, B)
33
+ m_d = o_d < D
34
+
35
+ b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
36
+ b_m = tl.max(b_x, 0)
37
+ b_x = exp(b_x - b_m)
38
+ b_p = b_x / tl.sum(b_x, 0)
39
+
40
+ tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d)
41
+
42
+
43
+ @triton.autotune(
44
+ configs=[
45
+ triton.Config({}, num_warps=1),
46
+ triton.Config({}, num_warps=2),
47
+ triton.Config({}, num_warps=4),
48
+ triton.Config({}, num_warps=8),
49
+ triton.Config({}, num_warps=16),
50
+ triton.Config({}, num_warps=32)
51
+ ],
52
+ key=['D']
53
+ )
54
+ @triton.jit
55
+ def softmax_bwd_kernel(
56
+ p,
57
+ dp,
58
+ ds,
59
+ D: tl.constexpr,
60
+ B: tl.constexpr
61
+ ):
62
+ i_n = tl.program_id(0)
63
+ o_d = tl.arange(0, B)
64
+ m_d = o_d < D
65
+
66
+ b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.)
67
+ b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.)
68
+ b_pp = tl.sum(b_p * b_dp, 0)
69
+ b_ds = b_p * b_dp - b_p * b_pp
70
+ tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d)
71
+
72
+
73
+ def softmax_fwd(
74
+ x: torch.Tensor,
75
+ dtype: Optional[torch.dtype] = torch.float
76
+ ) -> torch.Tensor:
77
+ shape = x.shape
78
+ x = x.view(-1, x.shape[-1])
79
+
80
+ N, D = x.shape
81
+ B = triton.next_power_of_2(D)
82
+
83
+ p = torch.empty_like(x, dtype=dtype)
84
+ softmax_fwd_kernel[(N,)](
85
+ x=x,
86
+ p=p,
87
+ D=D,
88
+ B=B
89
+ )
90
+ return p.view(*shape)
91
+
92
+
93
+ def softmax_bwd(
94
+ p: torch.Tensor,
95
+ dp: torch.Tensor,
96
+ dtype: Optional[torch.dtype] = torch.float
97
+ ) -> torch.Tensor:
98
+ shape = p.shape
99
+ p = p.view(-1, p.shape[-1])
100
+ ds = torch.empty_like(p, dtype=dtype)
101
+
102
+ N, D = p.shape
103
+ B = triton.next_power_of_2(D)
104
+ softmax_bwd_kernel[(N,)](
105
+ p=p,
106
+ dp=dp,
107
+ ds=ds,
108
+ D=D,
109
+ B=B
110
+ )
111
+ return ds.view(*shape)
fla3/ops/utils/solve_tril.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from ...ops.utils.index import prepare_chunk_indices
11
+ from ...utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [1, 2, 4, 8]
21
+ for num_stages in [2, 3, 4, 5]
22
+ ],
23
+ key=['BT'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def solve_tril_16x16_kernel(
27
+ A,
28
+ Ad,
29
+ cu_seqlens,
30
+ chunk_indices,
31
+ T,
32
+ H: tl.constexpr,
33
+ BT: tl.constexpr,
34
+ IS_VARLEN: tl.constexpr,
35
+ ):
36
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
37
+ i_b, i_h = i_bh // H, i_bh % H
38
+ if IS_VARLEN:
39
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
40
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
41
+ T = eos - bos
42
+ else:
43
+ bos, eos = i_b * T, i_b * T + T
44
+
45
+ A = A + (bos*H + i_h) * BT
46
+ Ad = Ad + (bos*H + i_h) * 16
47
+
48
+ offset = (i_t * 16) % BT
49
+ p_A = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
50
+ p_Ai = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 16, 0), (16, 16), (1, 0))
51
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
52
+ b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
53
+
54
+ o_i = tl.arange(0, 16)
55
+ for i in range(1, min(16, T-i_t*16)):
56
+ b_a = -tl.load(A + (i_t * 16 + i) * H*BT + o_i + offset)
57
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
58
+ mask = o_i == i
59
+ b_A = tl.where(mask[:, None], b_a, b_A)
60
+ b_A += o_i[:, None] == o_i[None, :]
61
+ tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
62
+
63
+
64
+ @triton.heuristics({
65
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
66
+ })
67
+ @triton.autotune(
68
+ configs=[
69
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
70
+ for num_warps in [1, 2, 4, 8]
71
+ for num_stages in [2, 3, 4, 5]
72
+ ],
73
+ key=['H', 'BT', 'IS_VARLEN'],
74
+ )
75
+ @triton.jit(do_not_specialize=['T'])
76
+ def merge_16x16_to_32x32_inverse_kernel(
77
+ A,
78
+ Ad,
79
+ Ai,
80
+ cu_seqlens,
81
+ chunk_indices,
82
+ T,
83
+ H: tl.constexpr,
84
+ BT: tl.constexpr,
85
+ IS_VARLEN: tl.constexpr
86
+ ):
87
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
88
+ i_b, i_h = i_bh // H, i_bh % H
89
+ if IS_VARLEN:
90
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
91
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
92
+ T = eos - bos
93
+ else:
94
+ bos, eos = i_b * T, i_b * T + T
95
+
96
+ A += (bos*H + i_h) * 32
97
+ Ad += (bos*H + i_h) * 16
98
+ Ai += (bos*H + i_h) * 32
99
+
100
+ p_A_21 = tl.make_block_ptr(A, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
101
+ p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32, 0), (16, 16), (1, 0))
102
+ p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
103
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32, 0), (16, 16), (1, 0))
104
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
105
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H*32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
106
+
107
+ A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
108
+ Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
109
+ Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
110
+ Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
111
+ tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
112
+ tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
113
+ tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
114
+
115
+
116
+ @triton.heuristics({
117
+ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
118
+ })
119
+ @triton.autotune(
120
+ configs=[
121
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
122
+ for num_warps in [2, 4, 8]
123
+ for num_stages in [2, 3, 4, 5]
124
+ ],
125
+ key=['H', 'BT', 'IS_VARLEN'],
126
+ )
127
+ @triton.jit(do_not_specialize=['T'])
128
+ def merge_16x16_to_64x64_inverse_kernel(
129
+ A,
130
+ Ad,
131
+ Ai,
132
+ cu_seqlens,
133
+ chunk_indices,
134
+ T,
135
+ H: tl.constexpr,
136
+ BT: tl.constexpr,
137
+ IS_VARLEN: tl.constexpr
138
+ ):
139
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
140
+ i_b, i_h = i_bh // H, i_bh % H
141
+ if IS_VARLEN:
142
+ i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
143
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
144
+ T = eos - bos
145
+ else:
146
+ bos, eos = i_b * T, i_b * T + T
147
+
148
+ A += (bos*H + i_h) * 64
149
+ Ad += (bos*H + i_h) * 16
150
+ Ai += (bos*H + i_h) * 64
151
+
152
+ p_A_21 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
153
+ p_A_32 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
154
+ p_A_31 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
155
+ p_A_43 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
156
+ p_A_42 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
157
+ p_A_41 = tl.make_block_ptr(A, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
158
+ p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64, 0), (16, 16), (1, 0))
159
+ p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
160
+ p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
161
+ p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H*16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
162
+
163
+ A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
164
+ A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
165
+ A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
166
+ A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
167
+ A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
168
+ A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
169
+
170
+ Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
171
+ Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
172
+ Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
173
+ Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
174
+
175
+ Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee')
176
+ Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee')
177
+ Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee')
178
+
179
+ Ai_31 = -tl.dot(
180
+ Ai_33,
181
+ tl.dot(A_31, Ai_11, input_precision='ieee') +
182
+ tl.dot(A_32, Ai_21, input_precision='ieee'),
183
+ input_precision='ieee'
184
+ )
185
+ Ai_42 = -tl.dot(
186
+ Ai_44,
187
+ tl.dot(A_42, Ai_22, input_precision='ieee') +
188
+ tl.dot(A_43, Ai_32, input_precision='ieee'),
189
+ input_precision='ieee'
190
+ )
191
+ Ai_41 = -tl.dot(
192
+ Ai_44,
193
+ tl.dot(A_41, Ai_11, input_precision='ieee') +
194
+ tl.dot(A_42, Ai_21, input_precision='ieee') +
195
+ tl.dot(A_43, Ai_31, input_precision='ieee'),
196
+ input_precision='ieee'
197
+ )
198
+
199
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64, 0), (16, 16), (1, 0))
200
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0))
201
+ p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0))
202
+ p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0))
203
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0))
204
+ p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0))
205
+ p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0))
206
+ p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0))
207
+ p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0))
208
+ p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H*64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0))
209
+ tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
210
+ tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
211
+ tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
212
+ tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
213
+ tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
214
+ tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
215
+ tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
216
+ tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
217
+ tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
218
+ tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
219
+
220
+
221
+ @input_guard
222
+ def solve_tril(
223
+ A: torch.Tensor,
224
+ cu_seqlens: Optional[torch.Tensor] = None,
225
+ output_dtype: torch.dtype = torch.float
226
+ ) -> torch.Tensor:
227
+ """
228
+ Compute the inverse of the lower triangular matrix
229
+ A should be strictly lower triangular, i.e., A.triu() == 0.
230
+
231
+ Args:
232
+ A (torch.Tensor):
233
+ [B, T, H, K]
234
+ cu_seqlens (torch.Tensor):
235
+ The cumulative sequence lengths of the input tensor.
236
+ Default: None.
237
+ output_dtype (torch.dtype):
238
+ The dtype of the output tensor. Default: `torch.float`
239
+
240
+ Returns:
241
+ (I + A)^-1 with the same shape as A
242
+ """
243
+ assert A.shape[-1] in [16, 32, 64]
244
+
245
+ B, T, H, BT = A.shape
246
+ Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
247
+
248
+ chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
249
+ NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
250
+ solve_tril_16x16_kernel[NT, B * H](
251
+ A=A,
252
+ Ad=Ad,
253
+ cu_seqlens=cu_seqlens,
254
+ chunk_indices=chunk_indices,
255
+ T=T,
256
+ H=H,
257
+ BT=BT,
258
+ )
259
+ if BT == 16:
260
+ return Ad
261
+
262
+ Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype)
263
+ merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
264
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
265
+ NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
266
+ merge_fn[NT, B * H](
267
+ A=A,
268
+ Ad=Ad,
269
+ Ai=Ai,
270
+ cu_seqlens=cu_seqlens,
271
+ chunk_indices=chunk_indices,
272
+ T=T,
273
+ H=H,
274
+ BT=BT,
275
+ )
276
+ return Ai
flame/__init__.py ADDED
File without changes
flame/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
flame/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (167 Bytes). View file
 
flame/__pycache__/data.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
flame/__pycache__/data.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
flame/__pycache__/logging.cpython-312.pyc ADDED
Binary file (6.44 kB). View file
 
flame/__pycache__/parser.cpython-310.pyc ADDED
Binary file (2.89 kB). View file
 
flame/__pycache__/parser.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
flame/data.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Iterable, List, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from datasets import Dataset, IterableDataset
12
+ from flame.logging import get_logger
13
+ from transformers import PreTrainedTokenizer
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class HuggingfaceDataset(IterableDataset):
19
+
20
+ def __init__(
21
+ self,
22
+ dataset: Dataset,
23
+ tokenizer: PreTrainedTokenizer,
24
+ context_len: int = 2048,
25
+ rank: int = 0,
26
+ world_size: int = 1,
27
+ buffer_size: int = 1024
28
+ ) -> HuggingfaceDataset:
29
+
30
+ self.dataset = dataset
31
+ self.tokenizer = tokenizer
32
+
33
+ self.data = dataset.shard(world_size, rank)
34
+ self.context_len = context_len
35
+ self.rank = rank
36
+ self.world_size = world_size
37
+ self.buffer_size = buffer_size
38
+
39
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
40
+ self.dtype = torch.int16
41
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
42
+ self.dtype = torch.int32
43
+ else:
44
+ self.dtype = torch.int64
45
+ self.states = None
46
+ self.buffer = torch.tensor([], dtype=self.dtype)
47
+ self.tokens = []
48
+ self.rand_id = 0
49
+ self.token_id = 0
50
+ self.rng_state = None
51
+ self._epoch = 0
52
+
53
+ def __iter__(self):
54
+ g = torch.Generator()
55
+ g.manual_seed(self._epoch + self.rank)
56
+ if self.rng_state is not None:
57
+ g.set_state(self.rng_state)
58
+
59
+ rand_it = self.randint(0, self.buffer_size, g=g)
60
+ if self.states is not None:
61
+ self.data.load_state_dict(self.states)
62
+
63
+ # max number of tokens allowed in the chunk buffer
64
+ n_tokens = self.buffer_size * self.context_len
65
+
66
+ while True:
67
+ for sample in self.tokenize(self.data):
68
+ # keep appending the samples to the token buffer
69
+ self.tokens += sample
70
+ # if the token buffer is full, start sampling
71
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, context_len] for efficiency
72
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
73
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
74
+ self.tokens = self.tokens[n_tokens:]
75
+ if len(self.buffer) == self.buffer_size:
76
+ yield from self.sample(rand_it)
77
+
78
+ n_chunks = len(self.tokens) // self.context_len
79
+ # handle the left tokens in the buffer
80
+ if n_chunks > 0:
81
+ n_tokens = n_chunks * self.context_len
82
+ indices = torch.randperm(n_chunks, generator=g).tolist()
83
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
84
+ self.tokens = self.tokens[n_tokens:]
85
+ for i in indices:
86
+ yield {'input_ids': self.buffer[i]}
87
+
88
+ def tokenize(self, data, batch_size: int = 64):
89
+ texts, states = [], []
90
+ for sample in data:
91
+ texts.append(sample['text'])
92
+ states.append(self.data.state_dict())
93
+ if len(texts) == batch_size:
94
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
95
+ self.states = s
96
+ yield tokenized
97
+ texts, states = [], []
98
+ if len(texts) > 0:
99
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
100
+ self.states = s
101
+ yield tokenized
102
+
103
+ def sample(self, indices):
104
+ n_tokens = (len(self.tokens) // self.context_len) * self.context_len
105
+ while self.token_id < n_tokens:
106
+ i = next(indices)
107
+ start, end = self.token_id, self.token_id + self.context_len
108
+ self.token_id += self.context_len
109
+ yield {'input_ids': self.buffer[i].to(torch.long)}
110
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
111
+ self.token_id = 0
112
+ self.tokens = self.tokens[n_tokens:]
113
+
114
+ def randint(
115
+ self,
116
+ low: int,
117
+ high: int,
118
+ batch_size: int = 1024,
119
+ g: torch.Generator = torch.Generator()
120
+ ) -> Iterable[int]:
121
+ indices = torch.empty(batch_size, dtype=torch.long)
122
+ while True:
123
+ # record the generator states before sampling
124
+ self.rng_state = g.get_state()
125
+ indices = torch.randint(low, high, (batch_size,), out=indices, generator=g)
126
+ for i in indices[self.rand_id:].tolist():
127
+ self.rand_id += 1
128
+ yield i
129
+ self.rand_id = 0
130
+
131
+ def set_epoch(self, epoch):
132
+ self._epoch = epoch
133
+ if hasattr(self.dataset, "set_epoch"):
134
+ self.dataset.set_epoch(epoch)
135
+
136
+ def state_dict(self):
137
+ return {
138
+ 'states': self.states,
139
+ 'buffer': self.buffer.clone(),
140
+ 'tokens': deepcopy(self.tokens),
141
+ 'rand_id': self.rand_id,
142
+ 'token_id': self.token_id,
143
+ 'rng_state': self.rng_state,
144
+ 'epoch': self._epoch
145
+ }
146
+
147
+ def load_state_dict(self, state_dict):
148
+ self.states = state_dict['states']
149
+ self.buffer = state_dict['buffer'].clone()
150
+ self.tokens = deepcopy(state_dict['tokens'])
151
+ self.rand_id = state_dict['rand_id']
152
+ self.token_id = state_dict['token_id']
153
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
154
+ self._epoch = state_dict['epoch']
155
+
156
+
157
+ @dataclass
158
+ class DataCollatorForLanguageModeling:
159
+ """
160
+ Data collator used for language modeling.
161
+
162
+ Args:
163
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
164
+ The tokenizer used for encoding the data.
165
+ varlen (`bool`):
166
+ Whether to return sequences with variable lengths.
167
+ If `True`, the offsets indicating the start and end of each sequence will be returned.
168
+ For example, if the sequence lengths are `[4, 8, 12]`,
169
+ the returned `input_ids` will be a long flattened tensor of shape `[1, 24]`, with `offsets` being `[0, 4, 12, 24]`.
170
+ If `False`, the `input_ids` with shape `[batch_size, seq_len]` will be returned directly.
171
+ return_tensors (`str`):
172
+ The type of Tensor to return. Allowable values are "pt".
173
+ """
174
+
175
+ tokenizer: PreTrainedTokenizer
176
+ varlen: bool = False
177
+ return_tensors: str = "pt"
178
+
179
+ def __call__(
180
+ self,
181
+ examples: List[Union[List[int], Dict[str, Any]]]
182
+ ) -> Dict[str, Any]:
183
+ if not isinstance(examples[0], Dict):
184
+ examples = [{'input_ids': example} for example in examples]
185
+
186
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
187
+ tensorized = {}
188
+ for key in ['input_ids', 'offsets']:
189
+ if key not in example:
190
+ continue
191
+ if isinstance(example[key], List):
192
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
193
+ elif isinstance(example[key], np.ndarray):
194
+ tensorized[key] = torch.from_numpy(example[key])
195
+ else:
196
+ tensorized[key] = example[key]
197
+ return tensorized
198
+
199
+ examples = list(map(tensorize, examples))
200
+
201
+ if not self.varlen:
202
+ length_of_first = examples[0]['input_ids'].size(0)
203
+ # Check if padding is necessary.
204
+ if all(example['input_ids'].size(0) == length_of_first for example in examples):
205
+ batch = {
206
+ 'input_ids': torch.stack([example['input_ids'] for example in examples], dim=0),
207
+ }
208
+ else:
209
+ # If yes, check if we have a `pad_token`.
210
+ if self.tokenizer._pad_token is None:
211
+ raise ValueError(
212
+ f"You are attempting to pad samples but the tokenizer you are using "
213
+ f"({self.tokenizer.__class__.__name__}) does not have a pad token."
214
+ )
215
+ batch = self.tokenizer.pad(examples, return_tensors=self.return_tensors, return_attention_mask=False)
216
+ else:
217
+ if len(examples) > 1:
218
+ raise ValueError("The batch size must be 1 for variable length inputs.")
219
+ batch = {
220
+ 'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)
221
+ }
222
+ if 'offsets' in examples[0]:
223
+ batch['offsets'] = torch.cat([example['offsets'] for example in examples], dim=0).unsqueeze(0)
224
+ else:
225
+ # determine boundaries by bos/eos positions
226
+ if self.tokenizer.add_bos_token:
227
+ offsets = []
228
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
229
+ offsets.append(torch.tensor([0], dtype=torch.long))
230
+ offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1])
231
+ offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
232
+ batch['offsets'] = torch.cat(offsets, dim=0)
233
+ elif self.tokenizer.add_eos_token:
234
+ offsets = [torch.tensor([0], dtype=torch.long)]
235
+ offsets.append(torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1)
236
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
237
+ offsets.append(torch.tensor([len(batch['input_ids'][0])], dtype=torch.long))
238
+ batch['offsets'] = torch.cat(offsets, dim=0)
239
+ else:
240
+ raise ValueError("You must allow the tokenizer to add either a bos or eos token as separators.")
241
+
242
+ labels = batch['input_ids'].clone()
243
+ if self.tokenizer.pad_token_id is not None:
244
+ labels[labels == self.tokenizer.pad_token_id] = -100
245
+ batch["labels"] = labels
246
+ return batch
flame/logging.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import sys
7
+ import time
8
+
9
+ from transformers.trainer_callback import (ExportableState, TrainerCallback,
10
+ TrainerControl, TrainerState)
11
+ from transformers.training_args import TrainingArguments
12
+
13
+
14
+ def get_logger(name: str = None) -> logging.Logger:
15
+ formatter = logging.Formatter(
16
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
17
+ )
18
+ handler = logging.StreamHandler(sys.stdout)
19
+ handler.setFormatter(formatter)
20
+
21
+ logger = logging.getLogger(name)
22
+ if 'RANK' in os.environ and int(os.environ['RANK']) == 0:
23
+ logger.setLevel(logging.INFO)
24
+ logger.addHandler(handler)
25
+
26
+ return logger
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+ LOG_FILE_NAME = "trainer_log.jsonl"
32
+
33
+
34
+ class LogCallback(TrainerCallback, ExportableState):
35
+ def __init__(self, start_time: float = None, elapsed_time: float = None):
36
+
37
+ self.start_time = time.time() if start_time is None else start_time
38
+ self.elapsed_time = 0 if elapsed_time is None else elapsed_time
39
+ self.last_time = self.start_time
40
+
41
+ def on_train_begin(
42
+ self,
43
+ args: TrainingArguments,
44
+ state: TrainerState,
45
+ control: TrainerControl,
46
+ **kwargs
47
+ ):
48
+ r"""
49
+ Event called at the beginning of training.
50
+ """
51
+ if state.is_local_process_zero:
52
+ if not args.resume_from_checkpoint:
53
+ self.start_time = time.time()
54
+ self.elapsed_time = 0
55
+ else:
56
+ self.start_time = state.stateful_callbacks['LogCallback']['start_time']
57
+ self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']
58
+
59
+ if args.save_on_each_node:
60
+ if not state.is_local_process_zero:
61
+ return
62
+ else:
63
+ if not state.is_world_process_zero:
64
+ return
65
+
66
+ self.last_time = time.time()
67
+ if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
68
+ logger.warning("Previous log file in this folder will be deleted.")
69
+ os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
70
+
71
+ def on_log(
72
+ self,
73
+ args: TrainingArguments,
74
+ state: TrainerState,
75
+ control: TrainerControl,
76
+ logs,
77
+ **kwargs
78
+ ):
79
+ if args.save_on_each_node:
80
+ if not state.is_local_process_zero:
81
+ return
82
+ else:
83
+ if not state.is_world_process_zero:
84
+ return
85
+
86
+ self.elapsed_time += time.time() - self.last_time
87
+ self.last_time = time.time()
88
+ if 'num_input_tokens_seen' in logs:
89
+ logs['num_tokens'] = logs.pop('num_input_tokens_seen')
90
+ state.log_history[-1].pop('num_input_tokens_seen')
91
+ throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
92
+ state.log_history[-1]['throughput'] = logs['throughput'] = throughput
93
+ state.stateful_callbacks["LogCallback"] = self.state()
94
+
95
+ logs = dict(
96
+ current_steps=state.global_step,
97
+ total_steps=state.max_steps,
98
+ loss=state.log_history[-1].get("loss", None),
99
+ eval_loss=state.log_history[-1].get("eval_loss", None),
100
+ predict_loss=state.log_history[-1].get("predict_loss", None),
101
+ learning_rate=state.log_history[-1].get("learning_rate", None),
102
+ epoch=state.log_history[-1].get("epoch", None),
103
+ percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
104
+ )
105
+
106
+ os.makedirs(args.output_dir, exist_ok=True)
107
+ with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
108
+ f.write(json.dumps(logs) + "\n")
109
+
110
+ def state(self) -> dict:
111
+ return {
112
+ 'start_time': self.start_time,
113
+ 'elapsed_time': self.elapsed_time
114
+ }
115
+
116
+ @classmethod
117
+ def from_state(cls, state):
118
+ return cls(state['start_time'], state['elapsed_time'])
flame/parser.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ import transformers
9
+ from transformers import HfArgumentParser, TrainingArguments
10
+
11
+ from flame.logging import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class TrainingArguments(TrainingArguments):
18
+
19
+ model_name_or_path: str = field(
20
+ default=None,
21
+ metadata={
22
+ "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
23
+ },
24
+ )
25
+ tokenizer: str = field(
26
+ default="fla-hub/gla-1.3B-100B",
27
+ metadata={"help": "Name of the tokenizer to use."}
28
+ )
29
+ use_fast_tokenizer: bool = field(
30
+ default=False,
31
+ metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
32
+ )
33
+ from_config: bool = field(
34
+ default=True,
35
+ metadata={"help": "Whether to initialize models from scratch."},
36
+ )
37
+ dataset: Optional[str] = field(
38
+ default=None,
39
+ metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."},
40
+ )
41
+ dataset_name: Optional[str] = field(
42
+ default=None,
43
+ metadata={"help": "The name of provided dataset(s) to use."},
44
+ )
45
+ cache_dir: str = field(
46
+ default=None,
47
+ metadata={"help": "Path to the cached tokenized dataset."},
48
+ )
49
+ split: str = field(
50
+ default="train",
51
+ metadata={"help": "Which dataset split to use for training and evaluation."},
52
+ )
53
+ streaming: bool = field(
54
+ default=False,
55
+ metadata={"help": "Enable dataset streaming."},
56
+ )
57
+ hf_hub_token: Optional[str] = field(
58
+ default=None,
59
+ metadata={"help": "Auth token to log in with Hugging Face Hub."},
60
+ )
61
+ preprocessing_num_workers: Optional[int] = field(
62
+ default=None,
63
+ metadata={"help": "The number of processes to use for the pre-processing."},
64
+ )
65
+ buffer_size: int = field(
66
+ default=2048,
67
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
68
+ )
69
+ context_length: int = field(
70
+ default=2048,
71
+ metadata={"help": "The context length of the tokenized inputs in the dataset."},
72
+ )
73
+ varlen: bool = field(
74
+ default=False,
75
+ metadata={"help": "Enable training with variable length inputs."},
76
+ )
77
+
78
+
79
+ def get_train_args():
80
+ parser = HfArgumentParser(TrainingArguments)
81
+ args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
82
+
83
+ if unknown_args:
84
+ print(parser.format_help())
85
+ print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
86
+ raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
87
+
88
+ if args.should_log:
89
+ transformers.utils.logging.set_verbosity(args.get_process_log_level())
90
+ transformers.utils.logging.enable_default_handler()
91
+ transformers.utils.logging.enable_explicit_format()
92
+ # set seeds manually
93
+ transformers.set_seed(args.seed)
94
+ return args