hypnopump commited on
Commit
81d95e1
·
verified ·
1 Parent(s): 260ab8a

Upload folder using huggingface_hub

Browse files
Files changed (31) hide show
  1. .gitattributes +9 -0
  2. README.md +61 -0
  3. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  4. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/__init__.py +1 -0
  5. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_interface.py +1101 -0
  7. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  8. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/__init__.py +1 -0
  9. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  10. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_interface.py +1101 -0
  11. build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  12. build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/__init__.py +1 -0
  13. build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  14. build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_interface.py +1116 -0
  15. build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  16. build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/__init__.py +1 -0
  17. build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  18. build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_interface.py +1116 -0
  19. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  20. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/__init__.py +1 -0
  21. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  22. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_interface.py +1101 -0
  23. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  24. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/__init__.py +1 -0
  25. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  26. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_interface.py +1101 -0
  27. build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/_C.abi3.so +3 -0
  28. build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/__init__.py +1 -0
  29. build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/flash_attn_config.py +7 -0
  30. build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/flash_attn_interface.py +1101 -0
  31. config.json +1 -0
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ build/torch28-cxx11-cu126-x86_64-linux/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
37
+ build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
38
+ build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
39
+ build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
40
+ build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
41
+ build/torch210-cxx11-cu129-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
42
+ build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
43
+ build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
44
+ build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flash Attention 3 compatible with `torch.compile`. See [this PR](https://github.com/Dao-AILab/flash-attention/pull/1769) by guilhermeleobas for more details.
2
+
3
+ There is a build here for Torch 2.8.0 and a build for Torch Nightlies from 08/30 onward.
4
+
5
+ Reproduce:
6
+
7
+ ## Torch 2.8.0 Build
8
+
9
+ Compiled from `https://github.com/varunneal/flash-attention` on branch `guilhermeleobas/fa3-compile`.
10
+
11
+ Compilation commands:
12
+
13
+ ```
14
+ pip install -U pip wheel setuptools ninja numpy packaging psutil
15
+ pip install torch==2.8.0
16
+
17
+ git clone https://github.com/varunneal/flash-attention
18
+ cd flash-attention/hopper
19
+
20
+ export MAX_JOBS=32
21
+ export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch
22
+ export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only
23
+ export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8
24
+
25
+ # Optional, for faster compilation time
26
+ export FLASH_ATTENTION_DISABLE_HDIM64=TRUE
27
+ export FLASH_ATTENTION_DISABLE_HDIM96=TRUE
28
+ export FLASH_ATTENTION_DISABLE_HDIM192=TRUE
29
+ export FLASH_ATTENTION_DISABLE_HDIM256=TRUE
30
+
31
+ python setup.py bdist_wheel
32
+ ```
33
+
34
+ ## Torch Nightlies build
35
+
36
+ Compiled from `https://github.com/varunneal/flash-attention` on branch `stable`.
37
+
38
+ This is a custom fork that combines [ABI Compatibility](https://github.com/Dao-AILab/flash-attention/pull/1791) with `torch.compile` compatbility.
39
+ This build should be consistent with Torch Nightlies from 08/30 onward.
40
+
41
+ Compilation commands:
42
+
43
+
44
+ ```
45
+ pip install -U pip wheel setuptools ninja numpy packaging psutil
46
+ # Any Torch Nightly after 08/30 should be alright
47
+ pip install --pre "torch==2.10.0.dev20250928+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126
48
+
49
+ git clone https://github.com/varunneal/flash-attention
50
+ cd flash-attention/hopper
51
+
52
+ export MAX_JOBS=32
53
+ export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch
54
+ export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only
55
+ export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8
56
+
57
+
58
+ python setup.py bdist_wheel
59
+ ```
60
+
61
+ Please contact me if you would like me to build wheels for any other version of python or torch.
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02f36ddbcfde635b8a04d2c877d3cda292b866e080a995b183f903c30f4c311b
3
+ size 811890112
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02f36ddbcfde635b8a04d2c877d3cda292b866e080a995b183f903c30f4c311b
3
+ size 811890112
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85ced01718a7e3e0488fd2395ecefc72025185e0bc7423c134495752494ae63a
3
+ size 124877472
build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': True, 'FLASHATTENTION_DISABLE_HDIM96': True, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': True, 'FLASHATTENTION_DISABLE_HDIM256': True, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch28-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ dq: Optional[torch.Tensor] = None,
256
+ dk: Optional[torch.Tensor] = None,
257
+ dv: Optional[torch.Tensor] = None,
258
+ softmax_scale: Optional[float] = None,
259
+ is_causal: bool = False,
260
+ window_size_left: int = -1,
261
+ window_size_right: int = -1,
262
+ softcap: float = 0.0,
263
+ deterministic: bool = False,
264
+ sm_margin: int = 0,
265
+ ) -> torch.Tensor:
266
+ # dq, dk, dv are allocated by us so they should already be contiguous
267
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
268
+ softmax_d, *rest = flash_attn_3_cuda.bwd(
269
+ dout,
270
+ q,
271
+ k,
272
+ v,
273
+ out,
274
+ softmax_lse,
275
+ dq,
276
+ dk,
277
+ dv,
278
+ cu_seqlens_q,
279
+ cu_seqlens_k,
280
+ sequed_q,
281
+ sequed_k,
282
+ max_seqlen_q,
283
+ max_seqlen_k,
284
+ softmax_scale,
285
+ is_causal,
286
+ window_size_left,
287
+ window_size_right,
288
+ softcap,
289
+ deterministic,
290
+ sm_margin,
291
+ )
292
+ return softmax_d
293
+
294
+
295
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
296
+ def _flash_attn_backward_fake(
297
+ dout: torch.Tensor,
298
+ q: torch.Tensor,
299
+ k: torch.Tensor,
300
+ v: torch.Tensor,
301
+ out: torch.Tensor,
302
+ softmax_lse: torch.Tensor,
303
+ cu_seqlens_q: Optional[torch.Tensor] = None,
304
+ cu_seqlens_k: Optional[torch.Tensor] = None,
305
+ sequed_q: Optional[torch.Tensor] = None,
306
+ sequed_k: Optional[torch.Tensor] = None,
307
+ max_seqlen_q: Optional[int] = None,
308
+ max_seqlen_k: Optional[int] = None,
309
+ dq: Optional[torch.Tensor] = None,
310
+ dk: Optional[torch.Tensor] = None,
311
+ dv: Optional[torch.Tensor] = None,
312
+ softmax_scale: Optional[float] = None,
313
+ is_causal: bool = False,
314
+ window_size_left: int = -1,
315
+ window_size_right: int = -1,
316
+ softcap: float = 0.0,
317
+ deterministic: bool = False,
318
+ sm_margin: int = 0,
319
+ ) -> torch.Tensor:
320
+
321
+ is_varlen_q = cu_seqlens_q is not None
322
+ is_varlen_k = cu_seqlens_q is not None
323
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
324
+
325
+ if not is_varlen_q:
326
+ batch_size = q.size(0)
327
+ seqlen_q = q.size(1)
328
+ seqlen_k = k.size(1)
329
+ total_q = batch_size * q.size(1)
330
+ else:
331
+ batch_size = cu_seqlens_q.size(0) - 1
332
+ total_q = q.size(0)
333
+ seqlen_q = max_seqlen_q
334
+ seqlen_k = max_seqlen_k
335
+
336
+ if window_size_left >= seqlen_k - 1:
337
+ window_size_left = -1
338
+
339
+ if window_size_right >= seqlen_q - 1:
340
+ window_size_right = -1
341
+
342
+ if is_causal:
343
+ window_size_right = 0
344
+
345
+ is_causal = window_size_left < 0 and window_size_right == 0
346
+
347
+ head_size = q.size(-1)
348
+ head_size_v = v.size(-1)
349
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
350
+
351
+ # Hopper gpus uses cuda compute capabilities 9.0
352
+ cap = torch.cuda.get_device_capability(q.device)
353
+ arch = cap[0] * 10 + cap[1]
354
+
355
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
356
+
357
+ if arch < 90:
358
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
359
+
360
+ if head_size_rounded <= 64:
361
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
362
+ elif head_size_rounded <= 96:
363
+ kBlockM_sm90 = 64
364
+ elif head_size_rounded <= 128:
365
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
366
+ else:
367
+ kBlockM_sm90 = 64
368
+
369
+ kBlockM = kBlockM_sm90
370
+
371
+ num_heads = q.shape[-2]
372
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
373
+
374
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
375
+
376
+ dq = torch.empty_like(q) if dq is None else dq
377
+ dk = torch.empty_like(k) if dk is None else dk
378
+ dv = torch.empty_like(v) if dv is None else dv
379
+
380
+ if not is_varlen:
381
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
382
+ else:
383
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
384
+
385
+ return softmax_d
386
+
387
+
388
+ def setup_context(ctx, inputs, output):
389
+ q, k, v = inputs[:3]
390
+ out, softmax_lse, _, _ = output
391
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
392
+ ctx.softmax_scale = inputs[-11]
393
+ ctx.causal = inputs[-10]
394
+ ctx.window_size = [inputs[-9], inputs[-8]]
395
+ ctx.attention_chunk = inputs[-7]
396
+ ctx.softcap = inputs[-6]
397
+ ctx.sm_margin = inputs[-1]
398
+
399
+
400
+ def _backward(ctx, dout, *grads):
401
+ q, k, v, out, softmax_lse = ctx.saved_tensors
402
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
403
+ _flash_attn_backward(
404
+ dout,
405
+ q,
406
+ k,
407
+ v,
408
+ out,
409
+ softmax_lse,
410
+ None, None, # cu_seqlens_q, cu_seqlens_k,
411
+ None, None, # sequed_q, sequed_k,
412
+ None, None, # max_seqlen_q, max_seqlen_k,
413
+ dq,
414
+ dk,
415
+ dv,
416
+ ctx.softmax_scale,
417
+ ctx.causal,
418
+ ctx.window_size[0],
419
+ ctx.window_size[1],
420
+ ctx.softcap,
421
+ False, # deterministic
422
+ ctx.sm_margin,
423
+ )
424
+ return dq, dk, dv, *((None,) * 21)
425
+
426
+
427
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
428
+
429
+
430
+
431
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
432
+ @staticmethod
433
+ def forward(
434
+ ctx,
435
+ qkv,
436
+ softmax_scale,
437
+ causal,
438
+ q_descale=None, k_descale=None, v_descale=None,
439
+ window_size=(-1, -1),
440
+ attention_chunk=0,
441
+ softcap=0.0,
442
+ deterministic=False,
443
+ num_heads_q=None,
444
+ sm_margin=0,
445
+ ):
446
+ if softmax_scale is None:
447
+ softmax_scale = qkv.shape[-1] ** (-0.5)
448
+ if qkv.dim() == 5:
449
+ assert qkv.shape[-3] == 3
450
+ q, k, v = qkv.unbind(dim=-3)
451
+ else:
452
+ assert qkv.dim() == 4
453
+ assert num_heads_q is not None
454
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
455
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
456
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
457
+ out, softmax_lse, *rest = _flash_attn_forward(
458
+ q,
459
+ k,
460
+ v,
461
+ None, None, # k_new, v_new
462
+ None, # qv
463
+ None, # out
464
+ None, None, None, # cu_seqlens_q/k/k_new
465
+ None, None, # seqused_q/k
466
+ None, None, # max_seqlen_q/k
467
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
468
+ None, None, None, # rotary_cos/sin, seqlens_rotary
469
+ q_descale, k_descale, v_descale,
470
+ softmax_scale,
471
+ causal=causal,
472
+ window_size_left=window_size[0],
473
+ window_size_right=window_size[1],
474
+ attention_chunk=attention_chunk,
475
+ softcap=softcap,
476
+ sm_margin=sm_margin,
477
+ )
478
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
479
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
480
+ ctx.softmax_scale = softmax_scale
481
+ ctx.causal = causal
482
+ ctx.window_size = window_size
483
+ ctx.attention_chunk = attention_chunk
484
+ ctx.softcap = softcap
485
+ ctx.deterministic = deterministic
486
+ ctx.ndim = qkv.dim()
487
+ ctx.sm_margin = sm_margin
488
+ # return out, softmax_lse
489
+ return out
490
+
491
+ @staticmethod
492
+ def backward(ctx, dout, *args):
493
+ q, k, v, out, softmax_lse = ctx.saved_tensors
494
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
495
+ if ctx.ndim == 5:
496
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
497
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
498
+ dq, dk, dv = dqkv.unbind(dim=-3)
499
+ else:
500
+ num_heads_q = q.shape[2]
501
+ num_heads_k = k.shape[2]
502
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
503
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
504
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
505
+ _flash_attn_backward(
506
+ dout,
507
+ q,
508
+ k,
509
+ v,
510
+ out,
511
+ softmax_lse,
512
+ None, None, # cu_seqlens_q, cu_seqlens_k,
513
+ None, None, # sequed_q, sequed_k,
514
+ None, None, # max_seqlen_q, max_seqlen_k,
515
+ dq,
516
+ dk,
517
+ dv,
518
+ ctx.softmax_scale,
519
+ ctx.causal,
520
+ ctx.window_size[0],
521
+ ctx.window_size[1],
522
+ ctx.softcap,
523
+ ctx.deterministic,
524
+ ctx.sm_margin,
525
+ )
526
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
527
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
528
+
529
+
530
+ class FlashAttnFunc(torch.autograd.Function):
531
+
532
+ @staticmethod
533
+ def forward(
534
+ ctx,
535
+ q,
536
+ k,
537
+ v,
538
+ softmax_scale,
539
+ causal,
540
+ qv=None,
541
+ q_descale=None, k_descale=None, v_descale=None,
542
+ window_size=(-1, -1),
543
+ attention_chunk=0,
544
+ softcap=0.0,
545
+ num_splits=1,
546
+ pack_gqa=None,
547
+ deterministic=False,
548
+ sm_margin=0,
549
+ ):
550
+ if softmax_scale is None:
551
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
552
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
553
+ out, softmax_lse, *rest = _flash_attn_forward(
554
+ q,
555
+ k,
556
+ v,
557
+ None, None, # k_new, v_new
558
+ qv, # qv
559
+ None, # out
560
+ None, None, None, # cu_seqlens_q/k/k_new
561
+ None, None, # seqused_q/k
562
+ None, None, # max_seqlen_q/k
563
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
564
+ None, None, None, # rotary_cos/sin, seqlens_rotary
565
+ q_descale, k_descale, v_descale,
566
+ softmax_scale,
567
+ causal=causal,
568
+ window_size_left=window_size[0],
569
+ window_size_right=window_size[1],
570
+ attention_chunk=attention_chunk,
571
+ softcap=softcap,
572
+ num_splits=num_splits,
573
+ pack_gqa=pack_gqa,
574
+ sm_margin=sm_margin,
575
+ )
576
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
577
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
578
+ ctx.softmax_scale = softmax_scale
579
+ ctx.causal = causal
580
+ ctx.window_size = window_size
581
+ ctx.attention_chunk = attention_chunk
582
+ ctx.softcap = softcap
583
+ ctx.deterministic = deterministic
584
+ ctx.sm_margin = sm_margin
585
+ return out
586
+
587
+ @staticmethod
588
+ def backward(ctx, dout, *args):
589
+ q, k, v, out, softmax_lse = ctx.saved_tensors
590
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
591
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
592
+ _flash_attn_backward(
593
+ dout,
594
+ q,
595
+ k,
596
+ v,
597
+ out,
598
+ softmax_lse,
599
+ None, None, # cu_seqlens_q, cu_seqlens_k,
600
+ None, None, # sequed_q, sequed_k,
601
+ None, None, # max_seqlen_q, max_seqlen_k,
602
+ dq,
603
+ dk,
604
+ dv,
605
+ ctx.softmax_scale,
606
+ ctx.causal,
607
+ ctx.window_size[0],
608
+ ctx.window_size[1],
609
+ ctx.softcap,
610
+ ctx.deterministic,
611
+ ctx.sm_margin,
612
+ )
613
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
614
+ dk = dk[..., : k.shape[-1]]
615
+ dv = dv[..., : v.shape[-1]]
616
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
617
+
618
+
619
+ class FlashAttnVarlenFunc(torch.autograd.Function):
620
+
621
+ @staticmethod
622
+ def forward(
623
+ ctx,
624
+ q,
625
+ k,
626
+ v,
627
+ cu_seqlens_q,
628
+ cu_seqlens_k,
629
+ seqused_q,
630
+ seqused_k,
631
+ max_seqlen_q,
632
+ max_seqlen_k,
633
+ softmax_scale,
634
+ causal,
635
+ qv=None,
636
+ q_descale=None, k_descale=None, v_descale=None,
637
+ window_size=(-1, -1),
638
+ attention_chunk=0,
639
+ softcap=0.0,
640
+ num_splits=1,
641
+ pack_gqa=None,
642
+ deterministic=False,
643
+ sm_margin=0,
644
+ ):
645
+ if softmax_scale is None:
646
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
647
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
648
+ out, softmax_lse, *rest = _flash_attn_forward(
649
+ q,
650
+ k,
651
+ v,
652
+ None, None, # k_new, v_new
653
+ qv, # qv
654
+ None, # out
655
+ cu_seqlens_q,
656
+ cu_seqlens_k,
657
+ None, # cu_seqlens_k_new
658
+ seqused_q,
659
+ seqused_k,
660
+ max_seqlen_q,
661
+ max_seqlen_k,
662
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
663
+ None, None, None, # rotary_cos/sin, seqlens_rotary
664
+ q_descale, k_descale, v_descale,
665
+ softmax_scale,
666
+ causal=causal,
667
+ window_size_left=window_size[0],
668
+ window_size_right=window_size[1],
669
+ attention_chunk=attention_chunk,
670
+ softcap=softcap,
671
+ num_splits=num_splits,
672
+ pack_gqa=pack_gqa,
673
+ sm_margin=sm_margin,
674
+ )
675
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
676
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
677
+ ctx.max_seqlen_q = max_seqlen_q
678
+ ctx.max_seqlen_k = max_seqlen_k
679
+ ctx.softmax_scale = softmax_scale
680
+ ctx.causal = causal
681
+ ctx.window_size = window_size
682
+ ctx.attention_chunk = attention_chunk
683
+ ctx.softcap = softcap
684
+ ctx.deterministic = deterministic
685
+ ctx.sm_margin = sm_margin
686
+ return out
687
+
688
+ @staticmethod
689
+ def backward(ctx, dout, *args):
690
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
691
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
692
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
693
+ _flash_attn_backward(
694
+ dout,
695
+ q,
696
+ k,
697
+ v,
698
+ out,
699
+ softmax_lse,
700
+ cu_seqlens_q,
701
+ cu_seqlens_k,
702
+ seqused_q,
703
+ seqused_k,
704
+ ctx.max_seqlen_q,
705
+ ctx.max_seqlen_k,
706
+ dq,
707
+ dk,
708
+ dv,
709
+ ctx.softmax_scale,
710
+ ctx.causal,
711
+ ctx.window_size[0],
712
+ ctx.window_size[1],
713
+ ctx.softcap,
714
+ ctx.deterministic,
715
+ ctx.sm_margin,
716
+ )
717
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
718
+ dk = dk[..., : k.shape[-1]]
719
+ dv = dv[..., : v.shape[-1]]
720
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
721
+
722
+
723
+ def flash_attn_qkvpacked_func(
724
+ qkv,
725
+ softmax_scale=None,
726
+ causal=False,
727
+ q_descale=None, k_descale=None, v_descale=None,
728
+ window_size=(-1, -1),
729
+ attention_chunk=0,
730
+ softcap=0.0,
731
+ deterministic=False,
732
+ num_heads_q=None,
733
+ sm_margin=0,
734
+ ):
735
+ """dropout_p should be set to 0.0 during evaluation
736
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
737
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
738
+ of the gradients of Q, K, V.
739
+ For multi-query and grouped-query attention (MQA/GQA), please see
740
+ flash_attn_kvpacked_func and flash_attn_func.
741
+
742
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
743
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
744
+
745
+ Arguments:
746
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
747
+ dropout_p: float. Dropout probability.
748
+ softmax_scale: float. The scaling of QK^T before applying softmax.
749
+ Default to 1 / sqrt(headdim).
750
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
751
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
752
+ softcap: float. Anything > 0 activates softcapping attention.
753
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
754
+ the attention score of query i and key j.
755
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
756
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
757
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
758
+ testing only. The returned probabilities are not guaranteed to be correct
759
+ (they might not have the right scaling).
760
+ Return:
761
+ out: (batch_size, seqlen, nheads, headdim).
762
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
763
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
764
+ normalization factor).
765
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
766
+ The output of softmax (possibly with different scaling). It also encodes the dropout
767
+ pattern (negative means that location was dropped, nonnegative means it was kept).
768
+ """
769
+ return FlashAttnQKVPackedFunc.apply(
770
+ qkv,
771
+ softmax_scale,
772
+ causal,
773
+ q_descale, k_descale, v_descale,
774
+ window_size,
775
+ attention_chunk,
776
+ softcap,
777
+ deterministic,
778
+ num_heads_q,
779
+ sm_margin,
780
+ )
781
+
782
+
783
+ def flash_attn_func(
784
+ q,
785
+ k,
786
+ v,
787
+ softmax_scale=None,
788
+ causal=False,
789
+ qv=None,
790
+ q_descale=None, k_descale=None, v_descale=None,
791
+ window_size=(-1, -1),
792
+ attention_chunk=0,
793
+ softcap=0.0,
794
+ num_splits=1,
795
+ pack_gqa=None,
796
+ deterministic=False,
797
+ sm_margin=0,
798
+ ):
799
+ """dropout_p should be set to 0.0 during evaluation
800
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
801
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
802
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
803
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
804
+
805
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
806
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
807
+ 1 1 1 1 0
808
+ 1 1 1 1 1
809
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
810
+ 0 0
811
+ 0 0
812
+ 0 0
813
+ 1 0
814
+ 1 1
815
+ If the row of the mask is all zero, the output will be zero.
816
+
817
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
818
+ will only attend to keys between
819
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
820
+
821
+ Arguments:
822
+ q: (batch_size, seqlen, nheads, headdim)
823
+ k: (batch_size, seqlen, nheads_k, headdim)
824
+ v: (batch_size, seqlen, nheads_k, headdim)
825
+ dropout_p: float. Dropout probability.
826
+ softmax_scale: float. The scaling of QK^T before applying softmax.
827
+ Default to 1 / sqrt(headdim).
828
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
829
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
830
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
831
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
832
+ is added to the attention score of query i and key j.
833
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
834
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
835
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
836
+ testing only. The returned probabilities are not guaranteed to be correct
837
+ (they might not have the right scaling).
838
+ Return:
839
+ out: (batch_size, seqlen, nheads, headdim).
840
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
841
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
842
+ normalization factor).
843
+ """
844
+ return FlashAttnFunc.apply(
845
+ q,
846
+ k,
847
+ v,
848
+ softmax_scale,
849
+ causal,
850
+ qv,
851
+ q_descale, k_descale, v_descale,
852
+ window_size,
853
+ attention_chunk,
854
+ softcap,
855
+ num_splits,
856
+ pack_gqa,
857
+ deterministic,
858
+ sm_margin,
859
+ )
860
+
861
+
862
+ def flash_attn_varlen_func(
863
+ q,
864
+ k,
865
+ v,
866
+ cu_seqlens_q,
867
+ cu_seqlens_k,
868
+ max_seqlen_q,
869
+ max_seqlen_k,
870
+ seqused_q=None,
871
+ seqused_k=None,
872
+ softmax_scale=None,
873
+ causal=False,
874
+ qv=None,
875
+ q_descale=None, k_descale=None, v_descale=None,
876
+ window_size=(-1, -1),
877
+ attention_chunk=0,
878
+ softcap=0.0,
879
+ num_splits=1,
880
+ pack_gqa=None,
881
+ deterministic=False,
882
+ sm_margin=0,
883
+ ):
884
+ return FlashAttnVarlenFunc.apply(
885
+ q,
886
+ k,
887
+ v,
888
+ cu_seqlens_q,
889
+ cu_seqlens_k,
890
+ seqused_q,
891
+ seqused_k,
892
+ max_seqlen_q,
893
+ max_seqlen_k,
894
+ softmax_scale,
895
+ causal,
896
+ qv,
897
+ q_descale, k_descale, v_descale,
898
+ window_size,
899
+ attention_chunk,
900
+ softcap,
901
+ num_splits,
902
+ pack_gqa,
903
+ deterministic,
904
+ sm_margin,
905
+ )
906
+
907
+
908
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
909
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
910
+
911
+
912
+ def flash_attn_with_kvcache(
913
+ q,
914
+ k_cache,
915
+ v_cache,
916
+ k=None,
917
+ v=None,
918
+ qv=None,
919
+ rotary_cos=None,
920
+ rotary_sin=None,
921
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
922
+ cache_batch_idx: Optional[torch.Tensor] = None,
923
+ cache_leftpad: Optional[torch.Tensor] = None,
924
+ page_table: Optional[torch.Tensor] = None,
925
+ cu_seqlens_q: Optional[torch.Tensor] = None,
926
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
927
+ max_seqlen_q: Optional[int] = None,
928
+ rotary_seqlens: Optional[torch.Tensor] = None,
929
+ q_descale: Optional[torch.Tensor] = None,
930
+ k_descale: Optional[torch.Tensor] = None,
931
+ v_descale: Optional[torch.Tensor] = None,
932
+ softmax_scale=None,
933
+ causal=False,
934
+ window_size=(-1, -1), # -1 means infinite context window
935
+ attention_chunk=0,
936
+ softcap=0.0, # 0.0 means deactivated
937
+ rotary_interleaved=True,
938
+ scheduler_metadata=None,
939
+ num_splits=0, # Can be tuned for speed
940
+ pack_gqa=None, # Can be tuned for speed
941
+ sm_margin=0, # Can be tuned if some SMs are used for communication
942
+ return_softmax_lse=False,
943
+ ):
944
+ """
945
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
946
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
947
+ the previous step, and update them with the new keys/values from the current step, and do
948
+ attention with the updated cache, all in 1 kernel.
949
+
950
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
951
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
952
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
953
+
954
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
955
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
956
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
957
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
958
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
959
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
960
+
961
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
962
+
963
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
964
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
965
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
966
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
967
+
968
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
969
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
970
+ 1 1 1 1 0
971
+ 1 1 1 1 1
972
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
973
+ 0 0
974
+ 0 0
975
+ 0 0
976
+ 1 0
977
+ 1 1
978
+ If the row of the mask is all zero, the output will be zero.
979
+
980
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
981
+ will only attend to keys between
982
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
983
+
984
+ Note: Does not support backward pass.
985
+
986
+ Arguments:
987
+ q: (batch_size, seqlen, nheads, headdim)
988
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
989
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
990
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
991
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
992
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
993
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
994
+ k with k_cache, starting at the indices specified by cache_seqlens.
995
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
996
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
997
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
998
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
999
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
1000
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
1001
+ KV cache.
1002
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
1003
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
1004
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
1005
+ might come from any of the duplicate indices.
1006
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
1007
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1008
+ softmax_scale: float. The scaling of QK^T before applying softmax.
1009
+ Default to 1 / sqrt(headdim).
1010
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1011
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1012
+ softcap: float. Anything > 0 activates softcapping attention.
1013
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
1014
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1015
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1016
+ (i.e. GPT-NeoX style).
1017
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1018
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1019
+ to automatically determine the number of splits.
1020
+ Don't change this unless you know what you are doing.
1021
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1022
+
1023
+ Return:
1024
+ out: (batch_size, seqlen, nheads, headdim).
1025
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1026
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1027
+ normalization factor).
1028
+ """
1029
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1030
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1031
+ if softmax_scale is None:
1032
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1033
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1034
+ cache_seqlens = torch.full(
1035
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1036
+ )
1037
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1038
+ out, softmax_lse, *rest = _flash_attn_forward(
1039
+ q,
1040
+ k_cache,
1041
+ v_cache,
1042
+ k,
1043
+ v,
1044
+ qv,
1045
+ None, # out
1046
+ cu_seqlens_q,
1047
+ None, # cu_seqlens_k
1048
+ cu_seqlens_k_new,
1049
+ None, # seqused_q
1050
+ cache_seqlens,
1051
+ max_seqlen_q,
1052
+ None, # max_seqlen_k
1053
+ page_table,
1054
+ cache_batch_idx,
1055
+ cache_leftpad,
1056
+ rotary_cos,
1057
+ rotary_sin,
1058
+ rotary_seqlens,
1059
+ q_descale, k_descale, v_descale,
1060
+ softmax_scale,
1061
+ causal=causal,
1062
+ window_size_left=window_size[0],
1063
+ window_size_right=window_size[1],
1064
+ attention_chunk=attention_chunk,
1065
+ softcap=softcap,
1066
+ rotary_interleaved=rotary_interleaved,
1067
+ scheduler_metadata=scheduler_metadata,
1068
+ num_splits=num_splits,
1069
+ pack_gqa=pack_gqa,
1070
+ sm_margin=sm_margin,
1071
+ )
1072
+ # return (out, softmax_lse) if return_softmax_lse else out
1073
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1074
+
1075
+
1076
+ def get_scheduler_metadata(
1077
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1078
+ cache_seqlens: torch.Tensor,
1079
+ qkv_dtype=torch.bfloat16,
1080
+ headdim_v=None,
1081
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1082
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1083
+ cache_leftpad: Optional[torch.Tensor] = None,
1084
+ page_size: Optional[int] = None,
1085
+ max_seqlen_k_new=0,
1086
+ causal=False,
1087
+ window_size=(-1, -1), # -1 means infinite context window
1088
+ attention_chunk=0,
1089
+ has_softcap=False,
1090
+ num_splits=0, # Can be tuned for speed
1091
+ pack_gqa=None, # Can be tuned for speed
1092
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1093
+ ):
1094
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1095
+ if headdim_v is None:
1096
+ headdim_v = headdim
1097
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1098
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1099
+ qkv_dtype,
1100
+ cache_seqlens,
1101
+ cu_seqlens_q,
1102
+ None, # cu_seqlens_k
1103
+ cu_seqlens_k_new,
1104
+ None, # seqused_q
1105
+ cache_leftpad,
1106
+ page_size,
1107
+ max_seqlen_k_new,
1108
+ causal,
1109
+ window_size[0], window_size[1],
1110
+ attention_chunk,
1111
+ has_softcap,
1112
+ num_splits,
1113
+ pack_gqa,
1114
+ sm_margin,
1115
+ )
1116
+ return scheduler_metadata
build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85ced01718a7e3e0488fd2395ecefc72025185e0bc7423c134495752494ae63a
3
+ size 124877472
build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': True, 'FLASHATTENTION_DISABLE_HDIM96': True, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': True, 'FLASHATTENTION_DISABLE_HDIM256': True, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch28-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ dq: Optional[torch.Tensor] = None,
256
+ dk: Optional[torch.Tensor] = None,
257
+ dv: Optional[torch.Tensor] = None,
258
+ softmax_scale: Optional[float] = None,
259
+ is_causal: bool = False,
260
+ window_size_left: int = -1,
261
+ window_size_right: int = -1,
262
+ softcap: float = 0.0,
263
+ deterministic: bool = False,
264
+ sm_margin: int = 0,
265
+ ) -> torch.Tensor:
266
+ # dq, dk, dv are allocated by us so they should already be contiguous
267
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
268
+ softmax_d, *rest = flash_attn_3_cuda.bwd(
269
+ dout,
270
+ q,
271
+ k,
272
+ v,
273
+ out,
274
+ softmax_lse,
275
+ dq,
276
+ dk,
277
+ dv,
278
+ cu_seqlens_q,
279
+ cu_seqlens_k,
280
+ sequed_q,
281
+ sequed_k,
282
+ max_seqlen_q,
283
+ max_seqlen_k,
284
+ softmax_scale,
285
+ is_causal,
286
+ window_size_left,
287
+ window_size_right,
288
+ softcap,
289
+ deterministic,
290
+ sm_margin,
291
+ )
292
+ return softmax_d
293
+
294
+
295
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
296
+ def _flash_attn_backward_fake(
297
+ dout: torch.Tensor,
298
+ q: torch.Tensor,
299
+ k: torch.Tensor,
300
+ v: torch.Tensor,
301
+ out: torch.Tensor,
302
+ softmax_lse: torch.Tensor,
303
+ cu_seqlens_q: Optional[torch.Tensor] = None,
304
+ cu_seqlens_k: Optional[torch.Tensor] = None,
305
+ sequed_q: Optional[torch.Tensor] = None,
306
+ sequed_k: Optional[torch.Tensor] = None,
307
+ max_seqlen_q: Optional[int] = None,
308
+ max_seqlen_k: Optional[int] = None,
309
+ dq: Optional[torch.Tensor] = None,
310
+ dk: Optional[torch.Tensor] = None,
311
+ dv: Optional[torch.Tensor] = None,
312
+ softmax_scale: Optional[float] = None,
313
+ is_causal: bool = False,
314
+ window_size_left: int = -1,
315
+ window_size_right: int = -1,
316
+ softcap: float = 0.0,
317
+ deterministic: bool = False,
318
+ sm_margin: int = 0,
319
+ ) -> torch.Tensor:
320
+
321
+ is_varlen_q = cu_seqlens_q is not None
322
+ is_varlen_k = cu_seqlens_q is not None
323
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
324
+
325
+ if not is_varlen_q:
326
+ batch_size = q.size(0)
327
+ seqlen_q = q.size(1)
328
+ seqlen_k = k.size(1)
329
+ total_q = batch_size * q.size(1)
330
+ else:
331
+ batch_size = cu_seqlens_q.size(0) - 1
332
+ total_q = q.size(0)
333
+ seqlen_q = max_seqlen_q
334
+ seqlen_k = max_seqlen_k
335
+
336
+ if window_size_left >= seqlen_k - 1:
337
+ window_size_left = -1
338
+
339
+ if window_size_right >= seqlen_q - 1:
340
+ window_size_right = -1
341
+
342
+ if is_causal:
343
+ window_size_right = 0
344
+
345
+ is_causal = window_size_left < 0 and window_size_right == 0
346
+
347
+ head_size = q.size(-1)
348
+ head_size_v = v.size(-1)
349
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
350
+
351
+ # Hopper gpus uses cuda compute capabilities 9.0
352
+ cap = torch.cuda.get_device_capability(q.device)
353
+ arch = cap[0] * 10 + cap[1]
354
+
355
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
356
+
357
+ if arch < 90:
358
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
359
+
360
+ if head_size_rounded <= 64:
361
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
362
+ elif head_size_rounded <= 96:
363
+ kBlockM_sm90 = 64
364
+ elif head_size_rounded <= 128:
365
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
366
+ else:
367
+ kBlockM_sm90 = 64
368
+
369
+ kBlockM = kBlockM_sm90
370
+
371
+ num_heads = q.shape[-2]
372
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
373
+
374
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
375
+
376
+ dq = torch.empty_like(q) if dq is None else dq
377
+ dk = torch.empty_like(k) if dk is None else dk
378
+ dv = torch.empty_like(v) if dv is None else dv
379
+
380
+ if not is_varlen:
381
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
382
+ else:
383
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
384
+
385
+ return softmax_d
386
+
387
+
388
+ def setup_context(ctx, inputs, output):
389
+ q, k, v = inputs[:3]
390
+ out, softmax_lse, _, _ = output
391
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
392
+ ctx.softmax_scale = inputs[-11]
393
+ ctx.causal = inputs[-10]
394
+ ctx.window_size = [inputs[-9], inputs[-8]]
395
+ ctx.attention_chunk = inputs[-7]
396
+ ctx.softcap = inputs[-6]
397
+ ctx.sm_margin = inputs[-1]
398
+
399
+
400
+ def _backward(ctx, dout, *grads):
401
+ q, k, v, out, softmax_lse = ctx.saved_tensors
402
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
403
+ _flash_attn_backward(
404
+ dout,
405
+ q,
406
+ k,
407
+ v,
408
+ out,
409
+ softmax_lse,
410
+ None, None, # cu_seqlens_q, cu_seqlens_k,
411
+ None, None, # sequed_q, sequed_k,
412
+ None, None, # max_seqlen_q, max_seqlen_k,
413
+ dq,
414
+ dk,
415
+ dv,
416
+ ctx.softmax_scale,
417
+ ctx.causal,
418
+ ctx.window_size[0],
419
+ ctx.window_size[1],
420
+ ctx.softcap,
421
+ False, # deterministic
422
+ ctx.sm_margin,
423
+ )
424
+ return dq, dk, dv, *((None,) * 21)
425
+
426
+
427
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
428
+
429
+
430
+
431
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
432
+ @staticmethod
433
+ def forward(
434
+ ctx,
435
+ qkv,
436
+ softmax_scale,
437
+ causal,
438
+ q_descale=None, k_descale=None, v_descale=None,
439
+ window_size=(-1, -1),
440
+ attention_chunk=0,
441
+ softcap=0.0,
442
+ deterministic=False,
443
+ num_heads_q=None,
444
+ sm_margin=0,
445
+ ):
446
+ if softmax_scale is None:
447
+ softmax_scale = qkv.shape[-1] ** (-0.5)
448
+ if qkv.dim() == 5:
449
+ assert qkv.shape[-3] == 3
450
+ q, k, v = qkv.unbind(dim=-3)
451
+ else:
452
+ assert qkv.dim() == 4
453
+ assert num_heads_q is not None
454
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
455
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
456
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
457
+ out, softmax_lse, *rest = _flash_attn_forward(
458
+ q,
459
+ k,
460
+ v,
461
+ None, None, # k_new, v_new
462
+ None, # qv
463
+ None, # out
464
+ None, None, None, # cu_seqlens_q/k/k_new
465
+ None, None, # seqused_q/k
466
+ None, None, # max_seqlen_q/k
467
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
468
+ None, None, None, # rotary_cos/sin, seqlens_rotary
469
+ q_descale, k_descale, v_descale,
470
+ softmax_scale,
471
+ causal=causal,
472
+ window_size_left=window_size[0],
473
+ window_size_right=window_size[1],
474
+ attention_chunk=attention_chunk,
475
+ softcap=softcap,
476
+ sm_margin=sm_margin,
477
+ )
478
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
479
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
480
+ ctx.softmax_scale = softmax_scale
481
+ ctx.causal = causal
482
+ ctx.window_size = window_size
483
+ ctx.attention_chunk = attention_chunk
484
+ ctx.softcap = softcap
485
+ ctx.deterministic = deterministic
486
+ ctx.ndim = qkv.dim()
487
+ ctx.sm_margin = sm_margin
488
+ # return out, softmax_lse
489
+ return out
490
+
491
+ @staticmethod
492
+ def backward(ctx, dout, *args):
493
+ q, k, v, out, softmax_lse = ctx.saved_tensors
494
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
495
+ if ctx.ndim == 5:
496
+ qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
497
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
498
+ dq, dk, dv = dqkv.unbind(dim=-3)
499
+ else:
500
+ num_heads_q = q.shape[2]
501
+ num_heads_k = k.shape[2]
502
+ qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
503
+ dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
504
+ dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
505
+ _flash_attn_backward(
506
+ dout,
507
+ q,
508
+ k,
509
+ v,
510
+ out,
511
+ softmax_lse,
512
+ None, None, # cu_seqlens_q, cu_seqlens_k,
513
+ None, None, # sequed_q, sequed_k,
514
+ None, None, # max_seqlen_q, max_seqlen_k,
515
+ dq,
516
+ dk,
517
+ dv,
518
+ ctx.softmax_scale,
519
+ ctx.causal,
520
+ ctx.window_size[0],
521
+ ctx.window_size[1],
522
+ ctx.softcap,
523
+ ctx.deterministic,
524
+ ctx.sm_margin,
525
+ )
526
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
527
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
528
+
529
+
530
+ class FlashAttnFunc(torch.autograd.Function):
531
+
532
+ @staticmethod
533
+ def forward(
534
+ ctx,
535
+ q,
536
+ k,
537
+ v,
538
+ softmax_scale,
539
+ causal,
540
+ qv=None,
541
+ q_descale=None, k_descale=None, v_descale=None,
542
+ window_size=(-1, -1),
543
+ attention_chunk=0,
544
+ softcap=0.0,
545
+ num_splits=1,
546
+ pack_gqa=None,
547
+ deterministic=False,
548
+ sm_margin=0,
549
+ ):
550
+ if softmax_scale is None:
551
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
552
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
553
+ out, softmax_lse, *rest = _flash_attn_forward(
554
+ q,
555
+ k,
556
+ v,
557
+ None, None, # k_new, v_new
558
+ qv, # qv
559
+ None, # out
560
+ None, None, None, # cu_seqlens_q/k/k_new
561
+ None, None, # seqused_q/k
562
+ None, None, # max_seqlen_q/k
563
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
564
+ None, None, None, # rotary_cos/sin, seqlens_rotary
565
+ q_descale, k_descale, v_descale,
566
+ softmax_scale,
567
+ causal=causal,
568
+ window_size_left=window_size[0],
569
+ window_size_right=window_size[1],
570
+ attention_chunk=attention_chunk,
571
+ softcap=softcap,
572
+ num_splits=num_splits,
573
+ pack_gqa=pack_gqa,
574
+ sm_margin=sm_margin,
575
+ )
576
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
577
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
578
+ ctx.softmax_scale = softmax_scale
579
+ ctx.causal = causal
580
+ ctx.window_size = window_size
581
+ ctx.attention_chunk = attention_chunk
582
+ ctx.softcap = softcap
583
+ ctx.deterministic = deterministic
584
+ ctx.sm_margin = sm_margin
585
+ return out
586
+
587
+ @staticmethod
588
+ def backward(ctx, dout, *args):
589
+ q, k, v, out, softmax_lse = ctx.saved_tensors
590
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
591
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
592
+ _flash_attn_backward(
593
+ dout,
594
+ q,
595
+ k,
596
+ v,
597
+ out,
598
+ softmax_lse,
599
+ None, None, # cu_seqlens_q, cu_seqlens_k,
600
+ None, None, # sequed_q, sequed_k,
601
+ None, None, # max_seqlen_q, max_seqlen_k,
602
+ dq,
603
+ dk,
604
+ dv,
605
+ ctx.softmax_scale,
606
+ ctx.causal,
607
+ ctx.window_size[0],
608
+ ctx.window_size[1],
609
+ ctx.softcap,
610
+ ctx.deterministic,
611
+ ctx.sm_margin,
612
+ )
613
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
614
+ dk = dk[..., : k.shape[-1]]
615
+ dv = dv[..., : v.shape[-1]]
616
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
617
+
618
+
619
+ class FlashAttnVarlenFunc(torch.autograd.Function):
620
+
621
+ @staticmethod
622
+ def forward(
623
+ ctx,
624
+ q,
625
+ k,
626
+ v,
627
+ cu_seqlens_q,
628
+ cu_seqlens_k,
629
+ seqused_q,
630
+ seqused_k,
631
+ max_seqlen_q,
632
+ max_seqlen_k,
633
+ softmax_scale,
634
+ causal,
635
+ qv=None,
636
+ q_descale=None, k_descale=None, v_descale=None,
637
+ window_size=(-1, -1),
638
+ attention_chunk=0,
639
+ softcap=0.0,
640
+ num_splits=1,
641
+ pack_gqa=None,
642
+ deterministic=False,
643
+ sm_margin=0,
644
+ ):
645
+ if softmax_scale is None:
646
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
647
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
648
+ out, softmax_lse, *rest = _flash_attn_forward(
649
+ q,
650
+ k,
651
+ v,
652
+ None, None, # k_new, v_new
653
+ qv, # qv
654
+ None, # out
655
+ cu_seqlens_q,
656
+ cu_seqlens_k,
657
+ None, # cu_seqlens_k_new
658
+ seqused_q,
659
+ seqused_k,
660
+ max_seqlen_q,
661
+ max_seqlen_k,
662
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
663
+ None, None, None, # rotary_cos/sin, seqlens_rotary
664
+ q_descale, k_descale, v_descale,
665
+ softmax_scale,
666
+ causal=causal,
667
+ window_size_left=window_size[0],
668
+ window_size_right=window_size[1],
669
+ attention_chunk=attention_chunk,
670
+ softcap=softcap,
671
+ num_splits=num_splits,
672
+ pack_gqa=pack_gqa,
673
+ sm_margin=sm_margin,
674
+ )
675
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
676
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
677
+ ctx.max_seqlen_q = max_seqlen_q
678
+ ctx.max_seqlen_k = max_seqlen_k
679
+ ctx.softmax_scale = softmax_scale
680
+ ctx.causal = causal
681
+ ctx.window_size = window_size
682
+ ctx.attention_chunk = attention_chunk
683
+ ctx.softcap = softcap
684
+ ctx.deterministic = deterministic
685
+ ctx.sm_margin = sm_margin
686
+ return out
687
+
688
+ @staticmethod
689
+ def backward(ctx, dout, *args):
690
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
691
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
692
+ dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
693
+ _flash_attn_backward(
694
+ dout,
695
+ q,
696
+ k,
697
+ v,
698
+ out,
699
+ softmax_lse,
700
+ cu_seqlens_q,
701
+ cu_seqlens_k,
702
+ seqused_q,
703
+ seqused_k,
704
+ ctx.max_seqlen_q,
705
+ ctx.max_seqlen_k,
706
+ dq,
707
+ dk,
708
+ dv,
709
+ ctx.softmax_scale,
710
+ ctx.causal,
711
+ ctx.window_size[0],
712
+ ctx.window_size[1],
713
+ ctx.softcap,
714
+ ctx.deterministic,
715
+ ctx.sm_margin,
716
+ )
717
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
718
+ dk = dk[..., : k.shape[-1]]
719
+ dv = dv[..., : v.shape[-1]]
720
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
721
+
722
+
723
+ def flash_attn_qkvpacked_func(
724
+ qkv,
725
+ softmax_scale=None,
726
+ causal=False,
727
+ q_descale=None, k_descale=None, v_descale=None,
728
+ window_size=(-1, -1),
729
+ attention_chunk=0,
730
+ softcap=0.0,
731
+ deterministic=False,
732
+ num_heads_q=None,
733
+ sm_margin=0,
734
+ ):
735
+ """dropout_p should be set to 0.0 during evaluation
736
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
737
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
738
+ of the gradients of Q, K, V.
739
+ For multi-query and grouped-query attention (MQA/GQA), please see
740
+ flash_attn_kvpacked_func and flash_attn_func.
741
+
742
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
743
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
744
+
745
+ Arguments:
746
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
747
+ dropout_p: float. Dropout probability.
748
+ softmax_scale: float. The scaling of QK^T before applying softmax.
749
+ Default to 1 / sqrt(headdim).
750
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
751
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
752
+ softcap: float. Anything > 0 activates softcapping attention.
753
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
754
+ the attention score of query i and key j.
755
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
756
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
757
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
758
+ testing only. The returned probabilities are not guaranteed to be correct
759
+ (they might not have the right scaling).
760
+ Return:
761
+ out: (batch_size, seqlen, nheads, headdim).
762
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
763
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
764
+ normalization factor).
765
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
766
+ The output of softmax (possibly with different scaling). It also encodes the dropout
767
+ pattern (negative means that location was dropped, nonnegative means it was kept).
768
+ """
769
+ return FlashAttnQKVPackedFunc.apply(
770
+ qkv,
771
+ softmax_scale,
772
+ causal,
773
+ q_descale, k_descale, v_descale,
774
+ window_size,
775
+ attention_chunk,
776
+ softcap,
777
+ deterministic,
778
+ num_heads_q,
779
+ sm_margin,
780
+ )
781
+
782
+
783
+ def flash_attn_func(
784
+ q,
785
+ k,
786
+ v,
787
+ softmax_scale=None,
788
+ causal=False,
789
+ qv=None,
790
+ q_descale=None, k_descale=None, v_descale=None,
791
+ window_size=(-1, -1),
792
+ attention_chunk=0,
793
+ softcap=0.0,
794
+ num_splits=1,
795
+ pack_gqa=None,
796
+ deterministic=False,
797
+ sm_margin=0,
798
+ ):
799
+ """dropout_p should be set to 0.0 during evaluation
800
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
801
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
802
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
803
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
804
+
805
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
806
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
807
+ 1 1 1 1 0
808
+ 1 1 1 1 1
809
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
810
+ 0 0
811
+ 0 0
812
+ 0 0
813
+ 1 0
814
+ 1 1
815
+ If the row of the mask is all zero, the output will be zero.
816
+
817
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
818
+ will only attend to keys between
819
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
820
+
821
+ Arguments:
822
+ q: (batch_size, seqlen, nheads, headdim)
823
+ k: (batch_size, seqlen, nheads_k, headdim)
824
+ v: (batch_size, seqlen, nheads_k, headdim)
825
+ dropout_p: float. Dropout probability.
826
+ softmax_scale: float. The scaling of QK^T before applying softmax.
827
+ Default to 1 / sqrt(headdim).
828
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
829
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
830
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
831
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
832
+ is added to the attention score of query i and key j.
833
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
834
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
835
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
836
+ testing only. The returned probabilities are not guaranteed to be correct
837
+ (they might not have the right scaling).
838
+ Return:
839
+ out: (batch_size, seqlen, nheads, headdim).
840
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
841
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
842
+ normalization factor).
843
+ """
844
+ return FlashAttnFunc.apply(
845
+ q,
846
+ k,
847
+ v,
848
+ softmax_scale,
849
+ causal,
850
+ qv,
851
+ q_descale, k_descale, v_descale,
852
+ window_size,
853
+ attention_chunk,
854
+ softcap,
855
+ num_splits,
856
+ pack_gqa,
857
+ deterministic,
858
+ sm_margin,
859
+ )
860
+
861
+
862
+ def flash_attn_varlen_func(
863
+ q,
864
+ k,
865
+ v,
866
+ cu_seqlens_q,
867
+ cu_seqlens_k,
868
+ max_seqlen_q,
869
+ max_seqlen_k,
870
+ seqused_q=None,
871
+ seqused_k=None,
872
+ softmax_scale=None,
873
+ causal=False,
874
+ qv=None,
875
+ q_descale=None, k_descale=None, v_descale=None,
876
+ window_size=(-1, -1),
877
+ attention_chunk=0,
878
+ softcap=0.0,
879
+ num_splits=1,
880
+ pack_gqa=None,
881
+ deterministic=False,
882
+ sm_margin=0,
883
+ ):
884
+ return FlashAttnVarlenFunc.apply(
885
+ q,
886
+ k,
887
+ v,
888
+ cu_seqlens_q,
889
+ cu_seqlens_k,
890
+ seqused_q,
891
+ seqused_k,
892
+ max_seqlen_q,
893
+ max_seqlen_k,
894
+ softmax_scale,
895
+ causal,
896
+ qv,
897
+ q_descale, k_descale, v_descale,
898
+ window_size,
899
+ attention_chunk,
900
+ softcap,
901
+ num_splits,
902
+ pack_gqa,
903
+ deterministic,
904
+ sm_margin,
905
+ )
906
+
907
+
908
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
909
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
910
+
911
+
912
+ def flash_attn_with_kvcache(
913
+ q,
914
+ k_cache,
915
+ v_cache,
916
+ k=None,
917
+ v=None,
918
+ qv=None,
919
+ rotary_cos=None,
920
+ rotary_sin=None,
921
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
922
+ cache_batch_idx: Optional[torch.Tensor] = None,
923
+ cache_leftpad: Optional[torch.Tensor] = None,
924
+ page_table: Optional[torch.Tensor] = None,
925
+ cu_seqlens_q: Optional[torch.Tensor] = None,
926
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
927
+ max_seqlen_q: Optional[int] = None,
928
+ rotary_seqlens: Optional[torch.Tensor] = None,
929
+ q_descale: Optional[torch.Tensor] = None,
930
+ k_descale: Optional[torch.Tensor] = None,
931
+ v_descale: Optional[torch.Tensor] = None,
932
+ softmax_scale=None,
933
+ causal=False,
934
+ window_size=(-1, -1), # -1 means infinite context window
935
+ attention_chunk=0,
936
+ softcap=0.0, # 0.0 means deactivated
937
+ rotary_interleaved=True,
938
+ scheduler_metadata=None,
939
+ num_splits=0, # Can be tuned for speed
940
+ pack_gqa=None, # Can be tuned for speed
941
+ sm_margin=0, # Can be tuned if some SMs are used for communication
942
+ return_softmax_lse=False,
943
+ ):
944
+ """
945
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
946
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
947
+ the previous step, and update them with the new keys/values from the current step, and do
948
+ attention with the updated cache, all in 1 kernel.
949
+
950
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
951
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
952
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
953
+
954
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
955
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
956
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
957
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
958
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
959
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
960
+
961
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
962
+
963
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
964
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
965
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
966
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
967
+
968
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
969
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
970
+ 1 1 1 1 0
971
+ 1 1 1 1 1
972
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
973
+ 0 0
974
+ 0 0
975
+ 0 0
976
+ 1 0
977
+ 1 1
978
+ If the row of the mask is all zero, the output will be zero.
979
+
980
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
981
+ will only attend to keys between
982
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
983
+
984
+ Note: Does not support backward pass.
985
+
986
+ Arguments:
987
+ q: (batch_size, seqlen, nheads, headdim)
988
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
989
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
990
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
991
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
992
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
993
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
994
+ k with k_cache, starting at the indices specified by cache_seqlens.
995
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
996
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
997
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
998
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
999
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
1000
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
1001
+ KV cache.
1002
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
1003
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
1004
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
1005
+ might come from any of the duplicate indices.
1006
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
1007
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1008
+ softmax_scale: float. The scaling of QK^T before applying softmax.
1009
+ Default to 1 / sqrt(headdim).
1010
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1011
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1012
+ softcap: float. Anything > 0 activates softcapping attention.
1013
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
1014
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1015
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1016
+ (i.e. GPT-NeoX style).
1017
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1018
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1019
+ to automatically determine the number of splits.
1020
+ Don't change this unless you know what you are doing.
1021
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1022
+
1023
+ Return:
1024
+ out: (batch_size, seqlen, nheads, headdim).
1025
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1026
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1027
+ normalization factor).
1028
+ """
1029
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1030
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1031
+ if softmax_scale is None:
1032
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1033
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1034
+ cache_seqlens = torch.full(
1035
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1036
+ )
1037
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1038
+ out, softmax_lse, *rest = _flash_attn_forward(
1039
+ q,
1040
+ k_cache,
1041
+ v_cache,
1042
+ k,
1043
+ v,
1044
+ qv,
1045
+ None, # out
1046
+ cu_seqlens_q,
1047
+ None, # cu_seqlens_k
1048
+ cu_seqlens_k_new,
1049
+ None, # seqused_q
1050
+ cache_seqlens,
1051
+ max_seqlen_q,
1052
+ None, # max_seqlen_k
1053
+ page_table,
1054
+ cache_batch_idx,
1055
+ cache_leftpad,
1056
+ rotary_cos,
1057
+ rotary_sin,
1058
+ rotary_seqlens,
1059
+ q_descale, k_descale, v_descale,
1060
+ softmax_scale,
1061
+ causal=causal,
1062
+ window_size_left=window_size[0],
1063
+ window_size_right=window_size[1],
1064
+ attention_chunk=attention_chunk,
1065
+ softcap=softcap,
1066
+ rotary_interleaved=rotary_interleaved,
1067
+ scheduler_metadata=scheduler_metadata,
1068
+ num_splits=num_splits,
1069
+ pack_gqa=pack_gqa,
1070
+ sm_margin=sm_margin,
1071
+ )
1072
+ # return (out, softmax_lse) if return_softmax_lse else out
1073
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1074
+
1075
+
1076
+ def get_scheduler_metadata(
1077
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1078
+ cache_seqlens: torch.Tensor,
1079
+ qkv_dtype=torch.bfloat16,
1080
+ headdim_v=None,
1081
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1082
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1083
+ cache_leftpad: Optional[torch.Tensor] = None,
1084
+ page_size: Optional[int] = None,
1085
+ max_seqlen_k_new=0,
1086
+ causal=False,
1087
+ window_size=(-1, -1), # -1 means infinite context window
1088
+ attention_chunk=0,
1089
+ has_softcap=False,
1090
+ num_splits=0, # Can be tuned for speed
1091
+ pack_gqa=None, # Can be tuned for speed
1092
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1093
+ ):
1094
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1095
+ if headdim_v is None:
1096
+ headdim_v = headdim
1097
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1098
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1099
+ qkv_dtype,
1100
+ cache_seqlens,
1101
+ cu_seqlens_q,
1102
+ None, # cu_seqlens_k
1103
+ cu_seqlens_k_new,
1104
+ None, # seqused_q
1105
+ cache_leftpad,
1106
+ page_size,
1107
+ max_seqlen_k_new,
1108
+ causal,
1109
+ window_size[0], window_size[1],
1110
+ attention_chunk,
1111
+ has_softcap,
1112
+ num_splits,
1113
+ pack_gqa,
1114
+ sm_margin,
1115
+ )
1116
+ return scheduler_metadata
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02f36ddbcfde635b8a04d2c877d3cda292b866e080a995b183f903c30f4c311b
3
+ size 811890112
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02f36ddbcfde635b8a04d2c877d3cda292b866e080a995b183f903c30f4c311b
3
+ size 811890112
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b84449c94fce73560a620696ead715bc63e8ed025d1e6949f9bbe3fd72edfec6
3
+ size 812081576
build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu129-x86_64-linux/flash_attention_3/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}