varunneal commited on
Commit
5569d82
·
verified ·
1 Parent(s): 884fcc7

Upload 42 files

Browse files
Files changed (43) hide show
  1. .gitattributes +6 -0
  2. build/torch210-cxx11-cu126-aarch64-linux/flash_attention_hopper/__init__.py +1 -0
  3. build/torch210-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  4. build/torch210-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  5. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/_C.abi3.so +3 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/__init__.py +1 -0
  7. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  8. build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  9. build/torch210-cxx11-cu128-aarch64-linux/flash_attention_hopper/__init__.py +1 -0
  10. build/torch210-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  11. build/torch210-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  12. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/_C.abi3.so +3 -0
  13. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/__init__.py +1 -0
  14. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  15. build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  16. build/torch210-cxx11-cu130-aarch64-linux/flash_attention_hopper/__init__.py +1 -0
  17. build/torch210-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  18. build/torch210-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  19. build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/_C.abi3.so +3 -0
  20. build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/__init__.py +1 -0
  21. build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  22. build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  23. build/torch29-cxx11-cu126-aarch64-linux/flash_attention_hopper/__init__.py +1 -0
  24. build/torch29-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  25. build/torch29-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  26. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/_C.abi3.so +3 -0
  27. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/__init__.py +1 -0
  28. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  29. build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  30. build/torch29-cxx11-cu128-aarch64-linux/flash_attention_hopper/__init__.py +1 -0
  31. build/torch29-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  32. build/torch29-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  33. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/_C.abi3.so +3 -0
  34. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/__init__.py +1 -0
  35. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  36. build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  37. build/torch29-cxx11-cu130-aarch64-linux/flash_attention_hopper/__init__.py +1 -0
  38. build/torch29-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  39. build/torch29-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
  40. build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/_C.abi3.so +3 -0
  41. build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/__init__.py +1 -0
  42. build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_config.py +7 -0
  43. build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_interface.py +1101 -0
.gitattributes CHANGED
@@ -39,3 +39,9 @@ build/torch210-cxx11-cu130-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs
39
  build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
40
  build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
41
  build/torch29-cxx11-cu130-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
39
  build/torch29-cxx11-cu126-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
40
  build/torch29-cxx11-cu128-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
41
  build/torch29-cxx11-cu130-x86_64-linux/flash_attention_3/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
42
+ build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
43
+ build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
44
+ build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
45
+ build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
46
+ build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
47
+ build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/_C.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch210-cxx11-cu126-aarch64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e722dcebc49024ab2dd2c58c7b27efb9a0618ce104f66b6b21cc0d75bf5de4b6
3
+ size 814568840
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch210-cxx11-cu128-aarch64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e722dcebc49024ab2dd2c58c7b27efb9a0618ce104f66b6b21cc0d75bf5de4b6
3
+ size 814568840
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch210-cxx11-cu130-aarch64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e722dcebc49024ab2dd2c58c7b27efb9a0618ce104f66b6b21cc0d75bf5de4b6
3
+ size 814568840
build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch210-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu126-aarch64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu126-aarch64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e722dcebc49024ab2dd2c58c7b27efb9a0618ce104f66b6b21cc0d75bf5de4b6
3
+ size 814568840
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu126-x86_64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu128-aarch64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu128-aarch64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e722dcebc49024ab2dd2c58c7b27efb9a0618ce104f66b6b21cc0d75bf5de4b6
3
+ size 814568840
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu128-x86_64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu130-aarch64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu130-aarch64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata
build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/_C.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e722dcebc49024ab2dd2c58c7b27efb9a0618ce104f66b6b21cc0d75bf5de4b6
3
+ size 814568840
build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flash_attn_interface import *
build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Auto-generated by flash attention 3 setup.py
2
+ CONFIG = {'build_flags': {'FLASHATTENTION_DISABLE_BACKWARD': False, 'FLASHATTENTION_DISABLE_SPLIT': False, 'FLASHATTENTION_DISABLE_PAGEDKV': False, 'FLASHATTENTION_DISABLE_APPENDKV': False, 'FLASHATTENTION_DISABLE_LOCAL': False, 'FLASHATTENTION_DISABLE_SOFTCAP': False, 'FLASHATTENTION_DISABLE_PACKGQA': False, 'FLASHATTENTION_DISABLE_FP16': True, 'FLASHATTENTION_DISABLE_FP8': False, 'FLASHATTENTION_DISABLE_VARLEN': False, 'FLASHATTENTION_DISABLE_CLUSTER': False, 'FLASHATTENTION_DISABLE_HDIM64': False, 'FLASHATTENTION_DISABLE_HDIM96': False, 'FLASHATTENTION_DISABLE_HDIM128': False, 'FLASHATTENTION_DISABLE_HDIM192': False, 'FLASHATTENTION_DISABLE_HDIM256': False, 'FLASHATTENTION_DISABLE_SM8x': True, 'FLASHATTENTION_ENABLE_VCOLMAJOR': False}}
3
+
4
+ def show():
5
+ from pprint import pprint
6
+ pprint(CONFIG)
7
+
build/torch29-cxx11-cu130-x86_64-linux/flash_attention_hopper/flash_attn_interface.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union, List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ from . import _C # Registers operators with PyTorch
11
+
12
+ # isort: on
13
+
14
+ flash_attn_3_cuda = torch.ops.flash_attn_3
15
+
16
+ def maybe_contiguous(x):
17
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
18
+
19
+
20
+ def round_multiple(x, m):
21
+ return (x + m - 1) // m * m
22
+
23
+
24
+ def round_up_headdim(head_size: int) -> int:
25
+ from .flash_attn_config import CONFIG
26
+
27
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
28
+ if head_size <= 64:
29
+ return 64
30
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
31
+ if head_size <= 96:
32
+ return 96
33
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
34
+ if head_size <= 128:
35
+ return 128
36
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
37
+ if head_size <= 192:
38
+ return 192
39
+ if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
40
+ if head_size <= 256:
41
+ return 256
42
+ return 256
43
+
44
+
45
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
46
+ def _flash_attn_forward(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ k_new: Optional[torch.Tensor] = None,
51
+ v_new: Optional[torch.Tensor] = None,
52
+ qv: Optional[torch.Tensor] = None,
53
+ out: Optional[torch.Tensor] = None,
54
+ cu_seqlens_q: Optional[torch.Tensor] = None,
55
+ cu_seqlens_k: Optional[torch.Tensor] = None,
56
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
57
+ seqused_q: Optional[torch.Tensor] = None,
58
+ seqused_k: Optional[torch.Tensor] = None,
59
+ max_seqlen_q: Optional[int] = None,
60
+ max_seqlen_k: Optional[int] = None,
61
+ page_table: Optional[torch.Tensor] = None,
62
+ kv_batch_idx: Optional[torch.Tensor] = None,
63
+ leftpad_k: Optional[torch.Tensor] = None,
64
+ rotary_cos: Optional[torch.Tensor] = None,
65
+ rotary_sin: Optional[torch.Tensor] = None,
66
+ seqlens_rotary: Optional[torch.Tensor] = None,
67
+ q_descale: Optional[torch.Tensor] = None,
68
+ k_descale: Optional[torch.Tensor] = None,
69
+ v_descale: Optional[torch.Tensor] = None,
70
+ softmax_scale: Optional[float] = None,
71
+ causal: bool = False,
72
+ window_size_left: int = -1,
73
+ window_size_right: int = -1,
74
+ attention_chunk: int = 0,
75
+ softcap: float = 0.0,
76
+ rotary_interleaved: bool = True,
77
+ scheduler_metadata: Optional[torch.Tensor] = None,
78
+ num_splits: int = 1,
79
+ pack_gqa: Optional[bool] = None,
80
+ sm_margin: int = 0,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
83
+ v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
84
+ cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
85
+ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
86
+ ]
87
+ seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
88
+ page_table, kv_batch_idx, leftpad_k = [
89
+ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
90
+ ]
91
+ rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
92
+ seqlens_rotary = maybe_contiguous(seqlens_rotary)
93
+ out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd(
94
+ q,
95
+ k,
96
+ v,
97
+ k_new,
98
+ v_new,
99
+ qv,
100
+ out,
101
+ cu_seqlens_q,
102
+ cu_seqlens_k,
103
+ cu_seqlens_k_new,
104
+ seqused_q,
105
+ seqused_k,
106
+ max_seqlen_q,
107
+ max_seqlen_k,
108
+ page_table,
109
+ kv_batch_idx,
110
+ leftpad_k,
111
+ rotary_cos,
112
+ rotary_sin,
113
+ seqlens_rotary,
114
+ q_descale,
115
+ k_descale,
116
+ v_descale,
117
+ softmax_scale,
118
+ causal,
119
+ window_size_left,
120
+ window_size_right,
121
+ attention_chunk,
122
+ softcap,
123
+ rotary_interleaved,
124
+ scheduler_metadata,
125
+ num_splits,
126
+ pack_gqa,
127
+ sm_margin,
128
+ )
129
+
130
+ if out_accum is None:
131
+ out_accum = torch.tensor([], device=out.device)
132
+
133
+ if softmax_lse_accum is None:
134
+ softmax_lse_accum = torch.tensor([], device=out.device)
135
+
136
+ return out, softmax_lse, out_accum, softmax_lse_accum
137
+
138
+
139
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
140
+ def _flash_attn_forward_fake(
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ k_new: Optional[torch.Tensor] = None,
145
+ v_new: Optional[torch.Tensor] = None,
146
+ qv: Optional[torch.Tensor] = None,
147
+ out: Optional[torch.Tensor] = None,
148
+ cu_seqlens_q: Optional[torch.Tensor] = None,
149
+ cu_seqlens_k: Optional[torch.Tensor] = None,
150
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
151
+ seqused_q: Optional[torch.Tensor] = None,
152
+ seqused_k: Optional[torch.Tensor] = None,
153
+ max_seqlen_q: Optional[int] = None,
154
+ max_seqlen_k: Optional[int] = None,
155
+ page_table: Optional[torch.Tensor] = None,
156
+ kv_batch_idx: Optional[torch.Tensor] = None,
157
+ leftpad_k: Optional[torch.Tensor] = None,
158
+ rotary_cos: Optional[torch.Tensor] = None,
159
+ rotary_sin: Optional[torch.Tensor] = None,
160
+ seqlens_rotary: Optional[torch.Tensor] = None,
161
+ q_descale: Optional[torch.Tensor] = None,
162
+ k_descale: Optional[torch.Tensor] = None,
163
+ v_descale: Optional[torch.Tensor] = None,
164
+ softmax_scale: Optional[float] = None,
165
+ causal: bool = False,
166
+ window_size_left: int = -1,
167
+ window_size_right: int = -1,
168
+ attention_chunk: int = 0,
169
+ softcap: float = 0.0,
170
+ rotary_interleaved: bool = True,
171
+ scheduler_metadata: Optional[torch.Tensor] = None,
172
+ num_splits: int = 1,
173
+ pack_gqa: Optional[bool] = None,
174
+ sm_margin: int = 0,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Symbolic fake implementation of flash attention forward.
178
+ Returns tensors with the correct shapes and dtypes without actual computation.
179
+ """
180
+
181
+ # Determine if we're in varlen mode
182
+ is_varlen_q = cu_seqlens_q is not None
183
+
184
+ # Get dimensions from query tensor
185
+ if is_varlen_q:
186
+ # varlen mode: q is (total_q, num_heads, head_size)
187
+ total_q, num_heads, head_size = q.shape
188
+ batch_size = cu_seqlens_q.shape[0] - 1
189
+
190
+ if max_seqlen_q is None:
191
+ raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
192
+ seqlen_q = max_seqlen_q
193
+ else:
194
+ # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
195
+ batch_size, seqlen_q, num_heads, head_size = q.shape
196
+ total_q = batch_size * q.shape[1]
197
+ # Get value head dimension
198
+ head_size_v = v.shape[-1]
199
+
200
+ # Determine output dtype (FP8 inputs produce BF16 outputs)
201
+ q_type = q.dtype
202
+ if q_type == torch.float8_e4m3fn:
203
+ out_dtype = torch.bfloat16
204
+ else:
205
+ out_dtype = q_type
206
+
207
+ # Create output tensor
208
+ if out is None:
209
+ if is_varlen_q:
210
+ out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
211
+ else:
212
+ out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
213
+
214
+ # Create softmax_lse tensor
215
+ if is_varlen_q:
216
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
217
+ else:
218
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
219
+
220
+ # TODO(guilhermeleobas): Implement "get_num_splits"
221
+ # There's an heuristic to compute num_splits when "num_splits <= 0"
222
+ # assert that num_splits is > 0 for now
223
+ if num_splits <= 0:
224
+ raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
225
+
226
+ if num_splits > 1:
227
+ if is_varlen_q:
228
+ out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
229
+ softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
230
+ else:
231
+ out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
232
+ softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
233
+ else:
234
+ # Tensors are not set when num_splits < 1
235
+ out_accum = torch.tensor([], device=out.device)
236
+ softmax_lse_accum = torch.tensor([], device=out.device)
237
+
238
+ return out, softmax_lse, out_accum, softmax_lse_accum
239
+
240
+
241
+ @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=(), device_types="cuda")
242
+ def _flash_attn_backward(
243
+ dout: torch.Tensor,
244
+ q: torch.Tensor,
245
+ k: torch.Tensor,
246
+ v: torch.Tensor,
247
+ out: torch.Tensor,
248
+ softmax_lse: torch.Tensor,
249
+ cu_seqlens_q: Optional[torch.Tensor] = None,
250
+ cu_seqlens_k: Optional[torch.Tensor] = None,
251
+ sequed_q: Optional[torch.Tensor] = None,
252
+ sequed_k: Optional[torch.Tensor] = None,
253
+ max_seqlen_q: Optional[int] = None,
254
+ max_seqlen_k: Optional[int] = None,
255
+ softmax_scale: Optional[float] = None,
256
+ is_causal: bool = False,
257
+ window_size_left: int = -1,
258
+ window_size_right: int = -1,
259
+ softcap: float = 0.0,
260
+ deterministic: bool = False,
261
+ sm_margin: int = 0,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
263
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
264
+ # C++ now allocates and returns dq, dk, dv
265
+ dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ None, # dq - let C++ allocate
273
+ None, # dk - let C++ allocate
274
+ None, # dv - let C++ allocate
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ sequed_q,
278
+ sequed_k,
279
+ max_seqlen_q,
280
+ max_seqlen_k,
281
+ softmax_scale,
282
+ is_causal,
283
+ window_size_left,
284
+ window_size_right,
285
+ softcap,
286
+ deterministic,
287
+ sm_margin,
288
+ )
289
+ return dq, dk, dv, softmax_d
290
+
291
+
292
+ @torch.library.register_fake("flash_attn_3::_flash_attn_backward")
293
+ def _flash_attn_backward_fake(
294
+ dout: torch.Tensor,
295
+ q: torch.Tensor,
296
+ k: torch.Tensor,
297
+ v: torch.Tensor,
298
+ out: torch.Tensor,
299
+ softmax_lse: torch.Tensor,
300
+ cu_seqlens_q: Optional[torch.Tensor] = None,
301
+ cu_seqlens_k: Optional[torch.Tensor] = None,
302
+ sequed_q: Optional[torch.Tensor] = None,
303
+ sequed_k: Optional[torch.Tensor] = None,
304
+ max_seqlen_q: Optional[int] = None,
305
+ max_seqlen_k: Optional[int] = None,
306
+ softmax_scale: Optional[float] = None,
307
+ is_causal: bool = False,
308
+ window_size_left: int = -1,
309
+ window_size_right: int = -1,
310
+ softcap: float = 0.0,
311
+ deterministic: bool = False,
312
+ sm_margin: int = 0,
313
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
314
+
315
+ is_varlen_q = cu_seqlens_q is not None
316
+ is_varlen_k = cu_seqlens_q is not None
317
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
318
+
319
+ if not is_varlen_q:
320
+ batch_size = q.size(0)
321
+ seqlen_q = q.size(1)
322
+ seqlen_k = k.size(1)
323
+ total_q = batch_size * q.size(1)
324
+ else:
325
+ batch_size = cu_seqlens_q.size(0) - 1
326
+ total_q = q.size(0)
327
+ seqlen_q = max_seqlen_q
328
+ seqlen_k = max_seqlen_k
329
+
330
+ if window_size_left >= seqlen_k - 1:
331
+ window_size_left = -1
332
+
333
+ if window_size_right >= seqlen_q - 1:
334
+ window_size_right = -1
335
+
336
+ if is_causal:
337
+ window_size_right = 0
338
+
339
+ is_causal = window_size_left < 0 and window_size_right == 0
340
+
341
+ head_size = q.size(-1)
342
+ head_size_v = v.size(-1)
343
+ head_size_rounded = round_up_headdim(max(head_size, head_size_v))
344
+
345
+ # Hopper gpus uses cuda compute capabilities 9.0
346
+ cap = torch.cuda.get_device_capability(q.device)
347
+ arch = cap[0] * 10 + cap[1]
348
+
349
+ is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
350
+
351
+ if arch < 90:
352
+ raise ValueError(f"Only cuda compute capabilities 9.0 or newer are supported. Got {arch=}")
353
+
354
+ if head_size_rounded <= 64:
355
+ kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
356
+ elif head_size_rounded <= 96:
357
+ kBlockM_sm90 = 64
358
+ elif head_size_rounded <= 128:
359
+ kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
360
+ else:
361
+ kBlockM_sm90 = 64
362
+
363
+ kBlockM = kBlockM_sm90
364
+
365
+ num_heads = q.shape[-2]
366
+ seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
367
+
368
+ total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
369
+
370
+ # Allocate gradient tensors
371
+ dq = torch.empty_like(q)
372
+ dk = torch.empty_like(k)
373
+ dv = torch.empty_like(v)
374
+
375
+ if not is_varlen:
376
+ softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
377
+ else:
378
+ softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
379
+
380
+ return dq, dk, dv, softmax_d
381
+
382
+
383
+ def setup_context(ctx, inputs, output):
384
+ q, k, v = inputs[:3]
385
+ out, softmax_lse, _, _ = output
386
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
387
+ ctx.softmax_scale = inputs[-11]
388
+ ctx.causal = inputs[-10]
389
+ ctx.window_size = [inputs[-9], inputs[-8]]
390
+ ctx.attention_chunk = inputs[-7]
391
+ ctx.softcap = inputs[-6]
392
+ ctx.sm_margin = inputs[-1]
393
+
394
+
395
+ def _backward(ctx, dout, *grads):
396
+ q, k, v, out, softmax_lse = ctx.saved_tensors
397
+ dq, dk, dv, _ = _flash_attn_backward(
398
+ dout,
399
+ q,
400
+ k,
401
+ v,
402
+ out,
403
+ softmax_lse,
404
+ None, None, # cu_seqlens_q, cu_seqlens_k,
405
+ None, None, # sequed_q, sequed_k,
406
+ None, None, # max_seqlen_q, max_seqlen_k,
407
+ ctx.softmax_scale,
408
+ ctx.causal,
409
+ ctx.window_size[0],
410
+ ctx.window_size[1],
411
+ ctx.softcap,
412
+ False, # deterministic
413
+ ctx.sm_margin,
414
+ )
415
+ return dq, dk, dv, *((None,) * 21)
416
+
417
+
418
+ _flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
419
+
420
+
421
+
422
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
423
+ @staticmethod
424
+ def forward(
425
+ ctx,
426
+ qkv,
427
+ softmax_scale,
428
+ causal,
429
+ q_descale=None, k_descale=None, v_descale=None,
430
+ window_size=(-1, -1),
431
+ attention_chunk=0,
432
+ softcap=0.0,
433
+ deterministic=False,
434
+ num_heads_q=None,
435
+ sm_margin=0,
436
+ return_softmax=False,
437
+ ):
438
+ if softmax_scale is None:
439
+ softmax_scale = qkv.shape[-1] ** (-0.5)
440
+ if qkv.dim() == 5:
441
+ assert qkv.shape[-3] == 3
442
+ q, k, v = qkv.unbind(dim=-3)
443
+ else:
444
+ assert qkv.dim() == 4
445
+ assert num_heads_q is not None
446
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
447
+ assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
448
+ q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
449
+ out, softmax_lse, *rest = _flash_attn_forward(
450
+ q,
451
+ k,
452
+ v,
453
+ None, None, # k_new, v_new
454
+ None, # qv
455
+ None, # out
456
+ None, None, None, # cu_seqlens_q/k/k_new
457
+ None, None, # seqused_q/k
458
+ None, None, # max_seqlen_q/k
459
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
460
+ None, None, None, # rotary_cos/sin, seqlens_rotary
461
+ q_descale, k_descale, v_descale,
462
+ softmax_scale,
463
+ causal=causal,
464
+ window_size_left=window_size[0],
465
+ window_size_right=window_size[1],
466
+ attention_chunk=attention_chunk,
467
+ softcap=softcap,
468
+ sm_margin=sm_margin,
469
+ )
470
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
471
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
472
+ ctx.softmax_scale = softmax_scale
473
+ ctx.causal = causal
474
+ ctx.window_size = window_size
475
+ ctx.attention_chunk = attention_chunk
476
+ ctx.softcap = softcap
477
+ ctx.deterministic = deterministic
478
+ ctx.ndim = qkv.dim()
479
+ ctx.sm_margin = sm_margin
480
+ return (out, softmax_lse) if return_softmax else out
481
+
482
+ @staticmethod
483
+ def backward(ctx, dout, *args):
484
+ q, k, v, out, softmax_lse = ctx.saved_tensors
485
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
486
+ # Get gradients from the backward function
487
+ dq, dk, dv, _ = _flash_attn_backward(
488
+ dout,
489
+ q,
490
+ k,
491
+ v,
492
+ out,
493
+ softmax_lse,
494
+ None, None, # cu_seqlens_q, cu_seqlens_k,
495
+ None, None, # sequed_q, sequed_k,
496
+ None, None, # max_seqlen_q, max_seqlen_k,
497
+ ctx.softmax_scale,
498
+ ctx.causal,
499
+ ctx.window_size[0],
500
+ ctx.window_size[1],
501
+ ctx.softcap,
502
+ ctx.deterministic,
503
+ ctx.sm_margin,
504
+ )
505
+ # Pack the gradients into the expected format
506
+ if ctx.ndim == 5:
507
+ dqkv = torch.stack([dq, dk, dv], dim=-3)
508
+ else:
509
+ # Concatenate along the heads dimension
510
+ dqkv = torch.cat([dq, dk, dv], dim=-2)
511
+ dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
512
+ return dqkv, None, None, None, None, None, None, None, None, None, None, None
513
+
514
+
515
+ class FlashAttnFunc(torch.autograd.Function):
516
+
517
+ @staticmethod
518
+ def forward(
519
+ ctx,
520
+ q,
521
+ k,
522
+ v,
523
+ softmax_scale,
524
+ causal,
525
+ qv=None,
526
+ q_descale=None, k_descale=None, v_descale=None,
527
+ window_size=(-1, -1),
528
+ attention_chunk=0,
529
+ softcap=0.0,
530
+ num_splits=1,
531
+ pack_gqa=None,
532
+ deterministic=False,
533
+ sm_margin=0,
534
+ return_softmax=False,
535
+ ):
536
+ if softmax_scale is None:
537
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
538
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
539
+ out, softmax_lse, *rest = _flash_attn_forward(
540
+ q,
541
+ k,
542
+ v,
543
+ None, None, # k_new, v_new
544
+ qv, # qv
545
+ None, # out
546
+ None, None, None, # cu_seqlens_q/k/k_new
547
+ None, None, # seqused_q/k
548
+ None, None, # max_seqlen_q/k
549
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
550
+ None, None, None, # rotary_cos/sin, seqlens_rotary
551
+ q_descale, k_descale, v_descale,
552
+ softmax_scale,
553
+ causal=causal,
554
+ window_size_left=window_size[0],
555
+ window_size_right=window_size[1],
556
+ attention_chunk=attention_chunk,
557
+ softcap=softcap,
558
+ num_splits=num_splits,
559
+ pack_gqa=pack_gqa,
560
+ sm_margin=sm_margin,
561
+ )
562
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
563
+ ctx.save_for_backward(q, k, v, out, softmax_lse)
564
+ ctx.softmax_scale = softmax_scale
565
+ ctx.causal = causal
566
+ ctx.window_size = window_size
567
+ ctx.attention_chunk = attention_chunk
568
+ ctx.softcap = softcap
569
+ ctx.deterministic = deterministic
570
+ ctx.sm_margin = sm_margin
571
+ return (out, softmax_lse) if return_softmax else out
572
+
573
+ @staticmethod
574
+ def backward(ctx, dout, *args):
575
+ q, k, v, out, softmax_lse = ctx.saved_tensors
576
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
577
+ dq, dk, dv, _ = _flash_attn_backward(
578
+ dout,
579
+ q,
580
+ k,
581
+ v,
582
+ out,
583
+ softmax_lse,
584
+ None, None, # cu_seqlens_q, cu_seqlens_k,
585
+ None, None, # sequed_q, sequed_k,
586
+ None, None, # max_seqlen_q, max_seqlen_k,
587
+ ctx.softmax_scale,
588
+ ctx.causal,
589
+ ctx.window_size[0],
590
+ ctx.window_size[1],
591
+ ctx.softcap,
592
+ ctx.deterministic,
593
+ ctx.sm_margin,
594
+ )
595
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
596
+ dk = dk[..., : k.shape[-1]]
597
+ dv = dv[..., : v.shape[-1]]
598
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
599
+
600
+
601
+ class FlashAttnVarlenFunc(torch.autograd.Function):
602
+
603
+ @staticmethod
604
+ def forward(
605
+ ctx,
606
+ q,
607
+ k,
608
+ v,
609
+ cu_seqlens_q,
610
+ cu_seqlens_k,
611
+ seqused_q,
612
+ seqused_k,
613
+ max_seqlen_q,
614
+ max_seqlen_k,
615
+ softmax_scale,
616
+ causal,
617
+ qv=None,
618
+ q_descale=None, k_descale=None, v_descale=None,
619
+ window_size=(-1, -1),
620
+ attention_chunk=0,
621
+ softcap=0.0,
622
+ num_splits=1,
623
+ pack_gqa=None,
624
+ deterministic=False,
625
+ sm_margin=0,
626
+ return_softmax=False,
627
+ ):
628
+ if softmax_scale is None:
629
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
630
+ # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
631
+ out, softmax_lse, *rest = _flash_attn_forward(
632
+ q,
633
+ k,
634
+ v,
635
+ None, None, # k_new, v_new
636
+ qv, # qv
637
+ None, # out
638
+ cu_seqlens_q,
639
+ cu_seqlens_k,
640
+ None, # cu_seqlens_k_new
641
+ seqused_q,
642
+ seqused_k,
643
+ max_seqlen_q,
644
+ max_seqlen_k,
645
+ None, None, None, # page_table, kv_batch_idx, leftpad_k,
646
+ None, None, None, # rotary_cos/sin, seqlens_rotary
647
+ q_descale, k_descale, v_descale,
648
+ softmax_scale,
649
+ causal=causal,
650
+ window_size_left=window_size[0],
651
+ window_size_right=window_size[1],
652
+ attention_chunk=attention_chunk,
653
+ softcap=softcap,
654
+ num_splits=num_splits,
655
+ pack_gqa=pack_gqa,
656
+ sm_margin=sm_margin,
657
+ )
658
+ # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
659
+ ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
660
+ ctx.max_seqlen_q = max_seqlen_q
661
+ ctx.max_seqlen_k = max_seqlen_k
662
+ ctx.softmax_scale = softmax_scale
663
+ ctx.causal = causal
664
+ ctx.window_size = window_size
665
+ ctx.attention_chunk = attention_chunk
666
+ ctx.softcap = softcap
667
+ ctx.deterministic = deterministic
668
+ ctx.sm_margin = sm_margin
669
+ return (out, softmax_lse) if return_softmax else out
670
+
671
+ @staticmethod
672
+ def backward(ctx, dout, *args):
673
+ q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
674
+ assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
675
+ dq, dk, dv, _ = _flash_attn_backward(
676
+ dout,
677
+ q,
678
+ k,
679
+ v,
680
+ out,
681
+ softmax_lse,
682
+ cu_seqlens_q,
683
+ cu_seqlens_k,
684
+ seqused_q,
685
+ seqused_k,
686
+ ctx.max_seqlen_q,
687
+ ctx.max_seqlen_k,
688
+ ctx.softmax_scale,
689
+ ctx.causal,
690
+ ctx.window_size[0],
691
+ ctx.window_size[1],
692
+ ctx.softcap,
693
+ ctx.deterministic,
694
+ ctx.sm_margin,
695
+ )
696
+ dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
697
+ dk = dk[..., : k.shape[-1]]
698
+ dv = dv[..., : v.shape[-1]]
699
+ return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
700
+
701
+
702
+ def flash_attn_qkvpacked_func(
703
+ qkv,
704
+ softmax_scale=None,
705
+ causal=False,
706
+ q_descale=None, k_descale=None, v_descale=None,
707
+ window_size=(-1, -1),
708
+ attention_chunk=0,
709
+ softcap=0.0,
710
+ deterministic=False,
711
+ num_heads_q=None,
712
+ sm_margin=0,
713
+ return_attn_probs=False,
714
+ ):
715
+ """dropout_p should be set to 0.0 during evaluation
716
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
717
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
718
+ of the gradients of Q, K, V.
719
+ For multi-query and grouped-query attention (MQA/GQA), please see
720
+ flash_attn_kvpacked_func and flash_attn_func.
721
+
722
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
723
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
724
+
725
+ Arguments:
726
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
727
+ dropout_p: float. Dropout probability.
728
+ softmax_scale: float. The scaling of QK^T before applying softmax.
729
+ Default to 1 / sqrt(headdim).
730
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
731
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
732
+ softcap: float. Anything > 0 activates softcapping attention.
733
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
734
+ the attention score of query i and key j.
735
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
736
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
737
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
738
+ testing only. The returned probabilities are not guaranteed to be correct
739
+ (they might not have the right scaling).
740
+ Return:
741
+ out: (batch_size, seqlen, nheads, headdim).
742
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
743
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
744
+ normalization factor).
745
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
746
+ The output of softmax (possibly with different scaling). It also encodes the dropout
747
+ pattern (negative means that location was dropped, nonnegative means it was kept).
748
+ """
749
+ return FlashAttnQKVPackedFunc.apply(
750
+ qkv,
751
+ softmax_scale,
752
+ causal,
753
+ q_descale, k_descale, v_descale,
754
+ window_size,
755
+ attention_chunk,
756
+ softcap,
757
+ deterministic,
758
+ num_heads_q,
759
+ sm_margin,
760
+ return_attn_probs,
761
+ )
762
+
763
+
764
+ def flash_attn_func(
765
+ q,
766
+ k,
767
+ v,
768
+ softmax_scale=None,
769
+ causal=False,
770
+ qv=None,
771
+ q_descale=None, k_descale=None, v_descale=None,
772
+ window_size=(-1, -1),
773
+ attention_chunk=0,
774
+ softcap=0.0,
775
+ num_splits=1,
776
+ pack_gqa=None,
777
+ deterministic=False,
778
+ sm_margin=0,
779
+ return_attn_probs=False,
780
+ ):
781
+ """dropout_p should be set to 0.0 during evaluation
782
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
783
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
784
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
785
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
786
+
787
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
788
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
789
+ 1 1 1 1 0
790
+ 1 1 1 1 1
791
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
792
+ 0 0
793
+ 0 0
794
+ 0 0
795
+ 1 0
796
+ 1 1
797
+ If the row of the mask is all zero, the output will be zero.
798
+
799
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
800
+ will only attend to keys between
801
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
802
+
803
+ Arguments:
804
+ q: (batch_size, seqlen, nheads, headdim)
805
+ k: (batch_size, seqlen, nheads_k, headdim)
806
+ v: (batch_size, seqlen, nheads_k, headdim)
807
+ dropout_p: float. Dropout probability.
808
+ softmax_scale: float. The scaling of QK^T before applying softmax.
809
+ Default to 1 / sqrt(headdim).
810
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
811
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
812
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
813
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
814
+ is added to the attention score of query i and key j.
815
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
816
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
817
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
818
+ testing only. The returned probabilities are not guaranteed to be correct
819
+ (they might not have the right scaling).
820
+ Return:
821
+ out: (batch_size, seqlen, nheads, headdim).
822
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
823
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
824
+ normalization factor).
825
+ """
826
+ return FlashAttnFunc.apply(
827
+ q,
828
+ k,
829
+ v,
830
+ softmax_scale,
831
+ causal,
832
+ qv,
833
+ q_descale, k_descale, v_descale,
834
+ window_size,
835
+ attention_chunk,
836
+ softcap,
837
+ num_splits,
838
+ pack_gqa,
839
+ deterministic,
840
+ sm_margin,
841
+ return_attn_probs,
842
+ )
843
+
844
+
845
+ def flash_attn_varlen_func(
846
+ q,
847
+ k,
848
+ v,
849
+ cu_seqlens_q,
850
+ cu_seqlens_k,
851
+ max_seqlen_q,
852
+ max_seqlen_k,
853
+ seqused_q=None,
854
+ seqused_k=None,
855
+ softmax_scale=None,
856
+ causal=False,
857
+ qv=None,
858
+ q_descale=None, k_descale=None, v_descale=None,
859
+ window_size=(-1, -1),
860
+ attention_chunk=0,
861
+ softcap=0.0,
862
+ num_splits=1,
863
+ pack_gqa=None,
864
+ deterministic=False,
865
+ sm_margin=0,
866
+ return_attn_probs=False,
867
+ ):
868
+ return FlashAttnVarlenFunc.apply(
869
+ q,
870
+ k,
871
+ v,
872
+ cu_seqlens_q,
873
+ cu_seqlens_k,
874
+ seqused_q,
875
+ seqused_k,
876
+ max_seqlen_q,
877
+ max_seqlen_k,
878
+ softmax_scale,
879
+ causal,
880
+ qv,
881
+ q_descale, k_descale, v_descale,
882
+ window_size,
883
+ attention_chunk,
884
+ softcap,
885
+ num_splits,
886
+ pack_gqa,
887
+ deterministic,
888
+ sm_margin,
889
+ return_attn_probs,
890
+ )
891
+
892
+
893
+ def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
894
+ return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype)
895
+
896
+
897
+ def flash_attn_with_kvcache(
898
+ q,
899
+ k_cache,
900
+ v_cache,
901
+ k=None,
902
+ v=None,
903
+ qv=None,
904
+ rotary_cos=None,
905
+ rotary_sin=None,
906
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
907
+ cache_batch_idx: Optional[torch.Tensor] = None,
908
+ cache_leftpad: Optional[torch.Tensor] = None,
909
+ page_table: Optional[torch.Tensor] = None,
910
+ cu_seqlens_q: Optional[torch.Tensor] = None,
911
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
912
+ max_seqlen_q: Optional[int] = None,
913
+ rotary_seqlens: Optional[torch.Tensor] = None,
914
+ q_descale: Optional[torch.Tensor] = None,
915
+ k_descale: Optional[torch.Tensor] = None,
916
+ v_descale: Optional[torch.Tensor] = None,
917
+ softmax_scale=None,
918
+ causal=False,
919
+ window_size=(-1, -1), # -1 means infinite context window
920
+ attention_chunk=0,
921
+ softcap=0.0, # 0.0 means deactivated
922
+ rotary_interleaved=True,
923
+ scheduler_metadata=None,
924
+ num_splits=0, # Can be tuned for speed
925
+ pack_gqa=None, # Can be tuned for speed
926
+ sm_margin=0, # Can be tuned if some SMs are used for communication
927
+ return_softmax_lse=False,
928
+ ):
929
+ """
930
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
931
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
932
+ the previous step, and update them with the new keys/values from the current step, and do
933
+ attention with the updated cache, all in 1 kernel.
934
+
935
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
936
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
937
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
938
+
939
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
940
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
941
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
942
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
943
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
944
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
945
+
946
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
947
+
948
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
949
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
950
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
951
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
952
+
953
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
954
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
955
+ 1 1 1 1 0
956
+ 1 1 1 1 1
957
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
958
+ 0 0
959
+ 0 0
960
+ 0 0
961
+ 1 0
962
+ 1 1
963
+ If the row of the mask is all zero, the output will be zero.
964
+
965
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
966
+ will only attend to keys between
967
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
968
+
969
+ Note: Does not support backward pass.
970
+
971
+ Arguments:
972
+ q: (batch_size, seqlen, nheads, headdim)
973
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
974
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
975
+ page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
976
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
977
+ or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
978
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
979
+ k with k_cache, starting at the indices specified by cache_seqlens.
980
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
981
+ qv [optional]: (batch_size, seqlen, nheads, headdim_v)
982
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
983
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
984
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
985
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
986
+ KV cache.
987
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
988
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
989
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
990
+ might come from any of the duplicate indices.
991
+ cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
992
+ page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
993
+ softmax_scale: float. The scaling of QK^T before applying softmax.
994
+ Default to 1 / sqrt(headdim).
995
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
996
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
997
+ softcap: float. Anything > 0 activates softcapping attention.
998
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
999
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1000
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1001
+ (i.e. GPT-NeoX style).
1002
+ num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1003
+ If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1004
+ to automatically determine the number of splits.
1005
+ Don't change this unless you know what you are doing.
1006
+ return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1007
+
1008
+ Return:
1009
+ out: (batch_size, seqlen, nheads, headdim).
1010
+ softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1011
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1012
+ normalization factor).
1013
+ """
1014
+ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1015
+ assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1016
+ if softmax_scale is None:
1017
+ softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
1018
+ if cache_seqlens is not None and isinstance(cache_seqlens, int):
1019
+ cache_seqlens = torch.full(
1020
+ (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1021
+ )
1022
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1023
+ out, softmax_lse, *rest = _flash_attn_forward(
1024
+ q,
1025
+ k_cache,
1026
+ v_cache,
1027
+ k,
1028
+ v,
1029
+ qv,
1030
+ None, # out
1031
+ cu_seqlens_q,
1032
+ None, # cu_seqlens_k
1033
+ cu_seqlens_k_new,
1034
+ None, # seqused_q
1035
+ cache_seqlens,
1036
+ max_seqlen_q,
1037
+ None, # max_seqlen_k
1038
+ page_table,
1039
+ cache_batch_idx,
1040
+ cache_leftpad,
1041
+ rotary_cos,
1042
+ rotary_sin,
1043
+ rotary_seqlens,
1044
+ q_descale, k_descale, v_descale,
1045
+ softmax_scale,
1046
+ causal=causal,
1047
+ window_size_left=window_size[0],
1048
+ window_size_right=window_size[1],
1049
+ attention_chunk=attention_chunk,
1050
+ softcap=softcap,
1051
+ rotary_interleaved=rotary_interleaved,
1052
+ scheduler_metadata=scheduler_metadata,
1053
+ num_splits=num_splits,
1054
+ pack_gqa=pack_gqa,
1055
+ sm_margin=sm_margin,
1056
+ )
1057
+ # return (out, softmax_lse) if return_softmax_lse else out
1058
+ return (out, softmax_lse, *rest) if return_softmax_lse else out
1059
+
1060
+
1061
+ def get_scheduler_metadata(
1062
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
1063
+ cache_seqlens: torch.Tensor,
1064
+ qkv_dtype=torch.bfloat16,
1065
+ headdim_v=None,
1066
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1067
+ cu_seqlens_k_new: Optional[torch.Tensor] = None,
1068
+ cache_leftpad: Optional[torch.Tensor] = None,
1069
+ page_size: Optional[int] = None,
1070
+ max_seqlen_k_new=0,
1071
+ causal=False,
1072
+ window_size=(-1, -1), # -1 means infinite context window
1073
+ attention_chunk=0,
1074
+ has_softcap=False,
1075
+ num_splits=0, # Can be tuned for speed
1076
+ pack_gqa=None, # Can be tuned for speed
1077
+ sm_margin=0, # Can be tuned if some SMs are used for communication
1078
+ ):
1079
+ cache_seqlens = maybe_contiguous(cache_seqlens)
1080
+ if headdim_v is None:
1081
+ headdim_v = headdim
1082
+ scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata(
1083
+ batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
1084
+ qkv_dtype,
1085
+ cache_seqlens,
1086
+ cu_seqlens_q,
1087
+ None, # cu_seqlens_k
1088
+ cu_seqlens_k_new,
1089
+ None, # seqused_q
1090
+ cache_leftpad,
1091
+ page_size,
1092
+ max_seqlen_k_new,
1093
+ causal,
1094
+ window_size[0], window_size[1],
1095
+ attention_chunk,
1096
+ has_softcap,
1097
+ num_splits,
1098
+ pack_gqa,
1099
+ sm_margin,
1100
+ )
1101
+ return scheduler_metadata