Kernels
This view is limited to 50 files because it contains too many changes.Β  See the raw diff here.
Files changed (50) hide show
  1. README.md +3 -3
  2. benchmark.py +0 -17
  3. benchmarks/benchmark.py +0 -17
  4. build/torch210-cxx11-cpu-x86_64-linux/flash_attn2/__init__.py +0 -26
  5. build/torch210-cxx11-cpu-x86_64-linux/metadata.json +0 -4
  6. build/torch210-cxx11-cu126-x86_64-linux/flash_attn2/__init__.py +0 -26
  7. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +0 -4
  8. build/torch210-cxx11-cu128-x86_64-linux/flash_attn2/__init__.py +0 -26
  9. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +0 -4
  10. build/torch210-cxx11-cu130-x86_64-linux/flash_attn2/__init__.py +0 -26
  11. build/torch210-cxx11-cu130-x86_64-linux/flash_attn_interface.py +0 -1620
  12. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +0 -4
  13. build/torch210-cxx11-cu130-x86_64-linux/ops/triton/rotary.py +0 -186
  14. build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py +0 -9
  15. build/torch210-cxx11-xpu20253-x86_64-linux/flash_attn2/__init__.py +0 -26
  16. build/torch210-cxx11-xpu20253-x86_64-linux/flash_attn_interface.py +0 -1620
  17. build/torch210-cxx11-xpu20253-x86_64-linux/metadata.json +0 -4
  18. build/torch210-cxx11-xpu20253-x86_64-linux/ops/triton/rotary.py +0 -186
  19. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/__init__.py +0 -0
  20. build/{torch210-cxx11-cu126-x86_64-linux/_flash_attn2_588b404.abi3.so β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so} +2 -2
  21. build/{torch210-cxx11-cu128-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/_ops.py +3 -3
  22. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/bert_padding.py +0 -0
  23. build/torch27-cxx11-cu118-x86_64-linux/{flash_attn2 β†’ flash_attn}/flash_attn_interface.py +0 -0
  24. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/layers/__init__.py +0 -0
  25. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/layers/patch_embed.py +0 -0
  26. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/layers/rotary.py +0 -0
  27. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/__init__.py +0 -0
  28. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/activations.py +0 -0
  29. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/fused_dense.py +0 -0
  30. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/layer_norm.py +0 -0
  31. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/rms_norm.py +0 -0
  32. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/__init__.py +0 -0
  33. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/cross_entropy.py +0 -0
  34. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/k_activations.py +0 -0
  35. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/layer_norm.py +0 -0
  36. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/linear.py +0 -0
  37. build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/mlp.py +0 -0
  38. build/torch27-cxx11-cu118-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/rotary.py +0 -0
  39. build/torch27-cxx11-cu118-x86_64-linux/flash_attn2/_flash_attn_9e27194.abi3.so +0 -3
  40. build/torch27-cxx11-cu118-x86_64-linux/flash_attn2/_ops.py +0 -9
  41. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/__init__.py +0 -0
  42. build/{torch210-cxx11-xpu20253-x86_64-linux/_flash_attn2_588b404.abi3.so β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so} +2 -2
  43. build/{torch210-cxx11-cu130-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/_ops.py +3 -3
  44. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/bert_padding.py +0 -0
  45. build/torch27-cxx11-cu126-x86_64-linux/{flash_attn2 β†’ flash_attn}/flash_attn_interface.py +0 -0
  46. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/layers/__init__.py +0 -0
  47. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/layers/patch_embed.py +0 -0
  48. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/layers/rotary.py +0 -0
  49. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/ops/__init__.py +0 -0
  50. build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/ops/activations.py +0 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  license: bsd-3-clause
3
  tags:
4
- - kernels
5
  ---
6
 
7
- <!-- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/flash-attn2) -->
8
 
9
  # Flash Attention
10
 
@@ -27,7 +27,7 @@ from kernels import get_kernel
27
 
28
  # Setup
29
  torch.manual_seed(42)
30
- flash_attn = get_kernel("kernels-community/flash-attn2")
31
  device = torch.device("cuda")
32
 
33
  # Create test tensors
 
1
  ---
2
  license: bsd-3-clause
3
  tags:
4
+ - kernel
5
  ---
6
 
7
+ <!-- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/flash-attn) -->
8
 
9
  # Flash Attention
10
 
 
27
 
28
  # Setup
29
  torch.manual_seed(42)
30
+ flash_attn = get_kernel("kernels-community/flash-attn")
31
  device = torch.device("cuda")
32
 
33
  # Create test tensors
benchmark.py DELETED
@@ -1,17 +0,0 @@
1
- from kernels.benchmarks import (
2
- FlashAttentionBenchmark,
3
- FlashAttentionCausalBenchmark,
4
- FlashAttentionVarlenBenchmark,
5
- )
6
-
7
-
8
- class FlashAttn(FlashAttentionBenchmark):
9
- pass
10
-
11
-
12
- class FlashAttnCausal(FlashAttentionCausalBenchmark):
13
- pass
14
-
15
-
16
- class FlashAttnVarlen(FlashAttentionVarlenBenchmark):
17
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/benchmark.py DELETED
@@ -1,17 +0,0 @@
1
- from kernels.benchmarks import (
2
- FlashAttentionBenchmark,
3
- FlashAttentionCausalBenchmark,
4
- FlashAttentionVarlenBenchmark,
5
- )
6
-
7
-
8
- class FlashAttn(FlashAttentionBenchmark):
9
- pass
10
-
11
-
12
- class FlashAttnCausal(FlashAttentionCausalBenchmark):
13
- pass
14
-
15
-
16
- class FlashAttnVarlen(FlashAttentionVarlenBenchmark):
17
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cpu-x86_64-linux/flash_attn2/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cpu-x86_64-linux/metadata.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "version": 1,
3
- "python-depends": []
4
- }
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/flash_attn2/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu126-x86_64-linux/metadata.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "version": 1,
3
- "python-depends": []
4
- }
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/flash_attn2/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu128-x86_64-linux/metadata.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "version": 1,
3
- "python-depends": []
4
- }
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/flash_attn2/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/flash_attn_interface.py DELETED
@@ -1,1620 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Sequence, Tuple, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
- import os
8
-
9
- # # isort: off
10
- # # We need to import the CUDA kernels after importing torch
11
- # USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
12
- # if USE_TRITON_ROCM:
13
- # from .flash_attn_triton_amd import interface_fa as flash_attn
14
- # else:
15
- # import flash_attn_2_cuda as flash_attn
16
-
17
-
18
- from ._ops import ops as flash_attn
19
-
20
- # # isort: on
21
-
22
- def maybe_contiguous(x):
23
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
24
-
25
-
26
- def _get_device():
27
- if torch.xpu.is_available():
28
- return "xpu"
29
- elif torch.cuda.is_available():
30
- return "cuda"
31
- else:
32
- return "cpu"
33
-
34
- _XPU_AVAILABLE = torch.xpu.is_available() if hasattr(torch, "xpu") else False # TODO remove hasattr check when bwd is supported on XPU
35
-
36
-
37
- def _get_block_size_n(device, head_dim, is_dropout, is_causal):
38
- # This should match the block sizes in the CUDA kernel
39
- assert head_dim <= 256
40
- major, minor = torch.cuda.get_device_capability(device)
41
- is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
42
- is_sm80 = major == 8 and minor == 0
43
- is_sm90 = major == 9 and minor == 0
44
- if head_dim <= 32:
45
- return 128
46
- if head_dim <= 64:
47
- return 128 if not is_dropout else 64
48
- elif head_dim <= 96:
49
- return 64
50
- elif head_dim <= 128:
51
- if is_sm8x:
52
- return 64 if (not is_dropout and is_causal) else 32
53
- else:
54
- return 64 if not is_dropout else 32
55
- elif head_dim <= 192:
56
- return 64
57
- elif head_dim <= 224:
58
- return 64
59
- elif head_dim <= 256:
60
- return 64
61
-
62
-
63
- def round_multiple(x, m):
64
- return (x + m - 1) // m * m
65
-
66
-
67
- # torch.compile() support is only enabled for pytorch >= 2.4
68
- # The reason for this is that we are using the new custom_op and register_fake
69
- # APIs, which support inplace modification of inputs in the function itself
70
- if torch.__version__ >= "2.4.0":
71
- _torch_custom_op_wrapper = torch.library.custom_op
72
- _torch_register_fake_wrapper = torch.library.register_fake
73
- else:
74
- def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
75
- def wrap(func):
76
- return func
77
- if fn is None:
78
- return wrap
79
- return fn
80
- def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
81
- def wrap(func):
82
- return func
83
- if fn is None:
84
- return wrap
85
- return fn
86
- _torch_custom_op_wrapper = noop_custom_op_wrapper
87
- _torch_register_fake_wrapper = noop_register_fake_wrapper
88
-
89
-
90
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types=_get_device())
91
- def _flash_attn_forward(
92
- q: torch.Tensor,
93
- k: torch.Tensor,
94
- v: torch.Tensor,
95
- dropout_p: float,
96
- softmax_scale: float,
97
- causal: bool,
98
- window_size_left: int,
99
- window_size_right: int,
100
- softcap: float,
101
- alibi_slopes: Optional[torch.Tensor],
102
- return_softmax: bool
103
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
105
- out, softmax_lse, S_dmask, rng_state = flash_attn.fwd(
106
- q,
107
- k,
108
- v,
109
- None,
110
- alibi_slopes,
111
- dropout_p,
112
- softmax_scale,
113
- causal,
114
- window_size_left,
115
- window_size_right,
116
- softcap,
117
- return_softmax,
118
- None,
119
- )
120
- return out, softmax_lse, S_dmask, rng_state
121
-
122
-
123
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
124
- def _flash_attn_forward_fake(
125
- q: torch.Tensor,
126
- k: torch.Tensor,
127
- v: torch.Tensor,
128
- dropout_p: float,
129
- softmax_scale: float,
130
- causal: bool,
131
- window_size_left: int,
132
- window_size_right: int,
133
- softcap: float,
134
- alibi_slopes: Optional[torch.Tensor],
135
- return_softmax: bool
136
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
137
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
138
- batch_size, seqlen_q, num_heads, head_size = q.shape
139
- seqlen_k = k.shape[1]
140
- out = torch.empty_like(q)
141
- softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
142
- p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
143
- if return_softmax:
144
- p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
145
- rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
146
-
147
- return out, softmax_lse, p, rng_state
148
-
149
-
150
- if torch.__version__ >= "2.4.0":
151
- _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
152
- else:
153
- _wrapped_flash_attn_forward = _flash_attn_forward
154
-
155
-
156
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types=_get_device())
157
- def _flash_attn_varlen_forward(
158
- q: torch.Tensor,
159
- k: torch.Tensor,
160
- v: torch.Tensor,
161
- cu_seqlens_q: torch.Tensor,
162
- cu_seqlens_k: torch.Tensor,
163
- max_seqlen_q: int,
164
- max_seqlen_k: int,
165
- dropout_p: float,
166
- softmax_scale: float,
167
- causal: bool,
168
- window_size_left: int = -1,
169
- window_size_right: int = -1,
170
- softcap: float = 0.0,
171
- alibi_slopes: Optional[torch.Tensor] = None,
172
- return_softmax: bool = False,
173
- block_table: Optional[torch.Tensor] = None,
174
- leftpad_k: Optional[torch.Tensor] = None,
175
- seqused_k: Optional[torch.Tensor] = None,
176
- zero_tensors: bool = False,
177
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
178
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
179
- out, softmax_lse, S_dmask, rng_state = flash_attn.varlen_fwd(
180
- q,
181
- k,
182
- v,
183
- None,
184
- cu_seqlens_q,
185
- cu_seqlens_k,
186
- seqused_k,
187
- leftpad_k,
188
- block_table,
189
- alibi_slopes,
190
- max_seqlen_q,
191
- max_seqlen_k,
192
- dropout_p,
193
- softmax_scale,
194
- zero_tensors,
195
- causal,
196
- window_size_left,
197
- window_size_right,
198
- softcap,
199
- return_softmax,
200
- None,
201
- )
202
- # if out.isnan().any() or softmax_lse.isnan().any():
203
- # breakpoint()
204
- return out, softmax_lse, S_dmask, rng_state
205
-
206
-
207
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
208
- def _flash_attn_varlen_forward_fake(
209
- q: torch.Tensor,
210
- k: torch.Tensor,
211
- v: torch.Tensor,
212
- cu_seqlens_q: torch.Tensor,
213
- cu_seqlens_k: torch.Tensor,
214
- max_seqlen_q: int,
215
- max_seqlen_k: int,
216
- dropout_p: float,
217
- softmax_scale: float,
218
- causal: bool,
219
- window_size_left: int = -1,
220
- window_size_right: int = -1,
221
- softcap: float = 0.0,
222
- alibi_slopes: Optional[torch.Tensor] = None,
223
- return_softmax: bool = False,
224
- block_table: Optional[torch.Tensor] = None,
225
- leftpad_k: Optional[torch.Tensor] = None,
226
- seqused_k: Optional[torch.Tensor] = None,
227
- zero_tensors: bool = False,
228
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
229
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
230
- paged_kv = block_table is not None
231
- batch_size = cu_seqlens_q.numel() - 1
232
- total_q, num_heads, _ = q.shape
233
-
234
- out = torch.empty_like(q)
235
- softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
236
- p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
237
- seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
238
- seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
239
- if return_softmax:
240
- p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
241
- rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
242
- return out, softmax_lse, p, rng_state
243
-
244
-
245
- if torch.__version__ >= "2.4.0":
246
- _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
247
- else:
248
- _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
249
-
250
-
251
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
252
- def _flash_attn_backward(
253
- dout: torch.Tensor,
254
- q: torch.Tensor,
255
- k: torch.Tensor,
256
- v: torch.Tensor,
257
- out: torch.Tensor,
258
- softmax_lse: torch.Tensor,
259
- dq: Optional[torch.Tensor],
260
- dk: Optional[torch.Tensor],
261
- dv: Optional[torch.Tensor],
262
- dropout_p: float,
263
- softmax_scale: float,
264
- causal: bool,
265
- window_size_left: int,
266
- window_size_right: int,
267
- softcap: float,
268
- alibi_slopes: Optional[torch.Tensor],
269
- deterministic: bool,
270
- rng_state: Optional[torch.Tensor] = None,
271
- ) -> torch.Tensor:
272
- # dq, dk, dv are allocated by us so they should already be contiguous
273
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
274
- (
275
- dq,
276
- dk,
277
- dv,
278
- softmax_d,
279
- ) = flash_attn.bwd(
280
- dout,
281
- q,
282
- k,
283
- v,
284
- out,
285
- softmax_lse,
286
- dq,
287
- dk,
288
- dv,
289
- alibi_slopes,
290
- dropout_p,
291
- softmax_scale,
292
- causal,
293
- window_size_left,
294
- window_size_right,
295
- softcap,
296
- deterministic,
297
- None,
298
- rng_state,
299
- )
300
- return softmax_d
301
-
302
-
303
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
304
- def _flash_attn_backward_fake(
305
- dout: torch.Tensor,
306
- q: torch.Tensor,
307
- k: torch.Tensor,
308
- v: torch.Tensor,
309
- out: torch.Tensor,
310
- softmax_lse: torch.Tensor,
311
- dq: Optional[torch.Tensor],
312
- dk: Optional[torch.Tensor],
313
- dv: Optional[torch.Tensor],
314
- dropout_p: float,
315
- softmax_scale: float,
316
- causal: bool,
317
- window_size_left: int,
318
- window_size_right: int,
319
- softcap: float,
320
- alibi_slopes: Optional[torch.Tensor],
321
- deterministic: bool,
322
- rng_state: Optional[torch.Tensor] = None,
323
- ) -> torch.Tensor:
324
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
325
- if dq is None:
326
- dq = torch.empty_like(q)
327
- if dk is None:
328
- dk = torch.empty_like(k)
329
- if dv is None:
330
- dv = torch.empty_like(v)
331
- batch_size, seqlen_q, num_heads, _ = q.shape
332
- softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
333
-
334
- return softmax_d
335
-
336
-
337
- if torch.__version__ >= "2.4.0":
338
- _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
339
- else:
340
- _wrapped_flash_attn_backward = _flash_attn_backward
341
-
342
-
343
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
344
- def _flash_attn_varlen_backward(
345
- dout: torch.Tensor,
346
- q: torch.Tensor,
347
- k: torch.Tensor,
348
- v: torch.Tensor,
349
- out: torch.Tensor,
350
- softmax_lse: torch.Tensor,
351
- dq: Optional[torch.Tensor],
352
- dk: Optional[torch.Tensor],
353
- dv: Optional[torch.Tensor],
354
- cu_seqlens_q: torch.Tensor,
355
- cu_seqlens_k: torch.Tensor,
356
- max_seqlen_q: int,
357
- max_seqlen_k: int,
358
- dropout_p: float,
359
- softmax_scale: float,
360
- causal: bool,
361
- window_size_left: int,
362
- window_size_right: int,
363
- softcap: float,
364
- alibi_slopes: Optional[torch.Tensor],
365
- deterministic: bool,
366
- rng_state: Optional[torch.Tensor] = None,
367
- zero_tensors: bool = False,
368
- ) -> torch.Tensor:
369
- # dq, dk, dv are allocated by us so they should already be contiguous
370
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
371
- (
372
- dq,
373
- dk,
374
- dv,
375
- softmax_d,
376
- ) = flash_attn.varlen_bwd(
377
- dout,
378
- q,
379
- k,
380
- v,
381
- out,
382
- softmax_lse,
383
- dq,
384
- dk,
385
- dv,
386
- cu_seqlens_q,
387
- cu_seqlens_k,
388
- alibi_slopes,
389
- max_seqlen_q,
390
- max_seqlen_k,
391
- dropout_p,
392
- softmax_scale,
393
- zero_tensors,
394
- causal,
395
- window_size_left,
396
- window_size_right,
397
- softcap,
398
- deterministic,
399
- None,
400
- rng_state,
401
- )
402
- # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
403
- # breakpoint()
404
- return softmax_d
405
-
406
-
407
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
408
- def _flash_attn_varlen_backward_fake(
409
- dout: torch.Tensor,
410
- q: torch.Tensor,
411
- k: torch.Tensor,
412
- v: torch.Tensor,
413
- out: torch.Tensor,
414
- softmax_lse: torch.Tensor,
415
- dq: Optional[torch.Tensor],
416
- dk: Optional[torch.Tensor],
417
- dv: Optional[torch.Tensor],
418
- cu_seqlens_q: torch.Tensor,
419
- cu_seqlens_k: torch.Tensor,
420
- max_seqlen_q: int,
421
- max_seqlen_k: int,
422
- dropout_p: float,
423
- softmax_scale: float,
424
- causal: bool,
425
- window_size_left: int,
426
- window_size_right: int,
427
- softcap: float,
428
- alibi_slopes: Optional[torch.Tensor],
429
- deterministic: bool,
430
- rng_state: Optional[torch.Tensor] = None,
431
- zero_tensors: bool = False,
432
- ) -> torch.Tensor:
433
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
434
- batch_size = cu_seqlens_q.numel() - 1
435
- total_q, num_heads, _ = q.shape
436
-
437
- if dq is None:
438
- dq = torch.empty_like(q)
439
- if dk is None:
440
- dk = torch.empty_like(k)
441
- if dv is None:
442
- dv = torch.empty_like(v)
443
- softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
444
-
445
- return softmax_d
446
-
447
-
448
- if torch.__version__ >= "2.4.0":
449
- _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
450
- else:
451
- _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
452
-
453
-
454
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
455
- @staticmethod
456
- def forward(
457
- ctx,
458
- qkv,
459
- dropout_p,
460
- softmax_scale,
461
- causal,
462
- window_size,
463
- softcap,
464
- alibi_slopes,
465
- deterministic,
466
- return_softmax,
467
- is_grad_enabled,
468
- ):
469
- is_grad = is_grad_enabled and qkv.requires_grad
470
- if softmax_scale is None:
471
- softmax_scale = qkv.shape[-1] ** (-0.5)
472
- q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
473
- head_size_og = q.size(3)
474
- if head_size_og % 8 != 0:
475
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
476
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
477
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
478
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
479
- q,
480
- k,
481
- v,
482
- dropout_p,
483
- softmax_scale,
484
- causal=causal,
485
- window_size_left=window_size[0],
486
- window_size_right=window_size[1],
487
- softcap=softcap,
488
- alibi_slopes=alibi_slopes,
489
- return_softmax=return_softmax and dropout_p > 0,
490
- )
491
- if is_grad:
492
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
493
- ctx.dropout_p = dropout_p
494
- ctx.softmax_scale = softmax_scale
495
- ctx.causal = causal
496
- ctx.window_size = window_size
497
- ctx.softcap = softcap
498
- ctx.alibi_slopes = alibi_slopes
499
- ctx.deterministic = deterministic
500
- out = out_padded[..., :head_size_og]
501
- return out if not return_softmax else (out, softmax_lse, S_dmask)
502
-
503
- @staticmethod
504
- def backward(ctx, dout, *args):
505
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
506
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
507
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
508
- head_size_og = dout.size(3)
509
- dout_padded = dout
510
- if head_size_og % 8 != 0:
511
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
512
- _wrapped_flash_attn_backward(
513
- dout_padded,
514
- q,
515
- k,
516
- v,
517
- out,
518
- softmax_lse,
519
- dqkv[:, :, 0],
520
- dqkv[:, :, 1],
521
- dqkv[:, :, 2],
522
- ctx.dropout_p,
523
- ctx.softmax_scale,
524
- ctx.causal,
525
- ctx.window_size[0],
526
- ctx.window_size[1],
527
- ctx.softcap,
528
- ctx.alibi_slopes,
529
- ctx.deterministic,
530
- rng_state=rng_state,
531
- )
532
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
533
- return dqkv, None, None, None, None, None, None, None, None, None
534
-
535
-
536
- class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
537
- @staticmethod
538
- def forward(
539
- ctx,
540
- qkv,
541
- cu_seqlens,
542
- max_seqlen,
543
- dropout_p,
544
- softmax_scale,
545
- causal,
546
- window_size,
547
- softcap,
548
- alibi_slopes,
549
- deterministic,
550
- return_softmax,
551
- is_grad_enabled,
552
- ):
553
- is_grad = is_grad_enabled and qkv.requires_grad
554
- if softmax_scale is None:
555
- softmax_scale = qkv.shape[-1] ** (-0.5)
556
- q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
557
- head_size_og = q.size(2)
558
- if head_size_og % 8 != 0:
559
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
560
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
561
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
562
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
563
- q,
564
- k,
565
- v,
566
- cu_seqlens,
567
- cu_seqlens,
568
- max_seqlen,
569
- max_seqlen,
570
- dropout_p,
571
- softmax_scale,
572
- causal=causal,
573
- window_size_left=window_size[0],
574
- window_size_right=window_size[1],
575
- softcap=softcap,
576
- alibi_slopes=alibi_slopes,
577
- return_softmax=return_softmax and dropout_p > 0,
578
- block_table=None,
579
- )
580
- if is_grad:
581
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
582
- ctx.dropout_p = dropout_p
583
- ctx.max_seqlen = max_seqlen
584
- ctx.softmax_scale = softmax_scale
585
- ctx.causal = causal
586
- ctx.window_size = window_size
587
- ctx.softcap = softcap
588
- ctx.alibi_slopes = alibi_slopes
589
- ctx.deterministic = deterministic
590
- out = out_padded[..., :head_size_og]
591
- return out if not return_softmax else (out, softmax_lse, S_dmask)
592
-
593
- @staticmethod
594
- def backward(ctx, dout, *args):
595
- q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
596
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
597
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
598
- head_size_og = dout.size(2)
599
- dout_padded = dout
600
- if head_size_og % 8 != 0:
601
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
602
- _wrapped_flash_attn_varlen_backward(
603
- dout_padded,
604
- q,
605
- k,
606
- v,
607
- out,
608
- softmax_lse,
609
- dqkv[:, 0],
610
- dqkv[:, 1],
611
- dqkv[:, 2],
612
- cu_seqlens,
613
- cu_seqlens,
614
- ctx.max_seqlen,
615
- ctx.max_seqlen,
616
- ctx.dropout_p,
617
- ctx.softmax_scale,
618
- ctx.causal,
619
- ctx.window_size[0],
620
- ctx.window_size[1],
621
- ctx.softcap,
622
- ctx.alibi_slopes,
623
- ctx.deterministic,
624
- rng_state=rng_state,
625
- )
626
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
627
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
628
-
629
-
630
- class FlashAttnKVPackedFunc(torch.autograd.Function):
631
- @staticmethod
632
- def forward(
633
- ctx,
634
- q,
635
- kv,
636
- dropout_p,
637
- softmax_scale,
638
- causal,
639
- window_size,
640
- softcap,
641
- alibi_slopes,
642
- deterministic,
643
- return_softmax,
644
- is_grad_enabled,
645
- ):
646
- is_grad = is_grad_enabled and any(
647
- x.requires_grad for x in [q, kv]
648
- )
649
- if softmax_scale is None:
650
- softmax_scale = q.shape[-1] ** (-0.5)
651
- k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
652
- head_size_og = q.size(3)
653
- if head_size_og % 8 != 0:
654
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
655
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
656
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
657
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
658
- q,
659
- k,
660
- v,
661
- dropout_p,
662
- softmax_scale,
663
- causal=causal,
664
- window_size_left=window_size[0],
665
- window_size_right=window_size[1],
666
- softcap=softcap,
667
- alibi_slopes=alibi_slopes,
668
- return_softmax=return_softmax and dropout_p > 0,
669
- )
670
- if is_grad:
671
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
672
- ctx.dropout_p = dropout_p
673
- ctx.softmax_scale = softmax_scale
674
- ctx.causal = causal
675
- ctx.window_size = window_size
676
- ctx.softcap = softcap
677
- ctx.alibi_slopes = alibi_slopes
678
- ctx.deterministic = deterministic
679
- out = out_padded[..., :head_size_og]
680
- return out if not return_softmax else (out, softmax_lse, S_dmask)
681
-
682
- @staticmethod
683
- def backward(ctx, dout, *args):
684
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
685
- dq = torch.empty_like(q)
686
- kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
687
- dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
688
- head_size_og = dout.size(3)
689
- dout_padded = dout
690
- if head_size_og % 8 != 0:
691
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
692
- _wrapped_flash_attn_backward(
693
- dout_padded,
694
- q,
695
- k,
696
- v,
697
- out,
698
- softmax_lse,
699
- dq,
700
- dkv[:, :, 0],
701
- dkv[:, :, 1],
702
- ctx.dropout_p,
703
- ctx.softmax_scale,
704
- ctx.causal,
705
- ctx.window_size[0],
706
- ctx.window_size[1],
707
- ctx.softcap,
708
- ctx.alibi_slopes,
709
- ctx.deterministic,
710
- rng_state=rng_state,
711
- )
712
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
713
- dkv = dkv[..., : dout.shape[-1]]
714
- return dq, dkv, None, None, None, None, None, None, None, None, None
715
-
716
-
717
- class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
718
- @staticmethod
719
- def forward(
720
- ctx,
721
- q,
722
- kv,
723
- cu_seqlens_q,
724
- cu_seqlens_k,
725
- max_seqlen_q,
726
- max_seqlen_k,
727
- dropout_p,
728
- softmax_scale,
729
- causal,
730
- window_size,
731
- softcap,
732
- alibi_slopes,
733
- deterministic,
734
- return_softmax,
735
- is_grad_enabled,
736
- ):
737
- is_grad = is_grad_enabled and any(
738
- x.requires_grad for x in [q, kv]
739
- )
740
- if softmax_scale is None:
741
- softmax_scale = q.shape[-1] ** (-0.5)
742
- k, v = kv[:, 0].detach(), kv[:, 1].detach()
743
- head_size_og = q.size(2)
744
- if head_size_og % 8 != 0:
745
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
746
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
747
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
748
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
749
- q,
750
- k,
751
- v,
752
- cu_seqlens_q,
753
- cu_seqlens_k,
754
- max_seqlen_q,
755
- max_seqlen_k,
756
- dropout_p,
757
- softmax_scale,
758
- causal=causal,
759
- window_size_left=window_size[0],
760
- window_size_right=window_size[1],
761
- softcap=softcap,
762
- alibi_slopes=alibi_slopes,
763
- return_softmax=return_softmax and dropout_p > 0,
764
- block_table=None,
765
- )
766
- if is_grad:
767
- ctx.save_for_backward(
768
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
769
- )
770
- ctx.dropout_p = dropout_p
771
- ctx.max_seqlen_q = max_seqlen_q
772
- ctx.max_seqlen_k = max_seqlen_k
773
- ctx.softmax_scale = softmax_scale
774
- ctx.causal = causal
775
- ctx.window_size = window_size
776
- ctx.softcap = softcap
777
- ctx.alibi_slopes = alibi_slopes
778
- ctx.deterministic = deterministic
779
- out = out_padded[..., :head_size_og]
780
- return out if not return_softmax else (out, softmax_lse, S_dmask)
781
-
782
- @staticmethod
783
- def backward(ctx, dout, *args):
784
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
785
- dq = torch.empty_like(q)
786
- kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
787
- dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
788
- head_size_og = dout.size(2)
789
- dout_padded = dout
790
- if head_size_og % 8 != 0:
791
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
792
- _wrapped_flash_attn_varlen_backward(
793
- dout_padded,
794
- q,
795
- k,
796
- v,
797
- out,
798
- softmax_lse,
799
- dq,
800
- dkv[:, 0],
801
- dkv[:, 1],
802
- cu_seqlens_q,
803
- cu_seqlens_k,
804
- ctx.max_seqlen_q,
805
- ctx.max_seqlen_k,
806
- ctx.dropout_p,
807
- ctx.softmax_scale,
808
- ctx.causal,
809
- ctx.window_size[0],
810
- ctx.window_size[1],
811
- ctx.softcap,
812
- ctx.alibi_slopes,
813
- ctx.deterministic,
814
- rng_state=rng_state,
815
- )
816
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
817
- dkv = dkv[..., : dout.shape[-1]]
818
- return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None
819
-
820
-
821
- class FlashAttnFunc(torch.autograd.Function):
822
- @staticmethod
823
- def forward(
824
- ctx,
825
- q,
826
- k,
827
- v,
828
- dropout_p,
829
- softmax_scale,
830
- causal,
831
- window_size,
832
- softcap,
833
- alibi_slopes,
834
- deterministic,
835
- return_softmax,
836
- is_grad_enabled,
837
- ):
838
- is_grad = is_grad_enabled and any(
839
- x.requires_grad for x in [q, k, v]
840
- )
841
- if softmax_scale is None:
842
- softmax_scale = q.shape[-1] ** (-0.5)
843
- head_size_og = q.size(3)
844
- if head_size_og % 8 != 0:
845
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
846
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
847
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
848
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
849
- q,
850
- k,
851
- v,
852
- dropout_p,
853
- softmax_scale,
854
- causal=causal,
855
- window_size_left=window_size[0],
856
- window_size_right=window_size[1],
857
- softcap=softcap,
858
- alibi_slopes=alibi_slopes,
859
- return_softmax=return_softmax and dropout_p > 0,
860
- )
861
- if is_grad:
862
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
863
- ctx.dropout_p = dropout_p
864
- ctx.softmax_scale = softmax_scale
865
- ctx.causal = causal
866
- ctx.window_size = window_size
867
- ctx.softcap = softcap
868
- ctx.alibi_slopes = alibi_slopes
869
- ctx.deterministic = deterministic
870
- out = out_padded[..., :head_size_og]
871
- return out if not return_softmax else (out, softmax_lse, S_dmask)
872
-
873
- @staticmethod
874
- def backward(ctx, dout, *args):
875
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
876
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
877
- head_size_og = dout.size(3)
878
- dout_padded = dout
879
- if head_size_og % 8 != 0:
880
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
881
- _wrapped_flash_attn_backward(
882
- dout_padded,
883
- q,
884
- k,
885
- v,
886
- out,
887
- softmax_lse,
888
- dq,
889
- dk,
890
- dv,
891
- ctx.dropout_p,
892
- ctx.softmax_scale,
893
- ctx.causal,
894
- ctx.window_size[0],
895
- ctx.window_size[1],
896
- ctx.softcap,
897
- ctx.alibi_slopes,
898
- ctx.deterministic,
899
- rng_state=rng_state,
900
- )
901
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
902
- dk = dk[..., : dout.shape[-1]]
903
- dv = dv[..., : dout.shape[-1]]
904
- return dq, dk, dv, None, None, None, None, None, None, None, None, None
905
-
906
-
907
- class FlashAttnVarlenFunc(torch.autograd.Function):
908
- @staticmethod
909
- def forward(
910
- ctx,
911
- q,
912
- k,
913
- v,
914
- cu_seqlens_q,
915
- cu_seqlens_k,
916
- max_seqlen_q,
917
- max_seqlen_k,
918
- dropout_p,
919
- softmax_scale,
920
- causal,
921
- window_size,
922
- softcap,
923
- alibi_slopes,
924
- deterministic,
925
- return_softmax,
926
- block_table,
927
- is_grad_enabled,
928
- ):
929
- is_grad = is_grad_enabled and any(
930
- x.requires_grad for x in [q, k, v]
931
- )
932
- if softmax_scale is None:
933
- softmax_scale = q.shape[-1] ** (-0.5)
934
- head_size_og = q.size(2)
935
- if head_size_og % 8 != 0:
936
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
937
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
938
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
939
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
940
- q,
941
- k,
942
- v,
943
- cu_seqlens_q,
944
- cu_seqlens_k,
945
- max_seqlen_q,
946
- max_seqlen_k,
947
- dropout_p,
948
- softmax_scale,
949
- causal=causal,
950
- window_size_left=window_size[0],
951
- window_size_right=window_size[1],
952
- softcap=softcap,
953
- alibi_slopes=alibi_slopes,
954
- return_softmax=return_softmax and dropout_p > 0,
955
- block_table=block_table,
956
- )
957
- if is_grad:
958
- ctx.save_for_backward(
959
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
960
- )
961
- ctx.dropout_p = dropout_p
962
- ctx.max_seqlen_q = max_seqlen_q
963
- ctx.max_seqlen_k = max_seqlen_k
964
- ctx.softmax_scale = softmax_scale
965
- ctx.causal = causal
966
- ctx.window_size = window_size
967
- ctx.softcap = softcap
968
- ctx.alibi_slopes = alibi_slopes
969
- ctx.deterministic = deterministic
970
-
971
- out = out_padded[..., :head_size_og]
972
- return out if not return_softmax else (out, softmax_lse, S_dmask)
973
-
974
- @staticmethod
975
- def backward(ctx, dout, *args):
976
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
977
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
978
- head_size_og = dout.size(2)
979
- dout_padded = dout
980
- if head_size_og % 8 != 0:
981
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
982
- _wrapped_flash_attn_varlen_backward(
983
- dout_padded,
984
- q,
985
- k,
986
- v,
987
- out,
988
- softmax_lse,
989
- dq,
990
- dk,
991
- dv,
992
- cu_seqlens_q,
993
- cu_seqlens_k,
994
- ctx.max_seqlen_q,
995
- ctx.max_seqlen_k,
996
- ctx.dropout_p,
997
- ctx.softmax_scale,
998
- ctx.causal,
999
- ctx.window_size[0],
1000
- ctx.window_size[1],
1001
- ctx.softcap,
1002
- ctx.alibi_slopes,
1003
- ctx.deterministic,
1004
- rng_state=rng_state,
1005
- )
1006
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
1007
- dk = dk[..., : dout.shape[-1]]
1008
- dv = dv[..., : dout.shape[-1]]
1009
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
1010
-
1011
-
1012
- def flash_attn_qkvpacked_func(
1013
- qkv,
1014
- dropout_p=0.0,
1015
- softmax_scale=None,
1016
- causal=False,
1017
- window_size=(-1, -1), # -1 means infinite context window
1018
- softcap=0.0, # <=0.0 means deactivate
1019
- alibi_slopes=None,
1020
- deterministic=False,
1021
- return_attn_probs=False,
1022
- ):
1023
- """dropout_p should be set to 0.0 during evaluation
1024
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
1025
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
1026
- of the gradients of Q, K, V.
1027
- For multi-query and grouped-query attention (MQA/GQA), please see
1028
- flash_attn_kvpacked_func and flash_attn_func.
1029
-
1030
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1031
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
1032
-
1033
- Arguments:
1034
- qkv: (batch_size, seqlen, 3, nheads, headdim)
1035
- dropout_p: float. Dropout probability.
1036
- softmax_scale: float. The scaling of QK^T before applying softmax.
1037
- Default to 1 / sqrt(headdim).
1038
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1039
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1040
- softcap: float. Anything > 0 activates softcapping attention.
1041
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
1042
- the attention score of query i and key j.
1043
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1044
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1045
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1046
- testing only. The returned probabilities are not guaranteed to be correct
1047
- (they might not have the right scaling).
1048
- Return:
1049
- out: (batch_size, seqlen, nheads, headdim).
1050
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
1051
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1052
- normalization factor).
1053
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1054
- The output of softmax (possibly with different scaling). It also encodes the dropout
1055
- pattern (negative means that location was dropped, nonnegative means it was kept).
1056
- """
1057
- return FlashAttnQKVPackedFunc.apply(
1058
- qkv,
1059
- dropout_p,
1060
- softmax_scale,
1061
- causal,
1062
- window_size,
1063
- softcap,
1064
- alibi_slopes,
1065
- deterministic,
1066
- return_attn_probs,
1067
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1068
- )
1069
-
1070
-
1071
- def flash_attn_kvpacked_func(
1072
- q,
1073
- kv,
1074
- dropout_p=0.0,
1075
- softmax_scale=None,
1076
- causal=False,
1077
- window_size=(-1, -1), # -1 means infinite context window
1078
- softcap=0.0, # 0.0 means deactivated
1079
- alibi_slopes=None,
1080
- deterministic=False,
1081
- return_attn_probs=False,
1082
- ):
1083
- """dropout_p should be set to 0.0 during evaluation
1084
- If K, V are already stacked into 1 tensor, this function will be faster than
1085
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
1086
- of the gradients of K, V.
1087
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1088
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1089
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1090
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1091
-
1092
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1093
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1094
- 1 1 1 1 0
1095
- 1 1 1 1 1
1096
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1097
- 0 0
1098
- 0 0
1099
- 0 0
1100
- 1 0
1101
- 1 1
1102
- If the row of the mask is all zero, the output will be zero.
1103
-
1104
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1105
- will only attend to keys between
1106
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1107
-
1108
- Arguments:
1109
- q: (batch_size, seqlen, nheads, headdim)
1110
- kv: (batch_size, seqlen, 2, nheads_k, headdim)
1111
- dropout_p: float. Dropout probability.
1112
- softmax_scale: float. The scaling of QK^T before applying softmax.
1113
- Default to 1 / sqrt(headdim).
1114
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1115
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1116
- softcap: float. Anything > 0 activates softcapping attention.
1117
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1118
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1119
- is added to the attention score of query i and key j.
1120
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1121
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1122
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1123
- testing only. The returned probabilities are not guaranteed to be correct
1124
- (they might not have the right scaling).
1125
- Return:
1126
- out: (batch_size, seqlen, nheads, headdim).
1127
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
1128
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1129
- normalization factor).
1130
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1131
- The output of softmax (possibly with different scaling). It also encodes the dropout
1132
- pattern (negative means that location was dropped, nonnegative means it was kept).
1133
- """
1134
- return FlashAttnKVPackedFunc.apply(
1135
- q,
1136
- kv,
1137
- dropout_p,
1138
- softmax_scale,
1139
- causal,
1140
- window_size,
1141
- softcap,
1142
- alibi_slopes,
1143
- deterministic,
1144
- return_attn_probs,
1145
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1146
- )
1147
-
1148
-
1149
- def flash_attn_func(
1150
- q,
1151
- k,
1152
- v,
1153
- dropout_p=0.0,
1154
- softmax_scale=None,
1155
- causal=False,
1156
- window_size=(-1, -1), # -1 means infinite context window
1157
- softcap=0.0, # 0.0 means deactivated
1158
- alibi_slopes=None,
1159
- deterministic=False,
1160
- return_attn_probs=False,
1161
- ):
1162
- """dropout_p should be set to 0.0 during evaluation
1163
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1164
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1165
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1166
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1167
-
1168
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1169
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1170
- 1 1 1 1 0
1171
- 1 1 1 1 1
1172
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1173
- 0 0
1174
- 0 0
1175
- 0 0
1176
- 1 0
1177
- 1 1
1178
- If the row of the mask is all zero, the output will be zero.
1179
-
1180
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1181
- will only attend to keys between
1182
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1183
-
1184
- Arguments:
1185
- q: (batch_size, seqlen, nheads, headdim)
1186
- k: (batch_size, seqlen, nheads_k, headdim)
1187
- v: (batch_size, seqlen, nheads_k, headdim)
1188
- dropout_p: float. Dropout probability.
1189
- softmax_scale: float. The scaling of QK^T before applying softmax.
1190
- Default to 1 / sqrt(headdim).
1191
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1192
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1193
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1194
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1195
- is added to the attention score of query i and key j.
1196
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1197
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1198
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1199
- testing only. The returned probabilities are not guaranteed to be correct
1200
- (they might not have the right scaling).
1201
- Return:
1202
- out: (batch_size, seqlen, nheads, headdim).
1203
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
1204
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1205
- normalization factor).
1206
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1207
- The output of softmax (possibly with different scaling). It also encodes the dropout
1208
- pattern (negative means that location was dropped, nonnegative means it was kept).
1209
- """
1210
- return FlashAttnFunc.apply(
1211
- q,
1212
- k,
1213
- v,
1214
- dropout_p,
1215
- softmax_scale,
1216
- causal,
1217
- window_size,
1218
- softcap,
1219
- alibi_slopes,
1220
- deterministic,
1221
- return_attn_probs,
1222
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1223
- )
1224
-
1225
-
1226
- def flash_attn_varlen_qkvpacked_func(
1227
- qkv,
1228
- cu_seqlens,
1229
- max_seqlen,
1230
- dropout_p=0.0,
1231
- softmax_scale=None,
1232
- causal=False,
1233
- window_size=(-1, -1), # -1 means infinite context window
1234
- softcap=0.0, # 0.0 means deactivated
1235
- alibi_slopes=None,
1236
- deterministic=False,
1237
- return_attn_probs=False,
1238
- ):
1239
- """dropout_p should be set to 0.0 during evaluation
1240
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
1241
- calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
1242
- of the gradients of Q, K, V.
1243
- For multi-query and grouped-query attention (MQA/GQA), please see
1244
- flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
1245
-
1246
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1247
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
1248
-
1249
- Arguments:
1250
- qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
1251
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1252
- of the sequences in the batch, used to index into qkv.
1253
- max_seqlen: int. Maximum sequence length in the batch.
1254
- dropout_p: float. Dropout probability.
1255
- softmax_scale: float. The scaling of QK^T before applying softmax.
1256
- Default to 1 / sqrt(headdim).
1257
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1258
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1259
- softcap: float. Anything > 0 activates softcapping attention.
1260
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
1261
- is added to the attention score of query i and key j.
1262
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1263
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1264
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1265
- testing only. The returned probabilities are not guaranteed to be correct
1266
- (they might not have the right scaling).
1267
- Return:
1268
- out: (total, nheads, headdim).
1269
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1270
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1271
- normalization factor).
1272
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1273
- The output of softmax (possibly with different scaling). It also encodes the dropout
1274
- pattern (negative means that location was dropped, nonnegative means it was kept).
1275
- """
1276
- return FlashAttnVarlenQKVPackedFunc.apply(
1277
- qkv,
1278
- cu_seqlens,
1279
- max_seqlen,
1280
- dropout_p,
1281
- softmax_scale,
1282
- causal,
1283
- window_size,
1284
- softcap,
1285
- alibi_slopes,
1286
- deterministic,
1287
- return_attn_probs,
1288
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1289
- )
1290
-
1291
-
1292
- def flash_attn_varlen_kvpacked_func(
1293
- q,
1294
- kv,
1295
- cu_seqlens_q,
1296
- cu_seqlens_k,
1297
- max_seqlen_q,
1298
- max_seqlen_k,
1299
- dropout_p=0.0,
1300
- softmax_scale=None,
1301
- causal=False,
1302
- window_size=(-1, -1), # -1 means infinite context window
1303
- softcap=0.0, # 0.0 means deactivated
1304
- alibi_slopes=None,
1305
- deterministic=False,
1306
- return_attn_probs=False,
1307
- ):
1308
- """dropout_p should be set to 0.0 during evaluation
1309
- If K, V are already stacked into 1 tensor, this function will be faster than
1310
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
1311
- of the gradients of K, V.
1312
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1313
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1314
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1315
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1316
-
1317
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1318
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1319
- 1 1 1 1 0
1320
- 1 1 1 1 1
1321
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1322
- 0 0
1323
- 0 0
1324
- 0 0
1325
- 1 0
1326
- 1 1
1327
- If the row of the mask is all zero, the output will be zero.
1328
-
1329
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1330
- will only attend to keys between
1331
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1332
-
1333
- Arguments:
1334
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1335
- kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1336
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1337
- of the sequences in the batch, used to index into q.
1338
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1339
- of the sequences in the batch, used to index into kv.
1340
- max_seqlen_q: int. Maximum query sequence length in the batch.
1341
- max_seqlen_k: int. Maximum key sequence length in the batch.
1342
- dropout_p: float. Dropout probability.
1343
- softmax_scale: float. The scaling of QK^T before applying softmax.
1344
- Default to 1 / sqrt(headdim).
1345
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1346
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1347
- softcap: float. Anything > 0 activates softcapping attention.
1348
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1349
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1350
- is added to the attention score of query i and key j.
1351
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1352
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1353
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1354
- testing only. The returned probabilities are not guaranteed to be correct
1355
- (they might not have the right scaling).
1356
- Return:
1357
- out: (total, nheads, headdim).
1358
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1359
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1360
- normalization factor).
1361
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1362
- The output of softmax (possibly with different scaling). It also encodes the dropout
1363
- pattern (negative means that location was dropped, nonnegative means it was kept).
1364
- """
1365
- return FlashAttnVarlenKVPackedFunc.apply(
1366
- q,
1367
- kv,
1368
- cu_seqlens_q,
1369
- cu_seqlens_k,
1370
- max_seqlen_q,
1371
- max_seqlen_k,
1372
- dropout_p,
1373
- softmax_scale,
1374
- causal,
1375
- window_size,
1376
- softcap,
1377
- alibi_slopes,
1378
- deterministic,
1379
- return_attn_probs,
1380
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1381
- )
1382
-
1383
-
1384
- def flash_attn_varlen_func(
1385
- q,
1386
- k,
1387
- v,
1388
- cu_seqlens_q,
1389
- cu_seqlens_k,
1390
- max_seqlen_q,
1391
- max_seqlen_k,
1392
- dropout_p=0.0,
1393
- softmax_scale=None,
1394
- causal=False,
1395
- window_size=(-1, -1), # -1 means infinite context window
1396
- softcap=0.0, # 0.0 means deactivated
1397
- alibi_slopes=None,
1398
- deterministic=False,
1399
- return_attn_probs=False,
1400
- block_table=None,
1401
- ):
1402
- """dropout_p should be set to 0.0 during evaluation
1403
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1404
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1405
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1406
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1407
-
1408
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1409
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1410
- 1 1 1 1 0
1411
- 1 1 1 1 1
1412
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1413
- 0 0
1414
- 0 0
1415
- 0 0
1416
- 1 0
1417
- 1 1
1418
- If the row of the mask is all zero, the output will be zero.
1419
-
1420
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1421
- will only attend to keys between
1422
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1423
-
1424
- Arguments:
1425
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1426
- k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1427
- v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1428
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1429
- of the sequences in the batch, used to index into q.
1430
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1431
- of the sequences in the batch, used to index into kv.
1432
- max_seqlen_q: int. Maximum query sequence length in the batch.
1433
- max_seqlen_k: int. Maximum key sequence length in the batch.
1434
- dropout_p: float. Dropout probability.
1435
- softmax_scale: float. The scaling of QK^T before applying softmax.
1436
- Default to 1 / sqrt(headdim).
1437
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1438
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1439
- softcap: float. Anything > 0 activates softcapping attention.
1440
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1441
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1442
- is added to the attention score of query i and key j.
1443
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1444
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1445
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1446
- testing only. The returned probabilities are not guaranteed to be correct
1447
- (they might not have the right scaling).
1448
- Return:
1449
- out: (total, nheads, headdim).
1450
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1451
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1452
- normalization factor).
1453
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1454
- The output of softmax (possibly with different scaling). It also encodes the dropout
1455
- pattern (negative means that location was dropped, nonnegative means it was kept).
1456
- """
1457
- return FlashAttnVarlenFunc.apply(
1458
- q,
1459
- k,
1460
- v,
1461
- cu_seqlens_q,
1462
- cu_seqlens_k,
1463
- max_seqlen_q,
1464
- max_seqlen_k,
1465
- dropout_p,
1466
- softmax_scale,
1467
- causal,
1468
- window_size,
1469
- softcap,
1470
- alibi_slopes,
1471
- deterministic,
1472
- return_attn_probs,
1473
- block_table,
1474
- False if _XPU_AVAILABLE or q.device.type == "cpu" else torch.is_grad_enabled(),
1475
- )
1476
-
1477
-
1478
- def flash_attn_with_kvcache(
1479
- q,
1480
- k_cache,
1481
- v_cache,
1482
- k=None,
1483
- v=None,
1484
- rotary_cos=None,
1485
- rotary_sin=None,
1486
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1487
- cache_batch_idx: Optional[torch.Tensor] = None,
1488
- cache_leftpad: Optional[torch.Tensor] = None,
1489
- block_table: Optional[torch.Tensor] = None,
1490
- softmax_scale=None,
1491
- causal=False,
1492
- window_size=(-1, -1), # -1 means infinite context window
1493
- softcap=0.0, # 0.0 means deactivated
1494
- rotary_interleaved=True,
1495
- alibi_slopes=None,
1496
- num_splits=0,
1497
- return_softmax_lse=False,
1498
- ):
1499
- """
1500
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
1501
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
1502
- the previous step, and update them with the new keys/values from the current step, and do
1503
- attention with the updated cache, all in 1 kernel.
1504
-
1505
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
1506
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
1507
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
1508
-
1509
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
1510
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
1511
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
1512
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
1513
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
1514
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1515
-
1516
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
1517
-
1518
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1519
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1520
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1521
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1522
-
1523
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1524
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1525
- 1 1 1 1 0
1526
- 1 1 1 1 1
1527
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1528
- 0 0
1529
- 0 0
1530
- 0 0
1531
- 1 0
1532
- 1 1
1533
- If the row of the mask is all zero, the output will be zero.
1534
-
1535
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1536
- will only attend to keys between
1537
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1538
-
1539
- Note: Does not support backward pass.
1540
-
1541
- Arguments:
1542
- q: (batch_size, seqlen, nheads, headdim)
1543
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
1544
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1545
- page_block_size must be a multiple of 256.
1546
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
1547
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1548
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
1549
- k with k_cache, starting at the indices specified by cache_seqlens.
1550
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1551
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
1552
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
1553
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
1554
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
1555
- KV cache.
1556
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
1557
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
1558
- If the indices are not distinct, and k and v are provided, the values updated in the cache
1559
- might come from any of the duplicate indices.
1560
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
1561
- block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1562
- softmax_scale: float. The scaling of QK^T before applying softmax.
1563
- Default to 1 / sqrt(headdim).
1564
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1565
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1566
- softcap: float. Anything > 0 activates softcapping attention.
1567
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
1568
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1569
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1570
- (i.e. GPT-NeoX style).
1571
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1572
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1573
- is added to the attention score of query i and key j.
1574
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1575
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1576
- to automatically determine the number of splits.
1577
- Don't change this unless you know what you are doing.
1578
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1579
-
1580
- Return:
1581
- out: (batch_size, seqlen, nheads, headdim).
1582
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1583
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1584
- normalization factor).
1585
- """
1586
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1587
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1588
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1589
- if softmax_scale is None:
1590
- softmax_scale = q.shape[-1] ** (-0.5)
1591
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
1592
- cache_seqlens = torch.full(
1593
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1594
- )
1595
- cache_seqlens = maybe_contiguous(cache_seqlens)
1596
- cache_batch_idx = maybe_contiguous(cache_batch_idx)
1597
- block_table = maybe_contiguous(block_table)
1598
- out, softmax_lse = flash_attn.fwd_kvcache(
1599
- q,
1600
- k_cache,
1601
- v_cache,
1602
- k,
1603
- v,
1604
- cache_seqlens,
1605
- rotary_cos,
1606
- rotary_sin,
1607
- cache_batch_idx,
1608
- cache_leftpad,
1609
- block_table,
1610
- alibi_slopes,
1611
- None,
1612
- softmax_scale,
1613
- causal,
1614
- window_size[0],
1615
- window_size[1],
1616
- softcap,
1617
- rotary_interleaved,
1618
- num_splits,
1619
- )
1620
- return (out, softmax_lse) if return_softmax_lse else out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/metadata.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "version": 1,
3
- "python-depends": []
4
- }
 
 
 
 
 
build/torch210-cxx11-cu130-x86_64-linux/ops/triton/rotary.py DELETED
@@ -1,186 +0,0 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- # As of 2025-04-23, we require triton >= 3.0
3
-
4
- from typing import Optional, Union
5
-
6
- import torch
7
-
8
- import triton
9
- import triton.language as tl
10
-
11
-
12
- @triton.jit
13
- def rotary_kernel(
14
- OUT, # Pointers to matrices
15
- X,
16
- COS,
17
- SIN,
18
- CU_SEQLENS,
19
- SEQLEN_OFFSETS, # this could be int or a pointer
20
- # Matrix dimensions
21
- seqlen,
22
- nheads,
23
- seqlen_ro,
24
- # strides
25
- stride_out_batch,
26
- stride_out_seqlen,
27
- stride_out_nheads,
28
- stride_out_headdim,
29
- stride_x_batch,
30
- stride_x_seqlen,
31
- stride_x_nheads,
32
- stride_x_headdim,
33
- # Meta-parameters
34
- # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that
35
- # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128
36
- ROTARY_DIM: tl.constexpr,
37
- IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
38
- IS_VARLEN: tl.constexpr,
39
- INTERLEAVED: tl.constexpr,
40
- CONJUGATE: tl.constexpr,
41
- BLOCK_H: tl.constexpr,
42
- BLOCK_M: tl.constexpr,
43
- ):
44
- BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM)
45
- ROTARY_DIM_HALF = ROTARY_DIM // 2
46
- pid_head = tl.program_id(axis=0)
47
- pid_m = tl.program_id(axis=1)
48
- pid_batch = tl.program_id(axis=2)
49
-
50
- if not IS_VARLEN:
51
- X = X + pid_batch * stride_x_batch
52
- OUT = OUT + pid_batch * stride_out_batch
53
- else:
54
- start_idx = tl.load(CU_SEQLENS + pid_batch)
55
- seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
56
- X = X + start_idx * stride_x_seqlen
57
- OUT = OUT + start_idx * stride_out_seqlen
58
-
59
- if pid_m * BLOCK_M >= seqlen:
60
- return
61
-
62
- rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H)
63
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
64
- if not IS_SEQLEN_OFFSETS_TENSOR:
65
- rm_cs = rm + SEQLEN_OFFSETS
66
- else:
67
- rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
68
-
69
- rk_half = tl.arange(0, BLOCK_K // 2)
70
- COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])
71
- SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])
72
- mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF)
73
- cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32)
74
- sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32)
75
- if CONJUGATE:
76
- sin = -sin
77
-
78
- if not INTERLEAVED:
79
- # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
80
- X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim)
81
- OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim)
82
- mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF)
83
- x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32)
84
- x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32)
85
- o0 = x0 * cos - x1 * sin
86
- o1 = x0 * sin + x1 * cos
87
- tl.store(OUT, o0, mask=mask)
88
- tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask)
89
- else:
90
- rk = tl.arange(0, BLOCK_K)
91
- X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim)
92
- OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim)
93
- mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM)
94
- x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
95
- x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2]))
96
- o0 = x0 * cos - x1 * sin
97
- o1 = x0 * sin + x1 * cos
98
- o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K])
99
- tl.store(OUT, o, mask=mask)
100
-
101
-
102
- def apply_rotary(
103
- x: torch.Tensor,
104
- cos: torch.Tensor,
105
- sin: torch.Tensor,
106
- seqlen_offsets: Union[int, torch.Tensor] = 0,
107
- cu_seqlens: Optional[torch.Tensor] = None,
108
- max_seqlen: Optional[int] = None,
109
- interleaved=False,
110
- inplace=False,
111
- conjugate=False,
112
- ) -> torch.Tensor:
113
- """
114
- Arguments:
115
- x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
116
- else (total_seqlen, nheads, headdim).
117
- cos: (seqlen_ro, rotary_dim / 2)
118
- sin: (seqlen_ro, rotary_dim / 2)
119
- seqlen_offsets: integer or integer tensor of size (batch,)
120
- cu_seqlens: (batch + 1,) or None
121
- max_seqlen: int
122
- Returns:
123
- y: (batch, seqlen, nheads, headdim)
124
- """
125
- is_varlen = cu_seqlens is not None
126
- if not is_varlen:
127
- batch, seqlen, nheads, headdim = x.shape
128
- else:
129
- assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
130
- total_seqlen, nheads, headdim = x.shape
131
- batch_p_1 = cu_seqlens.shape[0]
132
- batch = batch_p_1 - 1
133
- seqlen = max_seqlen
134
- seqlen_ro, rotary_dim = cos.shape
135
- assert sin.shape == cos.shape
136
- rotary_dim *= 2
137
- assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
138
- assert headdim <= 256, "Only support headdim <= 256"
139
- assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
140
-
141
- cos, sin = cos.contiguous(), sin.contiguous()
142
- if isinstance(seqlen_offsets, torch.Tensor):
143
- assert seqlen_offsets.shape == (batch,)
144
- assert seqlen_offsets.dtype in [torch.int32, torch.int64]
145
- seqlen_offsets = seqlen_offsets.contiguous()
146
- else:
147
- assert seqlen_offsets + seqlen <= seqlen_ro
148
-
149
- output = torch.empty_like(x) if not inplace else x
150
- if rotary_dim < headdim and not inplace:
151
- output[..., rotary_dim:].copy_(x[..., rotary_dim:])
152
-
153
- grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa
154
- BLOCK_M = 8 if rotary_dim <= 128 else 4
155
-
156
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
157
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
158
- device_ctx = torch.cuda.device(x.device.index) if x.device.type == 'cuda' else torch.xpu.device(x.device.index)
159
- with device_ctx:
160
- torch.library.wrap_triton(rotary_kernel)[grid](
161
- output, # data ptrs
162
- x,
163
- cos,
164
- sin,
165
- cu_seqlens,
166
- seqlen_offsets,
167
- seqlen, # shapes
168
- nheads,
169
- seqlen_ro,
170
- output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
171
- output.stride(-3), # seqlen_stride or total_seqlen_stride
172
- output.stride(-2), # nheads_stride
173
- output.stride(-1), # headdim_stride
174
- x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
175
- x.stride(-3), # seqlen stride or total_seqlen_stride
176
- x.stride(-2), # nheads stride
177
- x.stride(-1), # headdim stride
178
- rotary_dim,
179
- isinstance(seqlen_offsets, torch.Tensor),
180
- is_varlen,
181
- interleaved,
182
- conjugate,
183
- BLOCK_M=BLOCK_M,
184
- BLOCK_H=2,
185
- )
186
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn2_588b404
3
- ops = torch.ops._flash_attn2_588b404
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn2_588b404::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-xpu20253-x86_64-linux/flash_attn2/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-xpu20253-x86_64-linux/flash_attn_interface.py DELETED
@@ -1,1620 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- from typing import Optional, Sequence, Tuple, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
- import os
8
-
9
- # # isort: off
10
- # # We need to import the CUDA kernels after importing torch
11
- # USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
12
- # if USE_TRITON_ROCM:
13
- # from .flash_attn_triton_amd import interface_fa as flash_attn
14
- # else:
15
- # import flash_attn_2_cuda as flash_attn
16
-
17
-
18
- from ._ops import ops as flash_attn
19
-
20
- # # isort: on
21
-
22
- def maybe_contiguous(x):
23
- return x.contiguous() if x is not None and x.stride(-1) != 1 else x
24
-
25
-
26
- def _get_device():
27
- if torch.xpu.is_available():
28
- return "xpu"
29
- elif torch.cuda.is_available():
30
- return "cuda"
31
- else:
32
- return "cpu"
33
-
34
- _XPU_AVAILABLE = torch.xpu.is_available() if hasattr(torch, "xpu") else False # TODO remove hasattr check when bwd is supported on XPU
35
-
36
-
37
- def _get_block_size_n(device, head_dim, is_dropout, is_causal):
38
- # This should match the block sizes in the CUDA kernel
39
- assert head_dim <= 256
40
- major, minor = torch.cuda.get_device_capability(device)
41
- is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
42
- is_sm80 = major == 8 and minor == 0
43
- is_sm90 = major == 9 and minor == 0
44
- if head_dim <= 32:
45
- return 128
46
- if head_dim <= 64:
47
- return 128 if not is_dropout else 64
48
- elif head_dim <= 96:
49
- return 64
50
- elif head_dim <= 128:
51
- if is_sm8x:
52
- return 64 if (not is_dropout and is_causal) else 32
53
- else:
54
- return 64 if not is_dropout else 32
55
- elif head_dim <= 192:
56
- return 64
57
- elif head_dim <= 224:
58
- return 64
59
- elif head_dim <= 256:
60
- return 64
61
-
62
-
63
- def round_multiple(x, m):
64
- return (x + m - 1) // m * m
65
-
66
-
67
- # torch.compile() support is only enabled for pytorch >= 2.4
68
- # The reason for this is that we are using the new custom_op and register_fake
69
- # APIs, which support inplace modification of inputs in the function itself
70
- if torch.__version__ >= "2.4.0":
71
- _torch_custom_op_wrapper = torch.library.custom_op
72
- _torch_register_fake_wrapper = torch.library.register_fake
73
- else:
74
- def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
75
- def wrap(func):
76
- return func
77
- if fn is None:
78
- return wrap
79
- return fn
80
- def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
81
- def wrap(func):
82
- return func
83
- if fn is None:
84
- return wrap
85
- return fn
86
- _torch_custom_op_wrapper = noop_custom_op_wrapper
87
- _torch_register_fake_wrapper = noop_register_fake_wrapper
88
-
89
-
90
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types=_get_device())
91
- def _flash_attn_forward(
92
- q: torch.Tensor,
93
- k: torch.Tensor,
94
- v: torch.Tensor,
95
- dropout_p: float,
96
- softmax_scale: float,
97
- causal: bool,
98
- window_size_left: int,
99
- window_size_right: int,
100
- softcap: float,
101
- alibi_slopes: Optional[torch.Tensor],
102
- return_softmax: bool
103
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
105
- out, softmax_lse, S_dmask, rng_state = flash_attn.fwd(
106
- q,
107
- k,
108
- v,
109
- None,
110
- alibi_slopes,
111
- dropout_p,
112
- softmax_scale,
113
- causal,
114
- window_size_left,
115
- window_size_right,
116
- softcap,
117
- return_softmax,
118
- None,
119
- )
120
- return out, softmax_lse, S_dmask, rng_state
121
-
122
-
123
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
124
- def _flash_attn_forward_fake(
125
- q: torch.Tensor,
126
- k: torch.Tensor,
127
- v: torch.Tensor,
128
- dropout_p: float,
129
- softmax_scale: float,
130
- causal: bool,
131
- window_size_left: int,
132
- window_size_right: int,
133
- softcap: float,
134
- alibi_slopes: Optional[torch.Tensor],
135
- return_softmax: bool
136
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
137
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
138
- batch_size, seqlen_q, num_heads, head_size = q.shape
139
- seqlen_k = k.shape[1]
140
- out = torch.empty_like(q)
141
- softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
142
- p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
143
- if return_softmax:
144
- p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
145
- rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
146
-
147
- return out, softmax_lse, p, rng_state
148
-
149
-
150
- if torch.__version__ >= "2.4.0":
151
- _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
152
- else:
153
- _wrapped_flash_attn_forward = _flash_attn_forward
154
-
155
-
156
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types=_get_device())
157
- def _flash_attn_varlen_forward(
158
- q: torch.Tensor,
159
- k: torch.Tensor,
160
- v: torch.Tensor,
161
- cu_seqlens_q: torch.Tensor,
162
- cu_seqlens_k: torch.Tensor,
163
- max_seqlen_q: int,
164
- max_seqlen_k: int,
165
- dropout_p: float,
166
- softmax_scale: float,
167
- causal: bool,
168
- window_size_left: int = -1,
169
- window_size_right: int = -1,
170
- softcap: float = 0.0,
171
- alibi_slopes: Optional[torch.Tensor] = None,
172
- return_softmax: bool = False,
173
- block_table: Optional[torch.Tensor] = None,
174
- leftpad_k: Optional[torch.Tensor] = None,
175
- seqused_k: Optional[torch.Tensor] = None,
176
- zero_tensors: bool = False,
177
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
178
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
179
- out, softmax_lse, S_dmask, rng_state = flash_attn.varlen_fwd(
180
- q,
181
- k,
182
- v,
183
- None,
184
- cu_seqlens_q,
185
- cu_seqlens_k,
186
- seqused_k,
187
- leftpad_k,
188
- block_table,
189
- alibi_slopes,
190
- max_seqlen_q,
191
- max_seqlen_k,
192
- dropout_p,
193
- softmax_scale,
194
- zero_tensors,
195
- causal,
196
- window_size_left,
197
- window_size_right,
198
- softcap,
199
- return_softmax,
200
- None,
201
- )
202
- # if out.isnan().any() or softmax_lse.isnan().any():
203
- # breakpoint()
204
- return out, softmax_lse, S_dmask, rng_state
205
-
206
-
207
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
208
- def _flash_attn_varlen_forward_fake(
209
- q: torch.Tensor,
210
- k: torch.Tensor,
211
- v: torch.Tensor,
212
- cu_seqlens_q: torch.Tensor,
213
- cu_seqlens_k: torch.Tensor,
214
- max_seqlen_q: int,
215
- max_seqlen_k: int,
216
- dropout_p: float,
217
- softmax_scale: float,
218
- causal: bool,
219
- window_size_left: int = -1,
220
- window_size_right: int = -1,
221
- softcap: float = 0.0,
222
- alibi_slopes: Optional[torch.Tensor] = None,
223
- return_softmax: bool = False,
224
- block_table: Optional[torch.Tensor] = None,
225
- leftpad_k: Optional[torch.Tensor] = None,
226
- seqused_k: Optional[torch.Tensor] = None,
227
- zero_tensors: bool = False,
228
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
229
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
230
- paged_kv = block_table is not None
231
- batch_size = cu_seqlens_q.numel() - 1
232
- total_q, num_heads, _ = q.shape
233
-
234
- out = torch.empty_like(q)
235
- softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
236
- p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
237
- seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
238
- seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
239
- if return_softmax:
240
- p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
241
- rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
242
- return out, softmax_lse, p, rng_state
243
-
244
-
245
- if torch.__version__ >= "2.4.0":
246
- _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
247
- else:
248
- _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
249
-
250
-
251
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
252
- def _flash_attn_backward(
253
- dout: torch.Tensor,
254
- q: torch.Tensor,
255
- k: torch.Tensor,
256
- v: torch.Tensor,
257
- out: torch.Tensor,
258
- softmax_lse: torch.Tensor,
259
- dq: Optional[torch.Tensor],
260
- dk: Optional[torch.Tensor],
261
- dv: Optional[torch.Tensor],
262
- dropout_p: float,
263
- softmax_scale: float,
264
- causal: bool,
265
- window_size_left: int,
266
- window_size_right: int,
267
- softcap: float,
268
- alibi_slopes: Optional[torch.Tensor],
269
- deterministic: bool,
270
- rng_state: Optional[torch.Tensor] = None,
271
- ) -> torch.Tensor:
272
- # dq, dk, dv are allocated by us so they should already be contiguous
273
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
274
- (
275
- dq,
276
- dk,
277
- dv,
278
- softmax_d,
279
- ) = flash_attn.bwd(
280
- dout,
281
- q,
282
- k,
283
- v,
284
- out,
285
- softmax_lse,
286
- dq,
287
- dk,
288
- dv,
289
- alibi_slopes,
290
- dropout_p,
291
- softmax_scale,
292
- causal,
293
- window_size_left,
294
- window_size_right,
295
- softcap,
296
- deterministic,
297
- None,
298
- rng_state,
299
- )
300
- return softmax_d
301
-
302
-
303
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
304
- def _flash_attn_backward_fake(
305
- dout: torch.Tensor,
306
- q: torch.Tensor,
307
- k: torch.Tensor,
308
- v: torch.Tensor,
309
- out: torch.Tensor,
310
- softmax_lse: torch.Tensor,
311
- dq: Optional[torch.Tensor],
312
- dk: Optional[torch.Tensor],
313
- dv: Optional[torch.Tensor],
314
- dropout_p: float,
315
- softmax_scale: float,
316
- causal: bool,
317
- window_size_left: int,
318
- window_size_right: int,
319
- softcap: float,
320
- alibi_slopes: Optional[torch.Tensor],
321
- deterministic: bool,
322
- rng_state: Optional[torch.Tensor] = None,
323
- ) -> torch.Tensor:
324
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
325
- if dq is None:
326
- dq = torch.empty_like(q)
327
- if dk is None:
328
- dk = torch.empty_like(k)
329
- if dv is None:
330
- dv = torch.empty_like(v)
331
- batch_size, seqlen_q, num_heads, _ = q.shape
332
- softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
333
-
334
- return softmax_d
335
-
336
-
337
- if torch.__version__ >= "2.4.0":
338
- _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
339
- else:
340
- _wrapped_flash_attn_backward = _flash_attn_backward
341
-
342
-
343
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
344
- def _flash_attn_varlen_backward(
345
- dout: torch.Tensor,
346
- q: torch.Tensor,
347
- k: torch.Tensor,
348
- v: torch.Tensor,
349
- out: torch.Tensor,
350
- softmax_lse: torch.Tensor,
351
- dq: Optional[torch.Tensor],
352
- dk: Optional[torch.Tensor],
353
- dv: Optional[torch.Tensor],
354
- cu_seqlens_q: torch.Tensor,
355
- cu_seqlens_k: torch.Tensor,
356
- max_seqlen_q: int,
357
- max_seqlen_k: int,
358
- dropout_p: float,
359
- softmax_scale: float,
360
- causal: bool,
361
- window_size_left: int,
362
- window_size_right: int,
363
- softcap: float,
364
- alibi_slopes: Optional[torch.Tensor],
365
- deterministic: bool,
366
- rng_state: Optional[torch.Tensor] = None,
367
- zero_tensors: bool = False,
368
- ) -> torch.Tensor:
369
- # dq, dk, dv are allocated by us so they should already be contiguous
370
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
371
- (
372
- dq,
373
- dk,
374
- dv,
375
- softmax_d,
376
- ) = flash_attn.varlen_bwd(
377
- dout,
378
- q,
379
- k,
380
- v,
381
- out,
382
- softmax_lse,
383
- dq,
384
- dk,
385
- dv,
386
- cu_seqlens_q,
387
- cu_seqlens_k,
388
- alibi_slopes,
389
- max_seqlen_q,
390
- max_seqlen_k,
391
- dropout_p,
392
- softmax_scale,
393
- zero_tensors,
394
- causal,
395
- window_size_left,
396
- window_size_right,
397
- softcap,
398
- deterministic,
399
- None,
400
- rng_state,
401
- )
402
- # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
403
- # breakpoint()
404
- return softmax_d
405
-
406
-
407
- @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
408
- def _flash_attn_varlen_backward_fake(
409
- dout: torch.Tensor,
410
- q: torch.Tensor,
411
- k: torch.Tensor,
412
- v: torch.Tensor,
413
- out: torch.Tensor,
414
- softmax_lse: torch.Tensor,
415
- dq: Optional[torch.Tensor],
416
- dk: Optional[torch.Tensor],
417
- dv: Optional[torch.Tensor],
418
- cu_seqlens_q: torch.Tensor,
419
- cu_seqlens_k: torch.Tensor,
420
- max_seqlen_q: int,
421
- max_seqlen_k: int,
422
- dropout_p: float,
423
- softmax_scale: float,
424
- causal: bool,
425
- window_size_left: int,
426
- window_size_right: int,
427
- softcap: float,
428
- alibi_slopes: Optional[torch.Tensor],
429
- deterministic: bool,
430
- rng_state: Optional[torch.Tensor] = None,
431
- zero_tensors: bool = False,
432
- ) -> torch.Tensor:
433
- dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
434
- batch_size = cu_seqlens_q.numel() - 1
435
- total_q, num_heads, _ = q.shape
436
-
437
- if dq is None:
438
- dq = torch.empty_like(q)
439
- if dk is None:
440
- dk = torch.empty_like(k)
441
- if dv is None:
442
- dv = torch.empty_like(v)
443
- softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
444
-
445
- return softmax_d
446
-
447
-
448
- if torch.__version__ >= "2.4.0":
449
- _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
450
- else:
451
- _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
452
-
453
-
454
- class FlashAttnQKVPackedFunc(torch.autograd.Function):
455
- @staticmethod
456
- def forward(
457
- ctx,
458
- qkv,
459
- dropout_p,
460
- softmax_scale,
461
- causal,
462
- window_size,
463
- softcap,
464
- alibi_slopes,
465
- deterministic,
466
- return_softmax,
467
- is_grad_enabled,
468
- ):
469
- is_grad = is_grad_enabled and qkv.requires_grad
470
- if softmax_scale is None:
471
- softmax_scale = qkv.shape[-1] ** (-0.5)
472
- q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
473
- head_size_og = q.size(3)
474
- if head_size_og % 8 != 0:
475
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
476
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
477
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
478
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
479
- q,
480
- k,
481
- v,
482
- dropout_p,
483
- softmax_scale,
484
- causal=causal,
485
- window_size_left=window_size[0],
486
- window_size_right=window_size[1],
487
- softcap=softcap,
488
- alibi_slopes=alibi_slopes,
489
- return_softmax=return_softmax and dropout_p > 0,
490
- )
491
- if is_grad:
492
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
493
- ctx.dropout_p = dropout_p
494
- ctx.softmax_scale = softmax_scale
495
- ctx.causal = causal
496
- ctx.window_size = window_size
497
- ctx.softcap = softcap
498
- ctx.alibi_slopes = alibi_slopes
499
- ctx.deterministic = deterministic
500
- out = out_padded[..., :head_size_og]
501
- return out if not return_softmax else (out, softmax_lse, S_dmask)
502
-
503
- @staticmethod
504
- def backward(ctx, dout, *args):
505
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
506
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
507
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
508
- head_size_og = dout.size(3)
509
- dout_padded = dout
510
- if head_size_og % 8 != 0:
511
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
512
- _wrapped_flash_attn_backward(
513
- dout_padded,
514
- q,
515
- k,
516
- v,
517
- out,
518
- softmax_lse,
519
- dqkv[:, :, 0],
520
- dqkv[:, :, 1],
521
- dqkv[:, :, 2],
522
- ctx.dropout_p,
523
- ctx.softmax_scale,
524
- ctx.causal,
525
- ctx.window_size[0],
526
- ctx.window_size[1],
527
- ctx.softcap,
528
- ctx.alibi_slopes,
529
- ctx.deterministic,
530
- rng_state=rng_state,
531
- )
532
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
533
- return dqkv, None, None, None, None, None, None, None, None, None
534
-
535
-
536
- class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
537
- @staticmethod
538
- def forward(
539
- ctx,
540
- qkv,
541
- cu_seqlens,
542
- max_seqlen,
543
- dropout_p,
544
- softmax_scale,
545
- causal,
546
- window_size,
547
- softcap,
548
- alibi_slopes,
549
- deterministic,
550
- return_softmax,
551
- is_grad_enabled,
552
- ):
553
- is_grad = is_grad_enabled and qkv.requires_grad
554
- if softmax_scale is None:
555
- softmax_scale = qkv.shape[-1] ** (-0.5)
556
- q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
557
- head_size_og = q.size(2)
558
- if head_size_og % 8 != 0:
559
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
560
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
561
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
562
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
563
- q,
564
- k,
565
- v,
566
- cu_seqlens,
567
- cu_seqlens,
568
- max_seqlen,
569
- max_seqlen,
570
- dropout_p,
571
- softmax_scale,
572
- causal=causal,
573
- window_size_left=window_size[0],
574
- window_size_right=window_size[1],
575
- softcap=softcap,
576
- alibi_slopes=alibi_slopes,
577
- return_softmax=return_softmax and dropout_p > 0,
578
- block_table=None,
579
- )
580
- if is_grad:
581
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
582
- ctx.dropout_p = dropout_p
583
- ctx.max_seqlen = max_seqlen
584
- ctx.softmax_scale = softmax_scale
585
- ctx.causal = causal
586
- ctx.window_size = window_size
587
- ctx.softcap = softcap
588
- ctx.alibi_slopes = alibi_slopes
589
- ctx.deterministic = deterministic
590
- out = out_padded[..., :head_size_og]
591
- return out if not return_softmax else (out, softmax_lse, S_dmask)
592
-
593
- @staticmethod
594
- def backward(ctx, dout, *args):
595
- q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
596
- qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
597
- dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
598
- head_size_og = dout.size(2)
599
- dout_padded = dout
600
- if head_size_og % 8 != 0:
601
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
602
- _wrapped_flash_attn_varlen_backward(
603
- dout_padded,
604
- q,
605
- k,
606
- v,
607
- out,
608
- softmax_lse,
609
- dqkv[:, 0],
610
- dqkv[:, 1],
611
- dqkv[:, 2],
612
- cu_seqlens,
613
- cu_seqlens,
614
- ctx.max_seqlen,
615
- ctx.max_seqlen,
616
- ctx.dropout_p,
617
- ctx.softmax_scale,
618
- ctx.causal,
619
- ctx.window_size[0],
620
- ctx.window_size[1],
621
- ctx.softcap,
622
- ctx.alibi_slopes,
623
- ctx.deterministic,
624
- rng_state=rng_state,
625
- )
626
- dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
627
- return dqkv, None, None, None, None, None, None, None, None, None, None, None
628
-
629
-
630
- class FlashAttnKVPackedFunc(torch.autograd.Function):
631
- @staticmethod
632
- def forward(
633
- ctx,
634
- q,
635
- kv,
636
- dropout_p,
637
- softmax_scale,
638
- causal,
639
- window_size,
640
- softcap,
641
- alibi_slopes,
642
- deterministic,
643
- return_softmax,
644
- is_grad_enabled,
645
- ):
646
- is_grad = is_grad_enabled and any(
647
- x.requires_grad for x in [q, kv]
648
- )
649
- if softmax_scale is None:
650
- softmax_scale = q.shape[-1] ** (-0.5)
651
- k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
652
- head_size_og = q.size(3)
653
- if head_size_og % 8 != 0:
654
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
655
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
656
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
657
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
658
- q,
659
- k,
660
- v,
661
- dropout_p,
662
- softmax_scale,
663
- causal=causal,
664
- window_size_left=window_size[0],
665
- window_size_right=window_size[1],
666
- softcap=softcap,
667
- alibi_slopes=alibi_slopes,
668
- return_softmax=return_softmax and dropout_p > 0,
669
- )
670
- if is_grad:
671
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
672
- ctx.dropout_p = dropout_p
673
- ctx.softmax_scale = softmax_scale
674
- ctx.causal = causal
675
- ctx.window_size = window_size
676
- ctx.softcap = softcap
677
- ctx.alibi_slopes = alibi_slopes
678
- ctx.deterministic = deterministic
679
- out = out_padded[..., :head_size_og]
680
- return out if not return_softmax else (out, softmax_lse, S_dmask)
681
-
682
- @staticmethod
683
- def backward(ctx, dout, *args):
684
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
685
- dq = torch.empty_like(q)
686
- kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
687
- dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
688
- head_size_og = dout.size(3)
689
- dout_padded = dout
690
- if head_size_og % 8 != 0:
691
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
692
- _wrapped_flash_attn_backward(
693
- dout_padded,
694
- q,
695
- k,
696
- v,
697
- out,
698
- softmax_lse,
699
- dq,
700
- dkv[:, :, 0],
701
- dkv[:, :, 1],
702
- ctx.dropout_p,
703
- ctx.softmax_scale,
704
- ctx.causal,
705
- ctx.window_size[0],
706
- ctx.window_size[1],
707
- ctx.softcap,
708
- ctx.alibi_slopes,
709
- ctx.deterministic,
710
- rng_state=rng_state,
711
- )
712
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
713
- dkv = dkv[..., : dout.shape[-1]]
714
- return dq, dkv, None, None, None, None, None, None, None, None, None
715
-
716
-
717
- class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
718
- @staticmethod
719
- def forward(
720
- ctx,
721
- q,
722
- kv,
723
- cu_seqlens_q,
724
- cu_seqlens_k,
725
- max_seqlen_q,
726
- max_seqlen_k,
727
- dropout_p,
728
- softmax_scale,
729
- causal,
730
- window_size,
731
- softcap,
732
- alibi_slopes,
733
- deterministic,
734
- return_softmax,
735
- is_grad_enabled,
736
- ):
737
- is_grad = is_grad_enabled and any(
738
- x.requires_grad for x in [q, kv]
739
- )
740
- if softmax_scale is None:
741
- softmax_scale = q.shape[-1] ** (-0.5)
742
- k, v = kv[:, 0].detach(), kv[:, 1].detach()
743
- head_size_og = q.size(2)
744
- if head_size_og % 8 != 0:
745
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
746
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
747
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
748
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
749
- q,
750
- k,
751
- v,
752
- cu_seqlens_q,
753
- cu_seqlens_k,
754
- max_seqlen_q,
755
- max_seqlen_k,
756
- dropout_p,
757
- softmax_scale,
758
- causal=causal,
759
- window_size_left=window_size[0],
760
- window_size_right=window_size[1],
761
- softcap=softcap,
762
- alibi_slopes=alibi_slopes,
763
- return_softmax=return_softmax and dropout_p > 0,
764
- block_table=None,
765
- )
766
- if is_grad:
767
- ctx.save_for_backward(
768
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
769
- )
770
- ctx.dropout_p = dropout_p
771
- ctx.max_seqlen_q = max_seqlen_q
772
- ctx.max_seqlen_k = max_seqlen_k
773
- ctx.softmax_scale = softmax_scale
774
- ctx.causal = causal
775
- ctx.window_size = window_size
776
- ctx.softcap = softcap
777
- ctx.alibi_slopes = alibi_slopes
778
- ctx.deterministic = deterministic
779
- out = out_padded[..., :head_size_og]
780
- return out if not return_softmax else (out, softmax_lse, S_dmask)
781
-
782
- @staticmethod
783
- def backward(ctx, dout, *args):
784
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
785
- dq = torch.empty_like(q)
786
- kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
787
- dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
788
- head_size_og = dout.size(2)
789
- dout_padded = dout
790
- if head_size_og % 8 != 0:
791
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
792
- _wrapped_flash_attn_varlen_backward(
793
- dout_padded,
794
- q,
795
- k,
796
- v,
797
- out,
798
- softmax_lse,
799
- dq,
800
- dkv[:, 0],
801
- dkv[:, 1],
802
- cu_seqlens_q,
803
- cu_seqlens_k,
804
- ctx.max_seqlen_q,
805
- ctx.max_seqlen_k,
806
- ctx.dropout_p,
807
- ctx.softmax_scale,
808
- ctx.causal,
809
- ctx.window_size[0],
810
- ctx.window_size[1],
811
- ctx.softcap,
812
- ctx.alibi_slopes,
813
- ctx.deterministic,
814
- rng_state=rng_state,
815
- )
816
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
817
- dkv = dkv[..., : dout.shape[-1]]
818
- return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None
819
-
820
-
821
- class FlashAttnFunc(torch.autograd.Function):
822
- @staticmethod
823
- def forward(
824
- ctx,
825
- q,
826
- k,
827
- v,
828
- dropout_p,
829
- softmax_scale,
830
- causal,
831
- window_size,
832
- softcap,
833
- alibi_slopes,
834
- deterministic,
835
- return_softmax,
836
- is_grad_enabled,
837
- ):
838
- is_grad = is_grad_enabled and any(
839
- x.requires_grad for x in [q, k, v]
840
- )
841
- if softmax_scale is None:
842
- softmax_scale = q.shape[-1] ** (-0.5)
843
- head_size_og = q.size(3)
844
- if head_size_og % 8 != 0:
845
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
846
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
847
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
848
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
849
- q,
850
- k,
851
- v,
852
- dropout_p,
853
- softmax_scale,
854
- causal=causal,
855
- window_size_left=window_size[0],
856
- window_size_right=window_size[1],
857
- softcap=softcap,
858
- alibi_slopes=alibi_slopes,
859
- return_softmax=return_softmax and dropout_p > 0,
860
- )
861
- if is_grad:
862
- ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
863
- ctx.dropout_p = dropout_p
864
- ctx.softmax_scale = softmax_scale
865
- ctx.causal = causal
866
- ctx.window_size = window_size
867
- ctx.softcap = softcap
868
- ctx.alibi_slopes = alibi_slopes
869
- ctx.deterministic = deterministic
870
- out = out_padded[..., :head_size_og]
871
- return out if not return_softmax else (out, softmax_lse, S_dmask)
872
-
873
- @staticmethod
874
- def backward(ctx, dout, *args):
875
- q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
876
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
877
- head_size_og = dout.size(3)
878
- dout_padded = dout
879
- if head_size_og % 8 != 0:
880
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
881
- _wrapped_flash_attn_backward(
882
- dout_padded,
883
- q,
884
- k,
885
- v,
886
- out,
887
- softmax_lse,
888
- dq,
889
- dk,
890
- dv,
891
- ctx.dropout_p,
892
- ctx.softmax_scale,
893
- ctx.causal,
894
- ctx.window_size[0],
895
- ctx.window_size[1],
896
- ctx.softcap,
897
- ctx.alibi_slopes,
898
- ctx.deterministic,
899
- rng_state=rng_state,
900
- )
901
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
902
- dk = dk[..., : dout.shape[-1]]
903
- dv = dv[..., : dout.shape[-1]]
904
- return dq, dk, dv, None, None, None, None, None, None, None, None, None
905
-
906
-
907
- class FlashAttnVarlenFunc(torch.autograd.Function):
908
- @staticmethod
909
- def forward(
910
- ctx,
911
- q,
912
- k,
913
- v,
914
- cu_seqlens_q,
915
- cu_seqlens_k,
916
- max_seqlen_q,
917
- max_seqlen_k,
918
- dropout_p,
919
- softmax_scale,
920
- causal,
921
- window_size,
922
- softcap,
923
- alibi_slopes,
924
- deterministic,
925
- return_softmax,
926
- block_table,
927
- is_grad_enabled,
928
- ):
929
- is_grad = is_grad_enabled and any(
930
- x.requires_grad for x in [q, k, v]
931
- )
932
- if softmax_scale is None:
933
- softmax_scale = q.shape[-1] ** (-0.5)
934
- head_size_og = q.size(2)
935
- if head_size_og % 8 != 0:
936
- q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
937
- k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
938
- v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
939
- out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
940
- q,
941
- k,
942
- v,
943
- cu_seqlens_q,
944
- cu_seqlens_k,
945
- max_seqlen_q,
946
- max_seqlen_k,
947
- dropout_p,
948
- softmax_scale,
949
- causal=causal,
950
- window_size_left=window_size[0],
951
- window_size_right=window_size[1],
952
- softcap=softcap,
953
- alibi_slopes=alibi_slopes,
954
- return_softmax=return_softmax and dropout_p > 0,
955
- block_table=block_table,
956
- )
957
- if is_grad:
958
- ctx.save_for_backward(
959
- q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
960
- )
961
- ctx.dropout_p = dropout_p
962
- ctx.max_seqlen_q = max_seqlen_q
963
- ctx.max_seqlen_k = max_seqlen_k
964
- ctx.softmax_scale = softmax_scale
965
- ctx.causal = causal
966
- ctx.window_size = window_size
967
- ctx.softcap = softcap
968
- ctx.alibi_slopes = alibi_slopes
969
- ctx.deterministic = deterministic
970
-
971
- out = out_padded[..., :head_size_og]
972
- return out if not return_softmax else (out, softmax_lse, S_dmask)
973
-
974
- @staticmethod
975
- def backward(ctx, dout, *args):
976
- q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
977
- dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
978
- head_size_og = dout.size(2)
979
- dout_padded = dout
980
- if head_size_og % 8 != 0:
981
- dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
982
- _wrapped_flash_attn_varlen_backward(
983
- dout_padded,
984
- q,
985
- k,
986
- v,
987
- out,
988
- softmax_lse,
989
- dq,
990
- dk,
991
- dv,
992
- cu_seqlens_q,
993
- cu_seqlens_k,
994
- ctx.max_seqlen_q,
995
- ctx.max_seqlen_k,
996
- ctx.dropout_p,
997
- ctx.softmax_scale,
998
- ctx.causal,
999
- ctx.window_size[0],
1000
- ctx.window_size[1],
1001
- ctx.softcap,
1002
- ctx.alibi_slopes,
1003
- ctx.deterministic,
1004
- rng_state=rng_state,
1005
- )
1006
- dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
1007
- dk = dk[..., : dout.shape[-1]]
1008
- dv = dv[..., : dout.shape[-1]]
1009
- return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
1010
-
1011
-
1012
- def flash_attn_qkvpacked_func(
1013
- qkv,
1014
- dropout_p=0.0,
1015
- softmax_scale=None,
1016
- causal=False,
1017
- window_size=(-1, -1), # -1 means infinite context window
1018
- softcap=0.0, # <=0.0 means deactivate
1019
- alibi_slopes=None,
1020
- deterministic=False,
1021
- return_attn_probs=False,
1022
- ):
1023
- """dropout_p should be set to 0.0 during evaluation
1024
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
1025
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
1026
- of the gradients of Q, K, V.
1027
- For multi-query and grouped-query attention (MQA/GQA), please see
1028
- flash_attn_kvpacked_func and flash_attn_func.
1029
-
1030
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1031
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
1032
-
1033
- Arguments:
1034
- qkv: (batch_size, seqlen, 3, nheads, headdim)
1035
- dropout_p: float. Dropout probability.
1036
- softmax_scale: float. The scaling of QK^T before applying softmax.
1037
- Default to 1 / sqrt(headdim).
1038
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1039
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1040
- softcap: float. Anything > 0 activates softcapping attention.
1041
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
1042
- the attention score of query i and key j.
1043
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1044
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1045
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1046
- testing only. The returned probabilities are not guaranteed to be correct
1047
- (they might not have the right scaling).
1048
- Return:
1049
- out: (batch_size, seqlen, nheads, headdim).
1050
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
1051
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1052
- normalization factor).
1053
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1054
- The output of softmax (possibly with different scaling). It also encodes the dropout
1055
- pattern (negative means that location was dropped, nonnegative means it was kept).
1056
- """
1057
- return FlashAttnQKVPackedFunc.apply(
1058
- qkv,
1059
- dropout_p,
1060
- softmax_scale,
1061
- causal,
1062
- window_size,
1063
- softcap,
1064
- alibi_slopes,
1065
- deterministic,
1066
- return_attn_probs,
1067
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1068
- )
1069
-
1070
-
1071
- def flash_attn_kvpacked_func(
1072
- q,
1073
- kv,
1074
- dropout_p=0.0,
1075
- softmax_scale=None,
1076
- causal=False,
1077
- window_size=(-1, -1), # -1 means infinite context window
1078
- softcap=0.0, # 0.0 means deactivated
1079
- alibi_slopes=None,
1080
- deterministic=False,
1081
- return_attn_probs=False,
1082
- ):
1083
- """dropout_p should be set to 0.0 during evaluation
1084
- If K, V are already stacked into 1 tensor, this function will be faster than
1085
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
1086
- of the gradients of K, V.
1087
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1088
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1089
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1090
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1091
-
1092
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1093
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1094
- 1 1 1 1 0
1095
- 1 1 1 1 1
1096
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1097
- 0 0
1098
- 0 0
1099
- 0 0
1100
- 1 0
1101
- 1 1
1102
- If the row of the mask is all zero, the output will be zero.
1103
-
1104
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1105
- will only attend to keys between
1106
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1107
-
1108
- Arguments:
1109
- q: (batch_size, seqlen, nheads, headdim)
1110
- kv: (batch_size, seqlen, 2, nheads_k, headdim)
1111
- dropout_p: float. Dropout probability.
1112
- softmax_scale: float. The scaling of QK^T before applying softmax.
1113
- Default to 1 / sqrt(headdim).
1114
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1115
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1116
- softcap: float. Anything > 0 activates softcapping attention.
1117
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1118
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1119
- is added to the attention score of query i and key j.
1120
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1121
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1122
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1123
- testing only. The returned probabilities are not guaranteed to be correct
1124
- (they might not have the right scaling).
1125
- Return:
1126
- out: (batch_size, seqlen, nheads, headdim).
1127
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
1128
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1129
- normalization factor).
1130
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1131
- The output of softmax (possibly with different scaling). It also encodes the dropout
1132
- pattern (negative means that location was dropped, nonnegative means it was kept).
1133
- """
1134
- return FlashAttnKVPackedFunc.apply(
1135
- q,
1136
- kv,
1137
- dropout_p,
1138
- softmax_scale,
1139
- causal,
1140
- window_size,
1141
- softcap,
1142
- alibi_slopes,
1143
- deterministic,
1144
- return_attn_probs,
1145
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1146
- )
1147
-
1148
-
1149
- def flash_attn_func(
1150
- q,
1151
- k,
1152
- v,
1153
- dropout_p=0.0,
1154
- softmax_scale=None,
1155
- causal=False,
1156
- window_size=(-1, -1), # -1 means infinite context window
1157
- softcap=0.0, # 0.0 means deactivated
1158
- alibi_slopes=None,
1159
- deterministic=False,
1160
- return_attn_probs=False,
1161
- ):
1162
- """dropout_p should be set to 0.0 during evaluation
1163
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1164
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1165
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1166
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1167
-
1168
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1169
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1170
- 1 1 1 1 0
1171
- 1 1 1 1 1
1172
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1173
- 0 0
1174
- 0 0
1175
- 0 0
1176
- 1 0
1177
- 1 1
1178
- If the row of the mask is all zero, the output will be zero.
1179
-
1180
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1181
- will only attend to keys between
1182
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1183
-
1184
- Arguments:
1185
- q: (batch_size, seqlen, nheads, headdim)
1186
- k: (batch_size, seqlen, nheads_k, headdim)
1187
- v: (batch_size, seqlen, nheads_k, headdim)
1188
- dropout_p: float. Dropout probability.
1189
- softmax_scale: float. The scaling of QK^T before applying softmax.
1190
- Default to 1 / sqrt(headdim).
1191
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1192
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1193
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1194
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1195
- is added to the attention score of query i and key j.
1196
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1197
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1198
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1199
- testing only. The returned probabilities are not guaranteed to be correct
1200
- (they might not have the right scaling).
1201
- Return:
1202
- out: (batch_size, seqlen, nheads, headdim).
1203
- softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
1204
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1205
- normalization factor).
1206
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1207
- The output of softmax (possibly with different scaling). It also encodes the dropout
1208
- pattern (negative means that location was dropped, nonnegative means it was kept).
1209
- """
1210
- return FlashAttnFunc.apply(
1211
- q,
1212
- k,
1213
- v,
1214
- dropout_p,
1215
- softmax_scale,
1216
- causal,
1217
- window_size,
1218
- softcap,
1219
- alibi_slopes,
1220
- deterministic,
1221
- return_attn_probs,
1222
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1223
- )
1224
-
1225
-
1226
- def flash_attn_varlen_qkvpacked_func(
1227
- qkv,
1228
- cu_seqlens,
1229
- max_seqlen,
1230
- dropout_p=0.0,
1231
- softmax_scale=None,
1232
- causal=False,
1233
- window_size=(-1, -1), # -1 means infinite context window
1234
- softcap=0.0, # 0.0 means deactivated
1235
- alibi_slopes=None,
1236
- deterministic=False,
1237
- return_attn_probs=False,
1238
- ):
1239
- """dropout_p should be set to 0.0 during evaluation
1240
- If Q, K, V are already stacked into 1 tensor, this function will be faster than
1241
- calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
1242
- of the gradients of Q, K, V.
1243
- For multi-query and grouped-query attention (MQA/GQA), please see
1244
- flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
1245
-
1246
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1247
- will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
1248
-
1249
- Arguments:
1250
- qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
1251
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1252
- of the sequences in the batch, used to index into qkv.
1253
- max_seqlen: int. Maximum sequence length in the batch.
1254
- dropout_p: float. Dropout probability.
1255
- softmax_scale: float. The scaling of QK^T before applying softmax.
1256
- Default to 1 / sqrt(headdim).
1257
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1258
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1259
- softcap: float. Anything > 0 activates softcapping attention.
1260
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
1261
- is added to the attention score of query i and key j.
1262
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1263
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1264
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1265
- testing only. The returned probabilities are not guaranteed to be correct
1266
- (they might not have the right scaling).
1267
- Return:
1268
- out: (total, nheads, headdim).
1269
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1270
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1271
- normalization factor).
1272
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1273
- The output of softmax (possibly with different scaling). It also encodes the dropout
1274
- pattern (negative means that location was dropped, nonnegative means it was kept).
1275
- """
1276
- return FlashAttnVarlenQKVPackedFunc.apply(
1277
- qkv,
1278
- cu_seqlens,
1279
- max_seqlen,
1280
- dropout_p,
1281
- softmax_scale,
1282
- causal,
1283
- window_size,
1284
- softcap,
1285
- alibi_slopes,
1286
- deterministic,
1287
- return_attn_probs,
1288
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1289
- )
1290
-
1291
-
1292
- def flash_attn_varlen_kvpacked_func(
1293
- q,
1294
- kv,
1295
- cu_seqlens_q,
1296
- cu_seqlens_k,
1297
- max_seqlen_q,
1298
- max_seqlen_k,
1299
- dropout_p=0.0,
1300
- softmax_scale=None,
1301
- causal=False,
1302
- window_size=(-1, -1), # -1 means infinite context window
1303
- softcap=0.0, # 0.0 means deactivated
1304
- alibi_slopes=None,
1305
- deterministic=False,
1306
- return_attn_probs=False,
1307
- ):
1308
- """dropout_p should be set to 0.0 during evaluation
1309
- If K, V are already stacked into 1 tensor, this function will be faster than
1310
- calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
1311
- of the gradients of K, V.
1312
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1313
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1314
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1315
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1316
-
1317
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1318
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1319
- 1 1 1 1 0
1320
- 1 1 1 1 1
1321
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1322
- 0 0
1323
- 0 0
1324
- 0 0
1325
- 1 0
1326
- 1 1
1327
- If the row of the mask is all zero, the output will be zero.
1328
-
1329
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1330
- will only attend to keys between
1331
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1332
-
1333
- Arguments:
1334
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1335
- kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1336
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1337
- of the sequences in the batch, used to index into q.
1338
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1339
- of the sequences in the batch, used to index into kv.
1340
- max_seqlen_q: int. Maximum query sequence length in the batch.
1341
- max_seqlen_k: int. Maximum key sequence length in the batch.
1342
- dropout_p: float. Dropout probability.
1343
- softmax_scale: float. The scaling of QK^T before applying softmax.
1344
- Default to 1 / sqrt(headdim).
1345
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1346
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1347
- softcap: float. Anything > 0 activates softcapping attention.
1348
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1349
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1350
- is added to the attention score of query i and key j.
1351
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1352
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1353
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1354
- testing only. The returned probabilities are not guaranteed to be correct
1355
- (they might not have the right scaling).
1356
- Return:
1357
- out: (total, nheads, headdim).
1358
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1359
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1360
- normalization factor).
1361
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1362
- The output of softmax (possibly with different scaling). It also encodes the dropout
1363
- pattern (negative means that location was dropped, nonnegative means it was kept).
1364
- """
1365
- return FlashAttnVarlenKVPackedFunc.apply(
1366
- q,
1367
- kv,
1368
- cu_seqlens_q,
1369
- cu_seqlens_k,
1370
- max_seqlen_q,
1371
- max_seqlen_k,
1372
- dropout_p,
1373
- softmax_scale,
1374
- causal,
1375
- window_size,
1376
- softcap,
1377
- alibi_slopes,
1378
- deterministic,
1379
- return_attn_probs,
1380
- False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1381
- )
1382
-
1383
-
1384
- def flash_attn_varlen_func(
1385
- q,
1386
- k,
1387
- v,
1388
- cu_seqlens_q,
1389
- cu_seqlens_k,
1390
- max_seqlen_q,
1391
- max_seqlen_k,
1392
- dropout_p=0.0,
1393
- softmax_scale=None,
1394
- causal=False,
1395
- window_size=(-1, -1), # -1 means infinite context window
1396
- softcap=0.0, # 0.0 means deactivated
1397
- alibi_slopes=None,
1398
- deterministic=False,
1399
- return_attn_probs=False,
1400
- block_table=None,
1401
- ):
1402
- """dropout_p should be set to 0.0 during evaluation
1403
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1404
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1405
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1406
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1407
-
1408
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1409
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1410
- 1 1 1 1 0
1411
- 1 1 1 1 1
1412
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1413
- 0 0
1414
- 0 0
1415
- 0 0
1416
- 1 0
1417
- 1 1
1418
- If the row of the mask is all zero, the output will be zero.
1419
-
1420
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1421
- will only attend to keys between
1422
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1423
-
1424
- Arguments:
1425
- q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1426
- k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1427
- v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1428
- cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1429
- of the sequences in the batch, used to index into q.
1430
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1431
- of the sequences in the batch, used to index into kv.
1432
- max_seqlen_q: int. Maximum query sequence length in the batch.
1433
- max_seqlen_k: int. Maximum key sequence length in the batch.
1434
- dropout_p: float. Dropout probability.
1435
- softmax_scale: float. The scaling of QK^T before applying softmax.
1436
- Default to 1 / sqrt(headdim).
1437
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1438
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1439
- softcap: float. Anything > 0 activates softcapping attention.
1440
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1441
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1442
- is added to the attention score of query i and key j.
1443
- deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1444
- which is slightly slower and uses more memory. The forward pass is always deterministic.
1445
- return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1446
- testing only. The returned probabilities are not guaranteed to be correct
1447
- (they might not have the right scaling).
1448
- Return:
1449
- out: (total, nheads, headdim).
1450
- softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
1451
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1452
- normalization factor).
1453
- S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
1454
- The output of softmax (possibly with different scaling). It also encodes the dropout
1455
- pattern (negative means that location was dropped, nonnegative means it was kept).
1456
- """
1457
- return FlashAttnVarlenFunc.apply(
1458
- q,
1459
- k,
1460
- v,
1461
- cu_seqlens_q,
1462
- cu_seqlens_k,
1463
- max_seqlen_q,
1464
- max_seqlen_k,
1465
- dropout_p,
1466
- softmax_scale,
1467
- causal,
1468
- window_size,
1469
- softcap,
1470
- alibi_slopes,
1471
- deterministic,
1472
- return_attn_probs,
1473
- block_table,
1474
- False if _XPU_AVAILABLE or q.device.type == "cpu" else torch.is_grad_enabled(),
1475
- )
1476
-
1477
-
1478
- def flash_attn_with_kvcache(
1479
- q,
1480
- k_cache,
1481
- v_cache,
1482
- k=None,
1483
- v=None,
1484
- rotary_cos=None,
1485
- rotary_sin=None,
1486
- cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
1487
- cache_batch_idx: Optional[torch.Tensor] = None,
1488
- cache_leftpad: Optional[torch.Tensor] = None,
1489
- block_table: Optional[torch.Tensor] = None,
1490
- softmax_scale=None,
1491
- causal=False,
1492
- window_size=(-1, -1), # -1 means infinite context window
1493
- softcap=0.0, # 0.0 means deactivated
1494
- rotary_interleaved=True,
1495
- alibi_slopes=None,
1496
- num_splits=0,
1497
- return_softmax_lse=False,
1498
- ):
1499
- """
1500
- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
1501
- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
1502
- the previous step, and update them with the new keys/values from the current step, and do
1503
- attention with the updated cache, all in 1 kernel.
1504
-
1505
- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
1506
- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
1507
- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
1508
-
1509
- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
1510
- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
1511
- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
1512
- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
1513
- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
1514
- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
1515
-
1516
- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
1517
-
1518
- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
1519
- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1520
- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1521
- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1522
-
1523
- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1524
- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1525
- 1 1 1 1 0
1526
- 1 1 1 1 1
1527
- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1528
- 0 0
1529
- 0 0
1530
- 0 0
1531
- 1 0
1532
- 1 1
1533
- If the row of the mask is all zero, the output will be zero.
1534
-
1535
- If window_size != (-1, -1), implements sliding window local attention. Query at position i
1536
- will only attend to keys between
1537
- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1538
-
1539
- Note: Does not support backward pass.
1540
-
1541
- Arguments:
1542
- q: (batch_size, seqlen, nheads, headdim)
1543
- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
1544
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1545
- page_block_size must be a multiple of 256.
1546
- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
1547
- or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
1548
- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
1549
- k with k_cache, starting at the indices specified by cache_seqlens.
1550
- v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
1551
- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
1552
- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
1553
- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
1554
- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
1555
- KV cache.
1556
- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
1557
- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
1558
- If the indices are not distinct, and k and v are provided, the values updated in the cache
1559
- might come from any of the duplicate indices.
1560
- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
1561
- block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
1562
- softmax_scale: float. The scaling of QK^T before applying softmax.
1563
- Default to 1 / sqrt(headdim).
1564
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1565
- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1566
- softcap: float. Anything > 0 activates softcapping attention.
1567
- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
1568
- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
1569
- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
1570
- (i.e. GPT-NeoX style).
1571
- alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1572
- (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1573
- is added to the attention score of query i and key j.
1574
- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
1575
- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
1576
- to automatically determine the number of splits.
1577
- Don't change this unless you know what you are doing.
1578
- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
1579
-
1580
- Return:
1581
- out: (batch_size, seqlen, nheads, headdim).
1582
- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
1583
- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1584
- normalization factor).
1585
- """
1586
- assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
1587
- assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
1588
- q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1589
- if softmax_scale is None:
1590
- softmax_scale = q.shape[-1] ** (-0.5)
1591
- if cache_seqlens is not None and isinstance(cache_seqlens, int):
1592
- cache_seqlens = torch.full(
1593
- (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
1594
- )
1595
- cache_seqlens = maybe_contiguous(cache_seqlens)
1596
- cache_batch_idx = maybe_contiguous(cache_batch_idx)
1597
- block_table = maybe_contiguous(block_table)
1598
- out, softmax_lse = flash_attn.fwd_kvcache(
1599
- q,
1600
- k_cache,
1601
- v_cache,
1602
- k,
1603
- v,
1604
- cache_seqlens,
1605
- rotary_cos,
1606
- rotary_sin,
1607
- cache_batch_idx,
1608
- cache_leftpad,
1609
- block_table,
1610
- alibi_slopes,
1611
- None,
1612
- softmax_scale,
1613
- causal,
1614
- window_size[0],
1615
- window_size[1],
1616
- softcap,
1617
- rotary_interleaved,
1618
- num_splits,
1619
- )
1620
- return (out, softmax_lse) if return_softmax_lse else out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch210-cxx11-xpu20253-x86_64-linux/metadata.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "version": 1,
3
- "python-depends": []
4
- }
 
 
 
 
 
build/torch210-cxx11-xpu20253-x86_64-linux/ops/triton/rotary.py DELETED
@@ -1,186 +0,0 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- # As of 2025-04-23, we require triton >= 3.0
3
-
4
- from typing import Optional, Union
5
-
6
- import torch
7
-
8
- import triton
9
- import triton.language as tl
10
-
11
-
12
- @triton.jit
13
- def rotary_kernel(
14
- OUT, # Pointers to matrices
15
- X,
16
- COS,
17
- SIN,
18
- CU_SEQLENS,
19
- SEQLEN_OFFSETS, # this could be int or a pointer
20
- # Matrix dimensions
21
- seqlen,
22
- nheads,
23
- seqlen_ro,
24
- # strides
25
- stride_out_batch,
26
- stride_out_seqlen,
27
- stride_out_nheads,
28
- stride_out_headdim,
29
- stride_x_batch,
30
- stride_x_seqlen,
31
- stride_x_nheads,
32
- stride_x_headdim,
33
- # Meta-parameters
34
- # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that
35
- # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128
36
- ROTARY_DIM: tl.constexpr,
37
- IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
38
- IS_VARLEN: tl.constexpr,
39
- INTERLEAVED: tl.constexpr,
40
- CONJUGATE: tl.constexpr,
41
- BLOCK_H: tl.constexpr,
42
- BLOCK_M: tl.constexpr,
43
- ):
44
- BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM)
45
- ROTARY_DIM_HALF = ROTARY_DIM // 2
46
- pid_head = tl.program_id(axis=0)
47
- pid_m = tl.program_id(axis=1)
48
- pid_batch = tl.program_id(axis=2)
49
-
50
- if not IS_VARLEN:
51
- X = X + pid_batch * stride_x_batch
52
- OUT = OUT + pid_batch * stride_out_batch
53
- else:
54
- start_idx = tl.load(CU_SEQLENS + pid_batch)
55
- seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
56
- X = X + start_idx * stride_x_seqlen
57
- OUT = OUT + start_idx * stride_out_seqlen
58
-
59
- if pid_m * BLOCK_M >= seqlen:
60
- return
61
-
62
- rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H)
63
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
64
- if not IS_SEQLEN_OFFSETS_TENSOR:
65
- rm_cs = rm + SEQLEN_OFFSETS
66
- else:
67
- rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
68
-
69
- rk_half = tl.arange(0, BLOCK_K // 2)
70
- COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])
71
- SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])
72
- mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF)
73
- cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32)
74
- sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32)
75
- if CONJUGATE:
76
- sin = -sin
77
-
78
- if not INTERLEAVED:
79
- # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
80
- X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim)
81
- OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim)
82
- mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF)
83
- x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32)
84
- x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32)
85
- o0 = x0 * cos - x1 * sin
86
- o1 = x0 * sin + x1 * cos
87
- tl.store(OUT, o0, mask=mask)
88
- tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask)
89
- else:
90
- rk = tl.arange(0, BLOCK_K)
91
- X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim)
92
- OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim)
93
- mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM)
94
- x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
95
- x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2]))
96
- o0 = x0 * cos - x1 * sin
97
- o1 = x0 * sin + x1 * cos
98
- o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K])
99
- tl.store(OUT, o, mask=mask)
100
-
101
-
102
- def apply_rotary(
103
- x: torch.Tensor,
104
- cos: torch.Tensor,
105
- sin: torch.Tensor,
106
- seqlen_offsets: Union[int, torch.Tensor] = 0,
107
- cu_seqlens: Optional[torch.Tensor] = None,
108
- max_seqlen: Optional[int] = None,
109
- interleaved=False,
110
- inplace=False,
111
- conjugate=False,
112
- ) -> torch.Tensor:
113
- """
114
- Arguments:
115
- x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
116
- else (total_seqlen, nheads, headdim).
117
- cos: (seqlen_ro, rotary_dim / 2)
118
- sin: (seqlen_ro, rotary_dim / 2)
119
- seqlen_offsets: integer or integer tensor of size (batch,)
120
- cu_seqlens: (batch + 1,) or None
121
- max_seqlen: int
122
- Returns:
123
- y: (batch, seqlen, nheads, headdim)
124
- """
125
- is_varlen = cu_seqlens is not None
126
- if not is_varlen:
127
- batch, seqlen, nheads, headdim = x.shape
128
- else:
129
- assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
130
- total_seqlen, nheads, headdim = x.shape
131
- batch_p_1 = cu_seqlens.shape[0]
132
- batch = batch_p_1 - 1
133
- seqlen = max_seqlen
134
- seqlen_ro, rotary_dim = cos.shape
135
- assert sin.shape == cos.shape
136
- rotary_dim *= 2
137
- assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
138
- assert headdim <= 256, "Only support headdim <= 256"
139
- assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
140
-
141
- cos, sin = cos.contiguous(), sin.contiguous()
142
- if isinstance(seqlen_offsets, torch.Tensor):
143
- assert seqlen_offsets.shape == (batch,)
144
- assert seqlen_offsets.dtype in [torch.int32, torch.int64]
145
- seqlen_offsets = seqlen_offsets.contiguous()
146
- else:
147
- assert seqlen_offsets + seqlen <= seqlen_ro
148
-
149
- output = torch.empty_like(x) if not inplace else x
150
- if rotary_dim < headdim and not inplace:
151
- output[..., rotary_dim:].copy_(x[..., rotary_dim:])
152
-
153
- grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa
154
- BLOCK_M = 8 if rotary_dim <= 128 else 4
155
-
156
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
157
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
158
- device_ctx = torch.cuda.device(x.device.index) if x.device.type == 'cuda' else torch.xpu.device(x.device.index)
159
- with device_ctx:
160
- torch.library.wrap_triton(rotary_kernel)[grid](
161
- output, # data ptrs
162
- x,
163
- cos,
164
- sin,
165
- cu_seqlens,
166
- seqlen_offsets,
167
- seqlen, # shapes
168
- nheads,
169
- seqlen_ro,
170
- output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
171
- output.stride(-3), # seqlen_stride or total_seqlen_stride
172
- output.stride(-2), # nheads_stride
173
- output.stride(-1), # headdim_stride
174
- x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
175
- x.stride(-3), # seqlen stride or total_seqlen_stride
176
- x.stride(-2), # nheads stride
177
- x.stride(-1), # headdim stride
178
- rotary_dim,
179
- isinstance(seqlen_offsets, torch.Tensor),
180
- is_varlen,
181
- interleaved,
182
- conjugate,
183
- BLOCK_M=BLOCK_M,
184
- BLOCK_H=2,
185
- )
186
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux/_flash_attn2_588b404.abi3.so β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:247ade2063814573447dcb697fd39e738bcf5f0f5d40ac87eaf6cf6dba29298f
3
- size 448708992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:201b532a74bf5aeefdd6cfe7db479c2d089392ab53a34c699c56d78e225cd09a
3
+ size 445273568
build/{torch210-cxx11-cu128-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn2_588b404
3
- ops = torch.ops._flash_attn2_588b404
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn2_588b404::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_56449c1_dirty
3
+ ops = torch.ops._flash_attn_56449c1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_56449c1_dirty::{op_name}"
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/bert_padding.py RENAMED
File without changes
build/torch27-cxx11-cu118-x86_64-linux/{flash_attn2 β†’ flash_attn}/flash_attn_interface.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/layers/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/layers/patch_embed.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/layers/rotary.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/activations.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/fused_dense.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/layer_norm.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/rms_norm.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/cross_entropy.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/k_activations.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/layer_norm.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/linear.py RENAMED
File without changes
build/{torch210-cxx11-cpu-x86_64-linux β†’ torch27-cxx11-cu118-x86_64-linux/flash_attn}/ops/triton/mlp.py RENAMED
File without changes
build/torch27-cxx11-cu118-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/rotary.py RENAMED
File without changes
build/torch27-cxx11-cu118-x86_64-linux/flash_attn2/_flash_attn_9e27194.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:760f19632acec698b69e0898d336e1d5aa26d87318d26ab653366cc8d5a8eec7
3
- size 445273544
 
 
 
 
build/torch27-cxx11-cu118-x86_64-linux/flash_attn2/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _flash_attn_9e27194
3
- ops = torch.ops._flash_attn_9e27194
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_flash_attn_9e27194::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/__init__.py RENAMED
File without changes
build/{torch210-cxx11-xpu20253-x86_64-linux/_flash_attn2_588b404.abi3.so β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_56449c1_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b431c2d9fcbb6f0923c3dfcaab8ee5f6df980fd39877cd6ff5f44373cd02271
3
- size 10797184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:829746a9a9848d76f837c85613b9af3c367d51023134e99470bd208cddb0ba96
3
+ size 448639320
build/{torch210-cxx11-cu130-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn2_588b404
3
- ops = torch.ops._flash_attn2_588b404
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn2_588b404::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_56449c1_dirty
3
+ ops = torch.ops._flash_attn_56449c1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_56449c1_dirty::{op_name}"
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/bert_padding.py RENAMED
File without changes
build/torch27-cxx11-cu126-x86_64-linux/{flash_attn2 β†’ flash_attn}/flash_attn_interface.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/layers/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/layers/patch_embed.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/layers/rotary.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/ops/__init__.py RENAMED
File without changes
build/{torch210-cxx11-cu126-x86_64-linux β†’ torch27-cxx11-cu126-x86_64-linux/flash_attn}/ops/activations.py RENAMED
File without changes