Kernels
danieldk HF Staff commited on
Commit
d1a2b62
·
1 Parent(s): 7bbcfad

Remove old flash-attn3 builds

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/__init__.py +0 -17
  2. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +0 -3
  3. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +0 -3
  4. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_ops.py +0 -9
  5. build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  6. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +0 -17
  7. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +0 -3
  8. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +0 -3
  9. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +0 -9
  10. build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  11. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/__init__.py +0 -17
  12. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +0 -3
  13. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +0 -3
  14. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_ops.py +0 -9
  15. build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  16. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/__init__.py +0 -17
  17. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +0 -3
  18. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so +0 -3
  19. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_ops.py +0 -9
  20. build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  21. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +0 -17
  22. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +0 -3
  23. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +0 -9
  24. build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  25. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/__init__.py +0 -17
  26. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so +0 -3
  27. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_ops.py +0 -9
  28. build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  29. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__init__.py +0 -17
  30. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  31. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  32. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  33. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so +0 -3
  34. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_ops.py +0 -9
  35. build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/flash_attn_interface.py +0 -828
  36. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py +0 -17
  37. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_48fe103_dirty.abi3.so +0 -3
  38. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py +0 -9
  39. build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py +0 -828
  40. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__init__.py +0 -17
  41. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  42. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  43. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
  44. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so +0 -3
  45. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_ops.py +0 -9
  46. build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/flash_attn_interface.py +0 -828
  47. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__init__.py +0 -17
  48. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc +0 -0
  49. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc +0 -0
  50. build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc +0 -0
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:21b44e8e5e447a8b8ee051d347f0e32a3446a750f79d0bd1755e553f2119aa3b
3
- size 838459656
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12d4ff964085fd02252777a2008f5ca47c90ea6a93da590e2fc5065dd5330207
3
- size 838459656
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_557701f
3
- ops = torch.ops._flash_attn3_557701f
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_557701f::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:21b44e8e5e447a8b8ee051d347f0e32a3446a750f79d0bd1755e553f2119aa3b
3
- size 838459656
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12d4ff964085fd02252777a2008f5ca47c90ea6a93da590e2fc5065dd5330207
3
- size 838459656
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_557701f
3
- ops = torch.ops._flash_attn3_557701f
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_557701f::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9627e08ec8778d2a409a2a0477572edb3e03eaca2b45e7b4810ee0a9126d6547
3
- size 838456048
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:07fe025ba95671f6ff957991f74c66063bfb10ab6737641c88f88116c9f83718
3
- size 838456048
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_557701f
3
- ops = torch.ops._flash_attn3_557701f
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_557701f::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9627e08ec8778d2a409a2a0477572edb3e03eaca2b45e7b4810ee0a9126d6547
3
- size 838456048
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_flash_attn3_557701f.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:07fe025ba95671f6ff957991f74c66063bfb10ab6737641c88f88116c9f83718
3
- size 838456048
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_557701f
3
- ops = torch.ops._flash_attn3_557701f
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_557701f::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0302224ac29ba4773d926d4cb16c01c45a374c6dd61286aae1f423f2bf495ea
3
- size 838459544
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_2e75662
3
- ops = torch.ops._flash_attn3_2e75662
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_2e75662::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_flash_attn3_2e75662.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0302224ac29ba4773d926d4cb16c01c45a374c6dd61286aae1f423f2bf495ea
3
- size 838459544
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_2e75662
3
- ops = torch.ops._flash_attn3_2e75662
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_2e75662::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (438 Bytes)
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc DELETED
Binary file (530 Bytes)
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc DELETED
Binary file (26.2 kB)
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e9aef52109e5974778e3ccc2f697c4e6050b365624c843a675ce894b938341cc
3
- size 822395576
 
 
 
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_8d4f83f
3
- ops = torch.ops._flash_attn3_8d4f83f
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_8d4f83f::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-aarch64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_flash_attn3_48fe103_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bc32b815563bc9051986a333a362ff61e37cbd967893212243292fef03b461a5
3
- size 838544688
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_48fe103_dirty
3
- ops = torch.ops._flash_attn3_48fe103_dirty
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_48fe103_dirty::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (438 Bytes)
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc DELETED
Binary file (530 Bytes)
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc DELETED
Binary file (26.2 kB)
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_flash_attn3_8d4f83f.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e9aef52109e5974778e3ccc2f697c4e6050b365624c843a675ce894b938341cc
3
- size 822395576
 
 
 
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn3_8d4f83f
3
- ops = torch.ops._flash_attn3_8d4f83f
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn3_8d4f83f::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu128-aarch64-linux/flash_attn3/flash_attn_interface.py DELETED
@@ -1,828 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ._ops import ops as flash_attn_3_cuda
9
-
10
- def maybe_contiguous(x):
11
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
12
-
13
-
14
- def _flash_attn_forward(
15
- q,
16
- k,
17
- v,
18
- k_new,
19
- v_new,
20
- qv,
21
- out,
22
- cu_seqlens_q,
23
- cu_seqlens_k,
24
- cu_seqlens_k_new,
25
- seqused_q,
26
- seqused_k,
27
- max_seqlen_q,
28
- max_seqlen_k,
29
- page_table,
30
- kv_batch_idx,
31
- leftpad_k,
32
- rotary_cos,
33
- rotary_sin,
34
- seqlens_rotary,
35
- q_descale,
36
- k_descale,
37
- v_descale,
38
- softmax_scale,
39
- causal,
40
- window_size=(-1, -1),
41
- attention_chunk=0,
42
- softcap=0.0,
43
- rotary_interleaved=True,
44
- scheduler_metadata=None,
45
- num_splits=1,
46
- pack_gqa=None,
47
- sm_margin=0):
48
- q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
49
- v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
50
- cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
51
- maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
52
- ]
53
- seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
54
- page_table, kv_batch_idx, leftpad_k = [
55
- maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
56
- ]
57
- rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
58
- seqlens_rotary = maybe_contiguous(seqlens_rotary)
59
- out, softmax_lse, *rest = flash_attn_3_cuda.fwd(
60
- q,
61
- k,
62
- v,
63
- k_new,
64
- v_new,
65
- qv,
66
- out,
67
- cu_seqlens_q,
68
- cu_seqlens_k,
69
- cu_seqlens_k_new,
70
- seqused_q,
71
- seqused_k,
72
- max_seqlen_q,
73
- max_seqlen_k,
74
- page_table,
75
- kv_batch_idx,
76
- leftpad_k,
77
- rotary_cos,
78
- rotary_sin,
79
- seqlens_rotary,
80
- q_descale,
81
- k_descale,
82
- v_descale,
83
- softmax_scale,
84
- causal,
85
- window_size[0],
86
- window_size[1],
87
- attention_chunk,
88
- softcap,
89
- rotary_interleaved,
90
- scheduler_metadata,
91
- num_splits,
92
- pack_gqa,
93
- sm_margin,
94
- )
95
- return out, softmax_lse, *rest
96
-
97
-
98
- def _flash_attn_backward(
99
- dout,
100
- q,
101
- k,
102
- v,
103
- out,
104
- softmax_lse,
105
- cu_seqlens_q,
106
- cu_seqlens_k,
107
- sequed_q,
108
- sequed_k,
109
- max_seqlen_q,
110
- max_seqlen_k,
111
- dq,
112
- dk,
113
- dv,
114
- softmax_scale,
115
- causal,
116
- window_size=(-1, -1),
117
- softcap=0.0,
118
- deterministic=False,
119
- sm_margin=0,
120
- ):
121
- # dq, dk, dv are allocated by us so they should already be contiguous
122
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
123
- dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
124
- dout,
125
- q,
126
- k,
127
- v,
128
- out,
129
- softmax_lse,
130
- dq,
131
- dk,
132
- dv,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- sequed_q,
136
- sequed_k,
137
- max_seqlen_q,
138
- max_seqlen_k,
139
- softmax_scale,
140
- causal,
141
- window_size[0],
142
- window_size[1],
143
- softcap,
144
- deterministic,
145
- sm_margin,
146
- )
147
- return dq, dk, dv, softmax_d
148
-
149
-
150
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
151
- @staticmethod
152
- def forward(
153
- ctx,
154
- qkv,
155
- softmax_scale,
156
- causal,
157
- q_descale=None, k_descale=None, v_descale=None,
158
- window_size=(-1, -1),
159
- attention_chunk=0,
160
- softcap=0.0,
161
- deterministic=False,
162
- num_heads_q=None,
163
- sm_margin=0,
164
- ):
165
- if softmax_scale is None:
166
- softmax_scale = qkv.shape[-1] ** (-0.5)
167
- if qkv.dim() == 5:
168
- assert qkv.shape[-3] == 3
169
- q, k, v = qkv.unbind(dim=-3)
170
- else:
171
- assert qkv.dim() == 4
172
- assert num_heads_q is not None
173
- num_heads_k = (qkv.shape[2] - num_heads_q) // 2
174
- assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
175
- q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
176
- out, softmax_lse, *rest = _flash_attn_forward(
177
- q,
178
- k,
179
- v,
180
- None, None, # k_new, v_new
181
- None, # qv
182
- None, # out
183
- None, None, None, # cu_seqlens_q/k/k_new
184
- None, None, # seqused_q/k
185
- None, None, # max_seqlen_q/k
186
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
187
- None, None, None, # rotary_cos/sin, seqlens_rotary
188
- q_descale, k_descale, v_descale,
189
- softmax_scale,
190
- causal=causal,
191
- window_size=window_size,
192
- attention_chunk=attention_chunk,
193
- softcap=softcap,
194
- sm_margin=sm_margin,
195
- )
196
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
197
- ctx.save_for_backward(q, k, v, out, softmax_lse)
198
- ctx.softmax_scale = softmax_scale
199
- ctx.causal = causal
200
- ctx.window_size = window_size
201
- ctx.attention_chunk = attention_chunk
202
- ctx.softcap = softcap
203
- ctx.deterministic = deterministic
204
- ctx.ndim = qkv.dim()
205
- ctx.sm_margin = sm_margin
206
- # return out, softmax_lse
207
- return out
208
-
209
- @staticmethod
210
- def backward(ctx, dout, *args):
211
- q, k, v, out, softmax_lse = ctx.saved_tensors
212
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
213
- if ctx.ndim == 5:
214
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
215
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
216
- dq, dk, dv = dqkv.unbind(dim=-3)
217
- else:
218
- num_heads_q = q.shape[2]
219
- num_heads_k = k.shape[2]
220
- qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
221
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
222
- dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
223
- _flash_attn_backward(
224
- dout,
225
- q,
226
- k,
227
- v,
228
- out,
229
- softmax_lse,
230
- None, None, # cu_seqlens_q, cu_seqlens_k,
231
- None, None, # sequed_q, sequed_k,
232
- None, None, # max_seqlen_q, max_seqlen_k,
233
- dq,
234
- dk,
235
- dv,
236
- ctx.softmax_scale,
237
- ctx.causal,
238
- ctx.window_size,
239
- ctx.softcap,
240
- ctx.deterministic,
241
- ctx.sm_margin,
242
- )
243
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
244
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
245
-
246
-
247
- class FlashAttnFunc(torch.autograd.Function):
248
-
249
- @staticmethod
250
- def forward(
251
- ctx,
252
- q,
253
- k,
254
- v,
255
- softmax_scale,
256
- causal,
257
- qv=None,
258
- q_descale=None, k_descale=None, v_descale=None,
259
- window_size=(-1, -1),
260
- attention_chunk=0,
261
- softcap=0.0,
262
- num_splits=1,
263
- pack_gqa=None,
264
- deterministic=False,
265
- sm_margin=0,
266
- ):
267
- if softmax_scale is None:
268
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
269
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
270
- out, softmax_lse, *rest = _flash_attn_forward(
271
- q,
272
- k,
273
- v,
274
- None, None, # k_new, v_new
275
- qv, # qv
276
- None, # out
277
- None, None, None, # cu_seqlens_q/k/k_new
278
- None, None, # seqused_q/k
279
- None, None, # max_seqlen_q/k
280
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
281
- None, None, None, # rotary_cos/sin, seqlens_rotary
282
- q_descale, k_descale, v_descale,
283
- softmax_scale,
284
- causal=causal,
285
- window_size=window_size,
286
- attention_chunk=attention_chunk,
287
- softcap=softcap,
288
- num_splits=num_splits,
289
- pack_gqa=pack_gqa,
290
- sm_margin=sm_margin,
291
- )
292
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
293
- ctx.save_for_backward(q, k, v, out, softmax_lse)
294
- ctx.softmax_scale = softmax_scale
295
- ctx.causal = causal
296
- ctx.window_size = window_size
297
- ctx.attention_chunk = attention_chunk
298
- ctx.softcap = softcap
299
- ctx.deterministic = deterministic
300
- ctx.sm_margin = sm_margin
301
- return out, softmax_lse
302
-
303
- @staticmethod
304
- def backward(ctx, dout, *args):
305
- q, k, v, out, softmax_lse = ctx.saved_tensors
306
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
307
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
308
- _flash_attn_backward(
309
- dout,
310
- q,
311
- k,
312
- v,
313
- out,
314
- softmax_lse,
315
- None, None, # cu_seqlens_q, cu_seqlens_k,
316
- None, None, # sequed_q, sequed_k,
317
- None, None, # max_seqlen_q, max_seqlen_k,
318
- dq,
319
- dk,
320
- dv,
321
- ctx.softmax_scale,
322
- ctx.causal,
323
- ctx.window_size,
324
- ctx.softcap,
325
- ctx.deterministic,
326
- ctx.sm_margin,
327
- )
328
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
329
- dk = dk[..., : k.shape[-1]]
330
- dv = dv[..., : v.shape[-1]]
331
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
332
-
333
-
334
- class FlashAttnVarlenFunc(torch.autograd.Function):
335
-
336
- @staticmethod
337
- def forward(
338
- ctx,
339
- q,
340
- k,
341
- v,
342
- cu_seqlens_q,
343
- cu_seqlens_k,
344
- seqused_q,
345
- seqused_k,
346
- max_seqlen_q,
347
- max_seqlen_k,
348
- softmax_scale,
349
- causal,
350
- qv=None,
351
- q_descale=None, k_descale=None, v_descale=None,
352
- window_size=(-1, -1),
353
- attention_chunk=0,
354
- softcap=0.0,
355
- num_splits=1,
356
- pack_gqa=None,
357
- deterministic=False,
358
- sm_margin=0,
359
- ):
360
- if softmax_scale is None:
361
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
362
- # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
363
- out, softmax_lse, *rest = _flash_attn_forward(
364
- q,
365
- k,
366
- v,
367
- None, None, # k_new, v_new
368
- qv, # qv
369
- None, # out
370
- cu_seqlens_q,
371
- cu_seqlens_k,
372
- None, # cu_seqlens_k_new
373
- seqused_q,
374
- seqused_k,
375
- max_seqlen_q,
376
- max_seqlen_k,
377
- None, None, None, # page_table, kv_batch_idx, leftpad_k,
378
- None, None, None, # rotary_cos/sin, seqlens_rotary
379
- q_descale, k_descale, v_descale,
380
- softmax_scale,
381
- causal=causal,
382
- window_size=window_size,
383
- attention_chunk=attention_chunk,
384
- softcap=softcap,
385
- num_splits=num_splits,
386
- pack_gqa=pack_gqa,
387
- sm_margin=sm_margin,
388
- )
389
- # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
390
- ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
391
- ctx.max_seqlen_q = max_seqlen_q
392
- ctx.max_seqlen_k = max_seqlen_k
393
- ctx.softmax_scale = softmax_scale
394
- ctx.causal = causal
395
- ctx.window_size = window_size
396
- ctx.attention_chunk = attention_chunk
397
- ctx.softcap = softcap
398
- ctx.deterministic = deterministic
399
- ctx.sm_margin = sm_margin
400
- return out, softmax_lse
401
-
402
- @staticmethod
403
- def backward(ctx, dout, *args):
404
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
405
- assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
406
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
407
- _flash_attn_backward(
408
- dout,
409
- q,
410
- k,
411
- v,
412
- out,
413
- softmax_lse,
414
- cu_seqlens_q,
415
- cu_seqlens_k,
416
- seqused_q,
417
- seqused_k,
418
- ctx.max_seqlen_q,
419
- ctx.max_seqlen_k,
420
- dq,
421
- dk,
422
- dv,
423
- ctx.softmax_scale,
424
- ctx.causal,
425
- ctx.window_size,
426
- ctx.softcap,
427
- ctx.deterministic,
428
- ctx.sm_margin,
429
- )
430
- dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
431
- dk = dk[..., : k.shape[-1]]
432
- dv = dv[..., : v.shape[-1]]
433
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
434
-
435
-
436
- def flash_attn_qkvpacked_func(
437
- qkv,
438
- softmax_scale=None,
439
- causal=False,
440
- q_descale=None, k_descale=None, v_descale=None,
441
- window_size=(-1, -1),
442
- attention_chunk=0,
443
- softcap=0.0,
444
- deterministic=False,
445
- num_heads_q=None,
446
- sm_margin=0,
447
- ):
448
- """dropout_p should be set to 0.0 during evaluation
449
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
450
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
451
- of the gradients of Q, K, V.
452
- For multi-query and grouped-query attention (MQA/GQA), please see
453
- flash_attn_kvpacked_func and flash_attn_func.
454
-
455
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
456
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
457
-
458
- Arguments:
459
- qkv: (batch_size, seqlen, 3, nheads, headdim)
460
- dropout_p: float. Dropout probability.
461
- softmax_scale: float. The scaling of QK^T before applying softmax.
462
- Default to 1 / sqrt(headdim).
463
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
464
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
465
- softcap: float. Anything > 0 activates softcapping attention.
466
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
467
- the attention score of query i and key j.
468
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
469
- which is slightly slower and uses more memory. The forward pass is always deterministic.
470
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
471
- testing only. The returned probabilities are not guaranteed to be correct
472
- (they might not have the right scaling).
473
- Return:
474
- out: (batch_size, seqlen, nheads, headdim).
475
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
476
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
477
- normalization factor).
478
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
479
- The output of softmax (possibly with different scaling). It also encodes the dropout
480
- pattern (negative means that location was dropped, nonnegative means it was kept).
481
- """
482
- return FlashAttnQKVPackedFunc.apply(
483
- qkv,
484
- softmax_scale,
485
- causal,
486
- q_descale, k_descale, v_descale,
487
- window_size,
488
- attention_chunk,
489
- softcap,
490
- deterministic,
491
- num_heads_q,
492
- sm_margin,
493
- )
494
-
495
-
496
- def flash_attn_func(
497
- q,
498
- k,
499
- v,
500
- softmax_scale=None,
501
- causal=False,
502
- qv=None,
503
- q_descale=None, k_descale=None, v_descale=None,
504
- window_size=(-1, -1),
505
- attention_chunk=0,
506
- softcap=0.0,
507
- num_splits=1,
508
- pack_gqa=None,
509
- deterministic=False,
510
- sm_margin=0,
511
- ):
512
- """dropout_p should be set to 0.0 during evaluation
513
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
514
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
515
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
516
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
517
-
518
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
519
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
520
- 1 1 1 1 0
521
- 1 1 1 1 1
522
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
523
- 0 0
524
- 0 0
525
- 0 0
526
- 1 0
527
- 1 1
528
- If the row of the mask is all zero, the output will be zero.
529
-
530
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
531
- will only attend to keys between
532
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
533
-
534
- Arguments:
535
- q: (batch_size, seqlen, nheads, headdim)
536
- k: (batch_size, seqlen, nheads_k, headdim)
537
- v: (batch_size, seqlen, nheads_k, headdim)
538
- dropout_p: float. Dropout probability.
539
- softmax_scale: float. The scaling of QK^T before applying softmax.
540
- Default to 1 / sqrt(headdim).
541
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
542
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
543
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
544
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
545
- is added to the attention score of query i and key j.
546
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
547
- which is slightly slower and uses more memory. The forward pass is always deterministic.
548
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
549
- testing only. The returned probabilities are not guaranteed to be correct
550
- (they might not have the right scaling).
551
- Return:
552
- out: (batch_size, seqlen, nheads, headdim).
553
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
554
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
555
- normalization factor).
556
- """
557
- return FlashAttnFunc.apply(
558
- q,
559
- k,
560
- v,
561
- softmax_scale,
562
- causal,
563
- qv,
564
- q_descale, k_descale, v_descale,
565
- window_size,
566
- attention_chunk,
567
- softcap,
568
- num_splits,
569
- pack_gqa,
570
- deterministic,
571
- sm_margin,
572
- )
573
-
574
-
575
- def flash_attn_varlen_func(
576
- q,
577
- k,
578
- v,
579
- cu_seqlens_q,
580
- cu_seqlens_k,
581
- max_seqlen_q,
582
- max_seqlen_k,
583
- seqused_q=None,
584
- seqused_k=None,
585
- softmax_scale=None,
586
- causal=False,
587
- qv=None,
588
- q_descale=None, k_descale=None, v_descale=None,
589
- window_size=(-1, -1),
590
- attention_chunk=0,
591
- softcap=0.0,
592
- num_splits=1,
593
- pack_gqa=None,
594
- deterministic=False,
595
- sm_margin=0,
596
- ):
597
- return FlashAttnVarlenFunc.apply(
598
- q,
599
- k,
600
- v,
601
- cu_seqlens_q,
602
- cu_seqlens_k,
603
- seqused_q,
604
- seqused_k,
605
- max_seqlen_q,
606
- max_seqlen_k,
607
- softmax_scale,
608
- causal,
609
- qv,
610
- q_descale, k_descale, v_descale,
611
- window_size,
612
- attention_chunk,
613
- softcap,
614
- num_splits,
615
- pack_gqa,
616
- deterministic,
617
- sm_margin,
618
- )
619
-
620
-
621
- def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
622
- return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
623
-
624
-
625
- def flash_attn_with_kvcache(
626
- q,
627
- k_cache,
628
- v_cache,
629
- k=None,
630
- v=None,
631
- qv=None,
632
- rotary_cos=None,
633
- rotary_sin=None,
634
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
635
- cache_batch_idx: Optional[torch.Tensor] = None,
636
- cache_leftpad: Optional[torch.Tensor] = None,
637
- page_table: Optional[torch.Tensor] = None,
638
- cu_seqlens_q: Optional[torch.Tensor] = None,
639
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
640
- max_seqlen_q: Optional[int] = None,
641
- rotary_seqlens: Optional[torch.Tensor] = None,
642
- q_descale: Optional[torch.Tensor] = None,
643
- k_descale: Optional[torch.Tensor] = None,
644
- v_descale: Optional[torch.Tensor] = None,
645
- softmax_scale=None,
646
- causal=False,
647
- window_size=(-1, -1), # -1 means infinite context window
648
- attention_chunk=0,
649
- softcap=0.0, # 0.0 means deactivated
650
- rotary_interleaved=True,
651
- scheduler_metadata=None,
652
- num_splits=0, # Can be tuned for speed
653
- pack_gqa=None, # Can be tuned for speed
654
- sm_margin=0, # Can be tuned if some SMs are used for communication
655
- return_softmax_lse=False,
656
- ):
657
- """
658
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
659
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
660
- the previous step, and update them with the new keys/values from the current step, and do
661
- attention with the updated cache, all in 1 kernel.
662
-
663
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
664
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
665
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
666
-
667
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
668
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
669
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
670
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
671
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
672
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
673
-
674
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
675
-
676
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
677
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
678
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
679
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
680
-
681
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
682
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
683
- 1 1 1 1 0
684
- 1 1 1 1 1
685
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
686
- 0 0
687
- 0 0
688
- 0 0
689
- 1 0
690
- 1 1
691
- If the row of the mask is all zero, the output will be zero.
692
-
693
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
694
- will only attend to keys between
695
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
696
-
697
- Note: Does not support backward pass.
698
-
699
- Arguments:
700
- q: (batch_size, seqlen, nheads, headdim)
701
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
702
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
703
- page_block_size must be a multiple of 256.
704
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
705
- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
706
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
707
- k with k_cache, starting at the indices specified by cache_seqlens.
708
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
709
- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
710
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
711
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
712
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
713
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
714
- KV cache.
715
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
716
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
717
- If the indices are not distinct, and k and v are provided, the values updated in the cache
718
- might come from any of the duplicate indices.
719
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
720
- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
721
- softmax_scale: float. The scaling of QK^T before applying softmax.
722
- Default to 1 / sqrt(headdim).
723
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
724
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
725
- softcap: float. Anything > 0 activates softcapping attention.
726
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
727
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
728
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
729
- (i.e. GPT-NeoX style).
730
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
731
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
732
- to automatically determine the number of splits.
733
- Don't change this unless you know what you are doing.
734
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
735
-
736
- Return:
737
- out: (batch_size, seqlen, nheads, headdim).
738
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
739
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
740
- normalization factor).
741
- """
742
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
743
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
744
- if softmax_scale is None:
745
- softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
746
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
747
- cache_seqlens = torch.full(
748
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
749
- )
750
- cache_seqlens = maybe_contiguous(cache_seqlens)
751
- out, softmax_lse, *rest = _flash_attn_forward(
752
- q,
753
- k_cache,
754
- v_cache,
755
- k,
756
- v,
757
- qv,
758
- None, # out
759
- cu_seqlens_q,
760
- None, # cu_seqlens_k
761
- cu_seqlens_k_new,
762
- None, # seqused_q
763
- cache_seqlens,
764
- max_seqlen_q,
765
- None, # max_seqlen_k
766
- page_table,
767
- cache_batch_idx,
768
- cache_leftpad,
769
- rotary_cos,
770
- rotary_sin,
771
- rotary_seqlens,
772
- q_descale, k_descale, v_descale,
773
- softmax_scale,
774
- causal=causal,
775
- window_size=window_size,
776
- attention_chunk=attention_chunk,
777
- softcap=softcap,
778
- rotary_interleaved=rotary_interleaved,
779
- scheduler_metadata=scheduler_metadata,
780
- num_splits=num_splits,
781
- pack_gqa=pack_gqa,
782
- sm_margin=sm_margin,
783
- )
784
- # return (out, softmax_lse) if return_softmax_lse else out
785
- return (out, softmax_lse, *rest) if return_softmax_lse else out
786
-
787
-
788
- def get_scheduler_metadata(
789
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
790
- cache_seqlens: torch.Tensor,
791
- qkv_dtype=torch.bfloat16,
792
- headdim_v=None,
793
- cu_seqlens_q: Optional[torch.Tensor] = None,
794
- cu_seqlens_k_new: Optional[torch.Tensor] = None,
795
- cache_leftpad: Optional[torch.Tensor] = None,
796
- page_size: Optional[int] = None,
797
- max_seqlen_k_new=0,
798
- causal=False,
799
- window_size=(-1, -1), # -1 means infinite context window
800
- attention_chunk=0,
801
- has_softcap=False,
802
- num_splits=0, # Can be tuned for speed
803
- pack_gqa=None, # Can be tuned for speed
804
- sm_margin=0, # Can be tuned if some SMs are used for communication
805
- ):
806
- cache_seqlens = maybe_contiguous(cache_seqlens)
807
- if headdim_v is None:
808
- headdim_v = headdim
809
- scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
810
- batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
811
- qkv_dtype,
812
- cache_seqlens,
813
- cu_seqlens_q,
814
- None, # cu_seqlens_k
815
- cu_seqlens_k_new,
816
- None, # seqused_q
817
- cache_leftpad,
818
- page_size,
819
- max_seqlen_k_new,
820
- causal,
821
- window_size[0], window_size[1],
822
- attention_chunk,
823
- has_softcap,
824
- num_splits,
825
- pack_gqa,
826
- sm_margin,
827
- )
828
- return scheduler_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from .flash_attn_interface import (
2
- flash_attn_combine,
3
- flash_attn_func,
4
- flash_attn_qkvpacked_func,
5
- flash_attn_varlen_func,
6
- flash_attn_with_kvcache,
7
- get_scheduler_metadata,
8
- )
9
-
10
- __all__ = [
11
- "flash_attn_combine",
12
- "flash_attn_func",
13
- "flash_attn_qkvpacked_func",
14
- "flash_attn_varlen_func",
15
- "flash_attn_with_kvcache",
16
- "get_scheduler_metadata",
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (438 Bytes)
 
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/_ops.cpython-313.pyc DELETED
Binary file (530 Bytes)
 
build/torch28-cxx11-cu129-aarch64-linux/flash_attn3/__pycache__/flash_attn_interface.cpython-313.pyc DELETED
Binary file (26.2 kB)