Kernels
danieldk HF Staff commited on
Commit
98af189
Β·
verified Β·
1 Parent(s): dd401a6

Build uploaded using `kernels`.

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cpu-x86_64-linux/{_flash_attn2_5e9f49f.abi3.so β†’ _flash_attn2_588b404.abi3.so} +1 -1
  2. build/torch210-cxx11-cpu-x86_64-linux/_ops.py +3 -3
  3. build/torch210-cxx11-cpu-x86_64-linux/metadata.json +4 -1
  4. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +393 -0
  5. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2/_flash_attn_9e27194.abi3.so β†’ torch210-cxx11-cu126-x86_64-linux/_flash_attn2_588b404.abi3.so} +2 -2
  6. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/_ops.py +3 -3
  7. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/bert_padding.py +0 -0
  8. build/torch210-cxx11-cu126-x86_64-linux/flash_attn2/__init__.py +26 -0
  9. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/flash_attn_interface.py +29 -18
  10. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/layers/__init__.py +0 -0
  11. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/layers/patch_embed.py +0 -0
  12. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/layers/rotary.py +0 -0
  13. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +4 -0
  14. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/__init__.py +0 -0
  15. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/activations.py +0 -0
  16. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/fused_dense.py +0 -0
  17. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/layer_norm.py +0 -0
  18. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/rms_norm.py +0 -0
  19. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/__init__.py +0 -0
  20. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/cross_entropy.py +0 -0
  21. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/k_activations.py +0 -0
  22. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/layer_norm.py +0 -0
  23. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/linear.py +0 -0
  24. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/mlp.py +0 -0
  25. build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/rotary.py +2 -1
  26. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +393 -0
  27. build/torch210-cxx11-cu128-x86_64-linux/_flash_attn2_588b404.abi3.so +3 -0
  28. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/_ops.py +3 -3
  29. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/bert_padding.py +0 -0
  30. build/torch210-cxx11-cu128-x86_64-linux/flash_attn2/__init__.py +26 -0
  31. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/flash_attn_interface.py +29 -18
  32. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/layers/__init__.py +0 -0
  33. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/layers/patch_embed.py +0 -0
  34. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/layers/rotary.py +0 -0
  35. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +4 -0
  36. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/__init__.py +0 -0
  37. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/activations.py +0 -0
  38. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/fused_dense.py +0 -0
  39. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/layer_norm.py +0 -0
  40. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/rms_norm.py +0 -0
  41. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/__init__.py +0 -0
  42. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/cross_entropy.py +0 -0
  43. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/k_activations.py +0 -0
  44. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/layer_norm.py +0 -0
  45. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/linear.py +0 -0
  46. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/mlp.py +0 -0
  47. build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/rotary.py +2 -1
  48. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +393 -0
  49. build/torch210-cxx11-cu130-x86_64-linux/_flash_attn2_588b404.abi3.so +3 -0
  50. build/{torch28-cxx11-cu129-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu130-x86_64-linux}/_ops.py +3 -3
build/torch210-cxx11-cpu-x86_64-linux/{_flash_attn2_5e9f49f.abi3.so β†’ _flash_attn2_588b404.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5f51f42997f5a6c02f137bfba9add39400e48486aaed79a78ec5081215c487e3
3
  size 249504
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d90d30dbcf574c7a50f2c9774884370e71e1e177062c6a233fcc7e1940fffcb
3
  size 249504
build/torch210-cxx11-cpu-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn2_5e9f49f
3
- ops = torch.ops._flash_attn2_5e9f49f
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn2_5e9f49f::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn2_588b404
3
+ ops = torch.ops._flash_attn2_588b404
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn2_588b404::{op_name}"
build/torch210-cxx11-cpu-x86_64-linux/metadata.json CHANGED
@@ -1 +1,4 @@
1
- {"python-depends":[]}
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "python-depends": []
4
+ }
build/torch210-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2/_flash_attn_9e27194.abi3.so β†’ torch210-cxx11-cu126-x86_64-linux/_flash_attn2_588b404.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b567f9d044f6ca11f5a5fa2ba6d0fdb7573b7abcfe8d6ef875df44703ed020e1
3
- size 448643576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:247ade2063814573447dcb697fd39e738bcf5f0f5d40ac87eaf6cf6dba29298f
3
+ size 448708992
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_9e27194
3
- ops = torch.ops._flash_attn_9e27194
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_9e27194::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn2_588b404
3
+ ops = torch.ops._flash_attn2_588b404
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn2_588b404::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/bert_padding.py RENAMED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/flash_attn2/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/flash_attn_interface.py RENAMED
@@ -10,12 +10,12 @@ import os
10
  # # We need to import the CUDA kernels after importing torch
11
  # USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
12
  # if USE_TRITON_ROCM:
13
- # from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
14
  # else:
15
- # import flash_attn_2_cuda as flash_attn_gpu
16
 
17
 
18
- from ._ops import ops as flash_attn_gpu
19
 
20
  # # isort: on
21
 
@@ -23,6 +23,17 @@ def maybe_contiguous(x):
23
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def _get_block_size_n(device, head_dim, is_dropout, is_causal):
27
  # This should match the block sizes in the CUDA kernel
28
  assert head_dim <= 256
@@ -76,7 +87,7 @@ else:
76
  _torch_register_fake_wrapper = noop_register_fake_wrapper
77
 
78
 
79
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
80
  def _flash_attn_forward(
81
  q: torch.Tensor,
82
  k: torch.Tensor,
@@ -91,7 +102,7 @@ def _flash_attn_forward(
91
  return_softmax: bool
92
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
93
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
94
- out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
95
  q,
96
  k,
97
  v,
@@ -142,7 +153,7 @@ else:
142
  _wrapped_flash_attn_forward = _flash_attn_forward
143
 
144
 
145
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
146
  def _flash_attn_varlen_forward(
147
  q: torch.Tensor,
148
  k: torch.Tensor,
@@ -165,7 +176,7 @@ def _flash_attn_varlen_forward(
165
  zero_tensors: bool = False,
166
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
167
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
168
- out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
169
  q,
170
  k,
171
  v,
@@ -237,7 +248,7 @@ else:
237
  _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
238
 
239
 
240
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
241
  def _flash_attn_backward(
242
  dout: torch.Tensor,
243
  q: torch.Tensor,
@@ -265,7 +276,7 @@ def _flash_attn_backward(
265
  dk,
266
  dv,
267
  softmax_d,
268
- ) = flash_attn_gpu.bwd(
269
  dout,
270
  q,
271
  k,
@@ -329,7 +340,7 @@ else:
329
  _wrapped_flash_attn_backward = _flash_attn_backward
330
 
331
 
332
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
333
  def _flash_attn_varlen_backward(
334
  dout: torch.Tensor,
335
  q: torch.Tensor,
@@ -362,7 +373,7 @@ def _flash_attn_varlen_backward(
362
  dk,
363
  dv,
364
  softmax_d,
365
- ) = flash_attn_gpu.varlen_bwd(
366
  dout,
367
  q,
368
  k,
@@ -1053,7 +1064,7 @@ def flash_attn_qkvpacked_func(
1053
  alibi_slopes,
1054
  deterministic,
1055
  return_attn_probs,
1056
- torch.is_grad_enabled(),
1057
  )
1058
 
1059
 
@@ -1131,7 +1142,7 @@ def flash_attn_kvpacked_func(
1131
  alibi_slopes,
1132
  deterministic,
1133
  return_attn_probs,
1134
- torch.is_grad_enabled(),
1135
  )
1136
 
1137
 
@@ -1208,7 +1219,7 @@ def flash_attn_func(
1208
  alibi_slopes,
1209
  deterministic,
1210
  return_attn_probs,
1211
- torch.is_grad_enabled(),
1212
  )
1213
 
1214
 
@@ -1274,7 +1285,7 @@ def flash_attn_varlen_qkvpacked_func(
1274
  alibi_slopes,
1275
  deterministic,
1276
  return_attn_probs,
1277
- torch.is_grad_enabled(),
1278
  )
1279
 
1280
 
@@ -1366,7 +1377,7 @@ def flash_attn_varlen_kvpacked_func(
1366
  alibi_slopes,
1367
  deterministic,
1368
  return_attn_probs,
1369
- torch.is_grad_enabled(),
1370
  )
1371
 
1372
 
@@ -1460,7 +1471,7 @@ def flash_attn_varlen_func(
1460
  deterministic,
1461
  return_attn_probs,
1462
  block_table,
1463
- torch.is_grad_enabled(),
1464
  )
1465
 
1466
 
@@ -1584,7 +1595,7 @@ def flash_attn_with_kvcache(
1584
  cache_seqlens = maybe_contiguous(cache_seqlens)
1585
  cache_batch_idx = maybe_contiguous(cache_batch_idx)
1586
  block_table = maybe_contiguous(block_table)
1587
- out, softmax_lse = flash_attn_gpu.fwd_kvcache(
1588
  q,
1589
  k_cache,
1590
  v_cache,
 
10
  # # We need to import the CUDA kernels after importing torch
11
  # USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
12
  # if USE_TRITON_ROCM:
13
+ # from .flash_attn_triton_amd import interface_fa as flash_attn
14
  # else:
15
+ # import flash_attn_2_cuda as flash_attn
16
 
17
 
18
+ from ._ops import ops as flash_attn
19
 
20
  # # isort: on
21
 
 
23
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
24
 
25
 
26
+ def _get_device():
27
+ if torch.xpu.is_available():
28
+ return "xpu"
29
+ elif torch.cuda.is_available():
30
+ return "cuda"
31
+ else:
32
+ return "cpu"
33
+
34
+ _XPU_AVAILABLE = torch.xpu.is_available() if hasattr(torch, "xpu") else False # TODO remove hasattr check when bwd is supported on XPU
35
+
36
+
37
  def _get_block_size_n(device, head_dim, is_dropout, is_causal):
38
  # This should match the block sizes in the CUDA kernel
39
  assert head_dim <= 256
 
87
  _torch_register_fake_wrapper = noop_register_fake_wrapper
88
 
89
 
90
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types=_get_device())
91
  def _flash_attn_forward(
92
  q: torch.Tensor,
93
  k: torch.Tensor,
 
102
  return_softmax: bool
103
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
105
+ out, softmax_lse, S_dmask, rng_state = flash_attn.fwd(
106
  q,
107
  k,
108
  v,
 
153
  _wrapped_flash_attn_forward = _flash_attn_forward
154
 
155
 
156
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types=_get_device())
157
  def _flash_attn_varlen_forward(
158
  q: torch.Tensor,
159
  k: torch.Tensor,
 
176
  zero_tensors: bool = False,
177
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
178
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
179
+ out, softmax_lse, S_dmask, rng_state = flash_attn.varlen_fwd(
180
  q,
181
  k,
182
  v,
 
248
  _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
249
 
250
 
251
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
252
  def _flash_attn_backward(
253
  dout: torch.Tensor,
254
  q: torch.Tensor,
 
276
  dk,
277
  dv,
278
  softmax_d,
279
+ ) = flash_attn.bwd(
280
  dout,
281
  q,
282
  k,
 
340
  _wrapped_flash_attn_backward = _flash_attn_backward
341
 
342
 
343
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
344
  def _flash_attn_varlen_backward(
345
  dout: torch.Tensor,
346
  q: torch.Tensor,
 
373
  dk,
374
  dv,
375
  softmax_d,
376
+ ) = flash_attn.varlen_bwd(
377
  dout,
378
  q,
379
  k,
 
1064
  alibi_slopes,
1065
  deterministic,
1066
  return_attn_probs,
1067
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1068
  )
1069
 
1070
 
 
1142
  alibi_slopes,
1143
  deterministic,
1144
  return_attn_probs,
1145
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1146
  )
1147
 
1148
 
 
1219
  alibi_slopes,
1220
  deterministic,
1221
  return_attn_probs,
1222
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1223
  )
1224
 
1225
 
 
1285
  alibi_slopes,
1286
  deterministic,
1287
  return_attn_probs,
1288
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1289
  )
1290
 
1291
 
 
1377
  alibi_slopes,
1378
  deterministic,
1379
  return_attn_probs,
1380
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1381
  )
1382
 
1383
 
 
1471
  deterministic,
1472
  return_attn_probs,
1473
  block_table,
1474
+ False if _XPU_AVAILABLE or q.device.type == "cpu" else torch.is_grad_enabled(),
1475
  )
1476
 
1477
 
 
1595
  cache_seqlens = maybe_contiguous(cache_seqlens)
1596
  cache_batch_idx = maybe_contiguous(cache_batch_idx)
1597
  block_table = maybe_contiguous(block_table)
1598
+ out, softmax_lse = flash_attn.fwd_kvcache(
1599
  q,
1600
  k_cache,
1601
  v_cache,
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/layers/__init__.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/layers/patch_embed.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/layers/rotary.py RENAMED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "python-depends": []
4
+ }
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/__init__.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/activations.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/fused_dense.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/layer_norm.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/rms_norm.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/__init__.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/cross_entropy.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/k_activations.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/layer_norm.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/linear.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/mlp.py RENAMED
File without changes
build/{torch28-cxx11-cu126-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu126-x86_64-linux}/ops/triton/rotary.py RENAMED
@@ -155,7 +155,8 @@ def apply_rotary(
155
 
156
  # Need this, otherwise Triton tries to launch from cuda:0 and we get
157
  # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
158
- with torch.cuda.device(x.device.index):
 
159
  torch.library.wrap_triton(rotary_kernel)[grid](
160
  output, # data ptrs
161
  x,
 
155
 
156
  # Need this, otherwise Triton tries to launch from cuda:0 and we get
157
  # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
158
+ device_ctx = torch.cuda.device(x.device.index) if x.device.type == 'cuda' else torch.xpu.device(x.device.index)
159
+ with device_ctx:
160
  torch.library.wrap_triton(rotary_kernel)[grid](
161
  output, # data ptrs
162
  x,
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/torch210-cxx11-cu128-x86_64-linux/_flash_attn2_588b404.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09cfe096dc8f0010e99225d44263e4d9172d4b542d48d656b3b9fd718ca55b7d
3
+ size 1037803376
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_9e27194
3
- ops = torch.ops._flash_attn_9e27194
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_9e27194::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn2_588b404
3
+ ops = torch.ops._flash_attn2_588b404
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn2_588b404::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/bert_padding.py RENAMED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/flash_attn2/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/flash_attn_interface.py RENAMED
@@ -10,12 +10,12 @@ import os
10
  # # We need to import the CUDA kernels after importing torch
11
  # USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
12
  # if USE_TRITON_ROCM:
13
- # from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
14
  # else:
15
- # import flash_attn_2_cuda as flash_attn_gpu
16
 
17
 
18
- from ._ops import ops as flash_attn_gpu
19
 
20
  # # isort: on
21
 
@@ -23,6 +23,17 @@ def maybe_contiguous(x):
23
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def _get_block_size_n(device, head_dim, is_dropout, is_causal):
27
  # This should match the block sizes in the CUDA kernel
28
  assert head_dim <= 256
@@ -76,7 +87,7 @@ else:
76
  _torch_register_fake_wrapper = noop_register_fake_wrapper
77
 
78
 
79
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
80
  def _flash_attn_forward(
81
  q: torch.Tensor,
82
  k: torch.Tensor,
@@ -91,7 +102,7 @@ def _flash_attn_forward(
91
  return_softmax: bool
92
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
93
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
94
- out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
95
  q,
96
  k,
97
  v,
@@ -142,7 +153,7 @@ else:
142
  _wrapped_flash_attn_forward = _flash_attn_forward
143
 
144
 
145
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
146
  def _flash_attn_varlen_forward(
147
  q: torch.Tensor,
148
  k: torch.Tensor,
@@ -165,7 +176,7 @@ def _flash_attn_varlen_forward(
165
  zero_tensors: bool = False,
166
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
167
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
168
- out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
169
  q,
170
  k,
171
  v,
@@ -237,7 +248,7 @@ else:
237
  _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
238
 
239
 
240
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
241
  def _flash_attn_backward(
242
  dout: torch.Tensor,
243
  q: torch.Tensor,
@@ -265,7 +276,7 @@ def _flash_attn_backward(
265
  dk,
266
  dv,
267
  softmax_d,
268
- ) = flash_attn_gpu.bwd(
269
  dout,
270
  q,
271
  k,
@@ -329,7 +340,7 @@ else:
329
  _wrapped_flash_attn_backward = _flash_attn_backward
330
 
331
 
332
- @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
333
  def _flash_attn_varlen_backward(
334
  dout: torch.Tensor,
335
  q: torch.Tensor,
@@ -362,7 +373,7 @@ def _flash_attn_varlen_backward(
362
  dk,
363
  dv,
364
  softmax_d,
365
- ) = flash_attn_gpu.varlen_bwd(
366
  dout,
367
  q,
368
  k,
@@ -1053,7 +1064,7 @@ def flash_attn_qkvpacked_func(
1053
  alibi_slopes,
1054
  deterministic,
1055
  return_attn_probs,
1056
- torch.is_grad_enabled(),
1057
  )
1058
 
1059
 
@@ -1131,7 +1142,7 @@ def flash_attn_kvpacked_func(
1131
  alibi_slopes,
1132
  deterministic,
1133
  return_attn_probs,
1134
- torch.is_grad_enabled(),
1135
  )
1136
 
1137
 
@@ -1208,7 +1219,7 @@ def flash_attn_func(
1208
  alibi_slopes,
1209
  deterministic,
1210
  return_attn_probs,
1211
- torch.is_grad_enabled(),
1212
  )
1213
 
1214
 
@@ -1274,7 +1285,7 @@ def flash_attn_varlen_qkvpacked_func(
1274
  alibi_slopes,
1275
  deterministic,
1276
  return_attn_probs,
1277
- torch.is_grad_enabled(),
1278
  )
1279
 
1280
 
@@ -1366,7 +1377,7 @@ def flash_attn_varlen_kvpacked_func(
1366
  alibi_slopes,
1367
  deterministic,
1368
  return_attn_probs,
1369
- torch.is_grad_enabled(),
1370
  )
1371
 
1372
 
@@ -1460,7 +1471,7 @@ def flash_attn_varlen_func(
1460
  deterministic,
1461
  return_attn_probs,
1462
  block_table,
1463
- torch.is_grad_enabled(),
1464
  )
1465
 
1466
 
@@ -1584,7 +1595,7 @@ def flash_attn_with_kvcache(
1584
  cache_seqlens = maybe_contiguous(cache_seqlens)
1585
  cache_batch_idx = maybe_contiguous(cache_batch_idx)
1586
  block_table = maybe_contiguous(block_table)
1587
- out, softmax_lse = flash_attn_gpu.fwd_kvcache(
1588
  q,
1589
  k_cache,
1590
  v_cache,
 
10
  # # We need to import the CUDA kernels after importing torch
11
  # USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
12
  # if USE_TRITON_ROCM:
13
+ # from .flash_attn_triton_amd import interface_fa as flash_attn
14
  # else:
15
+ # import flash_attn_2_cuda as flash_attn
16
 
17
 
18
+ from ._ops import ops as flash_attn
19
 
20
  # # isort: on
21
 
 
23
  return x.contiguous() if x is not None and x.stride(-1) != 1 else x
24
 
25
 
26
+ def _get_device():
27
+ if torch.xpu.is_available():
28
+ return "xpu"
29
+ elif torch.cuda.is_available():
30
+ return "cuda"
31
+ else:
32
+ return "cpu"
33
+
34
+ _XPU_AVAILABLE = torch.xpu.is_available() if hasattr(torch, "xpu") else False # TODO remove hasattr check when bwd is supported on XPU
35
+
36
+
37
  def _get_block_size_n(device, head_dim, is_dropout, is_causal):
38
  # This should match the block sizes in the CUDA kernel
39
  assert head_dim <= 256
 
87
  _torch_register_fake_wrapper = noop_register_fake_wrapper
88
 
89
 
90
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types=_get_device())
91
  def _flash_attn_forward(
92
  q: torch.Tensor,
93
  k: torch.Tensor,
 
102
  return_softmax: bool
103
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
105
+ out, softmax_lse, S_dmask, rng_state = flash_attn.fwd(
106
  q,
107
  k,
108
  v,
 
153
  _wrapped_flash_attn_forward = _flash_attn_forward
154
 
155
 
156
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types=_get_device())
157
  def _flash_attn_varlen_forward(
158
  q: torch.Tensor,
159
  k: torch.Tensor,
 
176
  zero_tensors: bool = False,
177
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
178
  q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
179
+ out, softmax_lse, S_dmask, rng_state = flash_attn.varlen_fwd(
180
  q,
181
  k,
182
  v,
 
248
  _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
249
 
250
 
251
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
252
  def _flash_attn_backward(
253
  dout: torch.Tensor,
254
  q: torch.Tensor,
 
276
  dk,
277
  dv,
278
  softmax_d,
279
+ ) = flash_attn.bwd(
280
  dout,
281
  q,
282
  k,
 
340
  _wrapped_flash_attn_backward = _flash_attn_backward
341
 
342
 
343
+ @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types=_get_device())
344
  def _flash_attn_varlen_backward(
345
  dout: torch.Tensor,
346
  q: torch.Tensor,
 
373
  dk,
374
  dv,
375
  softmax_d,
376
+ ) = flash_attn.varlen_bwd(
377
  dout,
378
  q,
379
  k,
 
1064
  alibi_slopes,
1065
  deterministic,
1066
  return_attn_probs,
1067
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1068
  )
1069
 
1070
 
 
1142
  alibi_slopes,
1143
  deterministic,
1144
  return_attn_probs,
1145
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1146
  )
1147
 
1148
 
 
1219
  alibi_slopes,
1220
  deterministic,
1221
  return_attn_probs,
1222
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1223
  )
1224
 
1225
 
 
1285
  alibi_slopes,
1286
  deterministic,
1287
  return_attn_probs,
1288
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1289
  )
1290
 
1291
 
 
1377
  alibi_slopes,
1378
  deterministic,
1379
  return_attn_probs,
1380
+ False if _XPU_AVAILABLE else torch.is_grad_enabled(),
1381
  )
1382
 
1383
 
 
1471
  deterministic,
1472
  return_attn_probs,
1473
  block_table,
1474
+ False if _XPU_AVAILABLE or q.device.type == "cpu" else torch.is_grad_enabled(),
1475
  )
1476
 
1477
 
 
1595
  cache_seqlens = maybe_contiguous(cache_seqlens)
1596
  cache_batch_idx = maybe_contiguous(cache_batch_idx)
1597
  block_table = maybe_contiguous(block_table)
1598
+ out, softmax_lse = flash_attn.fwd_kvcache(
1599
  q,
1600
  k_cache,
1601
  v_cache,
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/layers/__init__.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/layers/patch_embed.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/layers/rotary.py RENAMED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "python-depends": []
4
+ }
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/__init__.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/activations.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/fused_dense.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/layer_norm.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/rms_norm.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/__init__.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/cross_entropy.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/k_activations.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/layer_norm.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/linear.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/mlp.py RENAMED
File without changes
build/{torch28-cxx11-cu128-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu128-x86_64-linux}/ops/triton/rotary.py RENAMED
@@ -155,7 +155,8 @@ def apply_rotary(
155
 
156
  # Need this, otherwise Triton tries to launch from cuda:0 and we get
157
  # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
158
- with torch.cuda.device(x.device.index):
 
159
  torch.library.wrap_triton(rotary_kernel)[grid](
160
  output, # data ptrs
161
  x,
 
155
 
156
  # Need this, otherwise Triton tries to launch from cuda:0 and we get
157
  # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
158
+ device_ctx = torch.cuda.device(x.device.index) if x.device.type == 'cuda' else torch.xpu.device(x.device.index)
159
+ with device_ctx:
160
  torch.library.wrap_triton(rotary_kernel)[grid](
161
  output, # data ptrs
162
  x,
build/torch210-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/torch210-cxx11-cu130-x86_64-linux/_flash_attn2_588b404.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:196d3756a7d099f5e23ddd53ebc47aadf558a96e1d7873f5a14faec09bb7b707
3
+ size 1009055064
build/{torch28-cxx11-cu129-x86_64-linux/flash_attn2 β†’ torch210-cxx11-cu130-x86_64-linux}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_9e27194
3
- ops = torch.ops._flash_attn_9e27194
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_9e27194::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn2_588b404
3
+ ops = torch.ops._flash_attn2_588b404
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn2_588b404::{op_name}"