Kernels
danieldk HF Staff commited on
Commit
172e232
Β·
verified Β·
1 Parent(s): 4592a27

Build uploaded using `kernels`.

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. build/torch28-cxx11-xpu20251-x86_64-linux/__init__.py +393 -0
  2. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2/_flash_attn2_870e782_dirty.abi3.so β†’ _flash_attn2_5dab8ba_dirty.abi3.so} +2 -2
  3. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2/_ops.py β†’ _ops.py} +3 -3
  4. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2/bert_padding.py β†’ bert_padding.py} +0 -0
  5. build/torch28-cxx11-xpu20251-x86_64-linux/flash_attn/__init__.py +393 -0
  6. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2/_flash_attn2_870e782_dirty.abi3.so β†’ torch28-cxx11-xpu20251-x86_64-linux/flash_attn/_flash_attn_c984dd4_dirty.abi3.so} +2 -2
  7. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux/flash_attn}/_ops.py +3 -3
  8. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux/flash_attn}/bert_padding.py +0 -0
  9. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/flash_attn_interface.py +0 -0
  10. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/layers/__init__.py +0 -0
  11. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/layers/patch_embed.py +0 -0
  12. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/layers/rotary.py +0 -0
  13. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/__init__.py +0 -0
  14. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/activations.py +0 -0
  15. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/fused_dense.py +0 -0
  16. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/layer_norm.py +0 -0
  17. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/rms_norm.py +0 -0
  18. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/__init__.py +0 -0
  19. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/cross_entropy.py +0 -0
  20. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/k_activations.py +0 -0
  21. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/layer_norm.py +0 -0
  22. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/linear.py +0 -0
  23. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/mlp.py +0 -0
  24. build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/rotary.py +0 -0
  25. build/torch28-cxx11-xpu20251-x86_64-linux/flash_attn2/__init__.py +26 -393
  26. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/flash_attn_interface.py +0 -0
  27. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/layers/__init__.py +0 -0
  28. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/layers/patch_embed.py +0 -0
  29. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/layers/rotary.py +0 -0
  30. build/torch28-cxx11-xpu20251-x86_64-linux/metadata.json +1 -0
  31. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/__init__.py +0 -0
  32. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/activations.py +0 -0
  33. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/fused_dense.py +0 -0
  34. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/layer_norm.py +0 -0
  35. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/rms_norm.py +0 -0
  36. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/__init__.py +0 -0
  37. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/cross_entropy.py +0 -0
  38. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/k_activations.py +0 -0
  39. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/layer_norm.py +0 -0
  40. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/linear.py +0 -0
  41. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/mlp.py +0 -0
  42. build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/rotary.py +0 -0
  43. build/torch29-cxx11-xpu20252-x86_64-linux/__init__.py +393 -0
  44. build/torch29-cxx11-xpu20252-x86_64-linux/_flash_attn2_5dab8ba_dirty.abi3.so +3 -0
  45. build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py +9 -0
  46. build/torch29-cxx11-xpu20252-x86_64-linux/bert_padding.py +218 -0
  47. build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/__init__.py +393 -0
  48. build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/_flash_attn_c984dd4_dirty.abi3.so +3 -0
  49. build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/_ops.py +9 -0
  50. build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/bert_padding.py +218 -0
build/torch28-cxx11-xpu20251-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2/_flash_attn2_870e782_dirty.abi3.so β†’ _flash_attn2_5dab8ba_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dc780f13b456fb2bc84975e2d50c257650a5b109b86c92376621eea7a983c3cd
3
- size 5959592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f14e01c60f4a293eab27d1b34e072c8b6e37ca3a7e9cbd5b6a2bb83c195579bb
3
+ size 8973288
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2/_ops.py β†’ _ops.py} RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn2_870e782_dirty
3
- ops = torch.ops._flash_attn2_870e782_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn2_870e782_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn2_5dab8ba_dirty
3
+ ops = torch.ops._flash_attn2_5dab8ba_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn2_5dab8ba_dirty::{op_name}"
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2/bert_padding.py β†’ bert_padding.py} RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2/_flash_attn2_870e782_dirty.abi3.so β†’ torch28-cxx11-xpu20251-x86_64-linux/flash_attn/_flash_attn_c984dd4_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c11add54a0ef55746db2b14aa68814fe2440923444872747902d10f232bc6273
3
- size 5127712
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e6ca073b589dbefd15e0160369a130677854636cae9de41f29ab6cb8d4c2123
3
+ size 3730720
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux/flash_attn}/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn2_870e782_dirty
3
- ops = torch.ops._flash_attn2_870e782_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn2_870e782_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_c984dd4_dirty
3
+ ops = torch.ops._flash_attn_c984dd4_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_c984dd4_dirty::{op_name}"
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux/flash_attn}/bert_padding.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/flash_attn_interface.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/layers/__init__.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/layers/patch_embed.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/layers/rotary.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/__init__.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/activations.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/fused_dense.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/layer_norm.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/rms_norm.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/__init__.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/cross_entropy.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/k_activations.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/layer_norm.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/linear.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/mlp.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/{flash_attn2 β†’ flash_attn}/ops/triton/rotary.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/flash_attn2/__init__.py CHANGED
@@ -1,393 +1,26 @@
1
- from typing import Optional, List
2
- import torch
3
- from ._ops import ops as flash_attn_ops
4
- from .flash_attn_interface import (
5
- flash_attn_func,
6
- flash_attn_kvpacked_func,
7
- flash_attn_qkvpacked_func,
8
- flash_attn_varlen_func,
9
- flash_attn_varlen_kvpacked_func,
10
- flash_attn_varlen_qkvpacked_func,
11
- flash_attn_with_kvcache,
12
- )
13
-
14
-
15
- def fwd(
16
- q: torch.Tensor,
17
- k: torch.Tensor,
18
- v: torch.Tensor,
19
- out: Optional[torch.Tensor] = None,
20
- alibi_slopes: Optional[torch.Tensor] = None,
21
- p_dropout: float = 0.0,
22
- softmax_scale: Optional[float] = None,
23
- is_causal: bool = False,
24
- window_size_left: int = -1,
25
- window_size_right: int = -1,
26
- softcap: float = 0.0,
27
- return_softmax: bool = False,
28
- gen: Optional[torch.Generator] = None,
29
- ) -> List[torch.Tensor]:
30
- """
31
- Forward pass for multi-head attention.
32
-
33
- Args:
34
- q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
- k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
- v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
- out: Optional output tensor, same shape as q
38
- alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
- p_dropout: Dropout probability
40
- softmax_scale: Scale factor for softmax
41
- is_causal: Whether to use causal attention
42
- window_size_left: Window size for left context (-1 for unlimited)
43
- window_size_right: Window size for right context (-1 for unlimited)
44
- softcap: Soft cap for attention weights
45
- return_softmax: Whether to return softmax weights
46
- gen: Optional random number generator
47
-
48
- Returns:
49
- List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
- """
51
- if softmax_scale is None:
52
- attention_head_dim = q.shape[-1]
53
- softmax_scale = 1.0 / (attention_head_dim**0.5)
54
-
55
- return flash_attn_ops.fwd(
56
- q,
57
- k,
58
- v,
59
- out,
60
- alibi_slopes,
61
- p_dropout,
62
- softmax_scale,
63
- is_causal,
64
- window_size_left,
65
- window_size_right,
66
- softcap,
67
- return_softmax,
68
- gen,
69
- )
70
-
71
-
72
- def varlen_fwd(
73
- q: torch.Tensor,
74
- k: torch.Tensor,
75
- v: torch.Tensor,
76
- cu_seqlens_q: torch.Tensor,
77
- cu_seqlens_k: torch.Tensor,
78
- out: Optional[torch.Tensor] = None,
79
- seqused_k: Optional[torch.Tensor] = None,
80
- leftpad_k: Optional[torch.Tensor] = None,
81
- block_table: Optional[torch.Tensor] = None,
82
- alibi_slopes: Optional[torch.Tensor] = None,
83
- max_seqlen_q: int = 0,
84
- max_seqlen_k: int = 0,
85
- p_dropout: float = 0.0,
86
- softmax_scale: Optional[float] = None,
87
- zero_tensors: bool = False,
88
- is_causal: bool = False,
89
- window_size_left: int = -1,
90
- window_size_right: int = -1,
91
- softcap: float = 0.0,
92
- return_softmax: bool = False,
93
- gen: Optional[torch.Generator] = None,
94
- ) -> List[torch.Tensor]:
95
- """
96
- Forward pass for multi-head attention with variable sequence lengths.
97
-
98
- Args:
99
- q: Query tensor of shape [total_q, num_heads, head_size]
100
- k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
- v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
- cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
- cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
- out: Optional output tensor of shape [total_q, num_heads, head_size]
105
- seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
- leftpad_k: Optional left padding for keys of shape [batch_size]
107
- block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
- alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
- max_seqlen_q: Maximum sequence length for queries
110
- max_seqlen_k: Maximum sequence length for keys
111
- p_dropout: Dropout probability
112
- softmax_scale: Scale factor for softmax
113
- zero_tensors: Whether to zero tensors before computation
114
- is_causal: Whether to use causal attention
115
- window_size_left: Window size for left context (-1 for unlimited)
116
- window_size_right: Window size for right context (-1 for unlimited)
117
- softcap: Soft cap for attention weights
118
- return_softmax: Whether to return softmax weights
119
- gen: Optional random number generator
120
-
121
- Returns:
122
- List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
- """
124
- if softmax_scale is None:
125
- attention_head_dim = q.shape[-1]
126
- softmax_scale = 1.0 / (attention_head_dim**0.5)
127
-
128
- return flash_attn_ops.varlen_fwd(
129
- q,
130
- k,
131
- v,
132
- out,
133
- cu_seqlens_q,
134
- cu_seqlens_k,
135
- seqused_k,
136
- leftpad_k,
137
- block_table,
138
- alibi_slopes,
139
- max_seqlen_q,
140
- max_seqlen_k,
141
- p_dropout,
142
- softmax_scale,
143
- zero_tensors,
144
- is_causal,
145
- window_size_left,
146
- window_size_right,
147
- softcap,
148
- return_softmax,
149
- gen,
150
- )
151
-
152
-
153
- def bwd(
154
- dout: torch.Tensor,
155
- q: torch.Tensor,
156
- k: torch.Tensor,
157
- v: torch.Tensor,
158
- out: torch.Tensor,
159
- softmax_lse: torch.Tensor,
160
- dq: Optional[torch.Tensor] = None,
161
- dk: Optional[torch.Tensor] = None,
162
- dv: Optional[torch.Tensor] = None,
163
- alibi_slopes: Optional[torch.Tensor] = None,
164
- p_dropout: float = 0.0,
165
- softmax_scale: Optional[float] = None,
166
- is_causal: bool = False,
167
- window_size_left: int = -1,
168
- window_size_right: int = -1,
169
- softcap: float = 0.0,
170
- deterministic: bool = False,
171
- gen: Optional[torch.Generator] = None,
172
- rng_state: Optional[torch.Tensor] = None,
173
- ) -> List[torch.Tensor]:
174
- """
175
- Backward pass for multi-head attention.
176
-
177
- Args:
178
- dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
- q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
- k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
- v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
- out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
- softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
- dq: Optional gradient tensor for queries, same shape as q
185
- dk: Optional gradient tensor for keys, same shape as k
186
- dv: Optional gradient tensor for values, same shape as v
187
- alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
- p_dropout: Dropout probability
189
- softmax_scale: Scale factor for softmax
190
- is_causal: Whether to use causal attention
191
- window_size_left: Window size for left context (-1 for unlimited)
192
- window_size_right: Window size for right context (-1 for unlimited)
193
- softcap: Soft cap for attention weights
194
- deterministic: Whether to use deterministic algorithms
195
- gen: Optional random number generator
196
- rng_state: Optional RNG state from forward pass
197
-
198
- Returns:
199
- List of tensors: [dq, dk, dv]
200
- """
201
- if softmax_scale is None:
202
- attention_head_dim = q.shape[-1]
203
- softmax_scale = 1.0 / (attention_head_dim**0.5)
204
-
205
- return flash_attn_ops.bwd(
206
- dout,
207
- q,
208
- k,
209
- v,
210
- out,
211
- softmax_lse,
212
- dq,
213
- dk,
214
- dv,
215
- alibi_slopes,
216
- p_dropout,
217
- softmax_scale,
218
- is_causal,
219
- window_size_left,
220
- window_size_right,
221
- softcap,
222
- deterministic,
223
- gen,
224
- rng_state,
225
- )
226
-
227
-
228
- def varlen_bwd(
229
- dout: torch.Tensor,
230
- q: torch.Tensor,
231
- k: torch.Tensor,
232
- v: torch.Tensor,
233
- out: torch.Tensor,
234
- softmax_lse: torch.Tensor,
235
- cu_seqlens_q: torch.Tensor,
236
- cu_seqlens_k: torch.Tensor,
237
- dq: Optional[torch.Tensor] = None,
238
- dk: Optional[torch.Tensor] = None,
239
- dv: Optional[torch.Tensor] = None,
240
- alibi_slopes: Optional[torch.Tensor] = None,
241
- max_seqlen_q: int = 0,
242
- max_seqlen_k: int = 0,
243
- p_dropout: float = 0.0,
244
- softmax_scale: Optional[float] = None,
245
- zero_tensors: bool = False,
246
- is_causal: bool = False,
247
- window_size_left: int = -1,
248
- window_size_right: int = -1,
249
- softcap: float = 0.0,
250
- deterministic: bool = False,
251
- gen: Optional[torch.Generator] = None,
252
- rng_state: Optional[torch.Tensor] = None,
253
- ) -> List[torch.Tensor]:
254
- """
255
- Backward pass for multi-head attention with variable sequence lengths.
256
-
257
- Args:
258
- dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
- q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
- k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
- v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
- out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
- softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
- cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
- cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
- dq: Optional gradient tensor for queries, same shape as q
267
- dk: Optional gradient tensor for keys, same shape as k
268
- dv: Optional gradient tensor for values, same shape as v
269
- alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
- max_seqlen_q: Maximum sequence length for queries
271
- max_seqlen_k: Maximum sequence length for keys
272
- p_dropout: Dropout probability
273
- softmax_scale: Scale factor for softmax
274
- zero_tensors: Whether to zero tensors before computation
275
- is_causal: Whether to use causal attention
276
- window_size_left: Window size for left context (-1 for unlimited)
277
- window_size_right: Window size for right context (-1 for unlimited)
278
- softcap: Soft cap for attention weights
279
- deterministic: Whether to use deterministic algorithms
280
- gen: Optional random number generator
281
- rng_state: Optional RNG state from forward pass
282
-
283
- Returns:
284
- List of tensors: [dq, dk, dv]
285
- """
286
- if softmax_scale is None:
287
- attention_head_dim = q.shape[-1]
288
- softmax_scale = 1.0 / (attention_head_dim**0.5)
289
-
290
- return flash_attn_ops.varlen_bwd(
291
- dout,
292
- q,
293
- k,
294
- v,
295
- out,
296
- softmax_lse,
297
- dq,
298
- dk,
299
- dv,
300
- cu_seqlens_q,
301
- cu_seqlens_k,
302
- alibi_slopes,
303
- max_seqlen_q,
304
- max_seqlen_k,
305
- p_dropout,
306
- softmax_scale,
307
- zero_tensors,
308
- is_causal,
309
- window_size_left,
310
- window_size_right,
311
- softcap,
312
- deterministic,
313
- gen,
314
- rng_state,
315
- )
316
-
317
-
318
- def fwd_kvcache(
319
- q: torch.Tensor,
320
- kcache: torch.Tensor,
321
- vcache: torch.Tensor,
322
- k: Optional[torch.Tensor] = None,
323
- v: Optional[torch.Tensor] = None,
324
- seqlens_k: Optional[torch.Tensor] = None,
325
- rotary_cos: Optional[torch.Tensor] = None,
326
- rotary_sin: Optional[torch.Tensor] = None,
327
- cache_batch_idx: Optional[torch.Tensor] = None,
328
- leftpad_k: Optional[torch.Tensor] = None,
329
- block_table: Optional[torch.Tensor] = None,
330
- alibi_slopes: Optional[torch.Tensor] = None,
331
- out: Optional[torch.Tensor] = None,
332
- softmax_scale: Optional[float] = None,
333
- is_causal: bool = False,
334
- window_size_left: int = -1,
335
- window_size_right: int = -1,
336
- softcap: float = 0.0,
337
- is_rotary_interleaved: bool = False,
338
- num_splits: int = 1,
339
- ) -> List[torch.Tensor]:
340
- """
341
- Forward pass for multi-head attention with KV cache.
342
-
343
- Args:
344
- q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
- kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
- vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
- k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
- v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
- seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
- rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
- rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
- cache_batch_idx: Optional indices to index into the KV cache
353
- leftpad_k: Optional left padding for keys of shape [batch_size]
354
- block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
- alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
- out: Optional output tensor, same shape as q
357
- softmax_scale: Scale factor for softmax
358
- is_causal: Whether to use causal attention
359
- window_size_left: Window size for left context (-1 for unlimited)
360
- window_size_right: Window size for right context (-1 for unlimited)
361
- softcap: Soft cap for attention weights
362
- is_rotary_interleaved: Whether rotary embeddings are interleaved
363
- num_splits: Number of splits for computation
364
-
365
- Returns:
366
- List of tensors: [output, softmax_lse]
367
- """
368
- if softmax_scale is None:
369
- attention_head_dim = q.shape[-1]
370
- softmax_scale = 1.0 / (attention_head_dim**0.5)
371
-
372
- return flash_attn_ops.fwd_kvcache(
373
- q,
374
- kcache,
375
- vcache,
376
- k,
377
- v,
378
- seqlens_k,
379
- rotary_cos,
380
- rotary_sin,
381
- cache_batch_idx,
382
- leftpad_k,
383
- block_table,
384
- alibi_slopes,
385
- out,
386
- softmax_scale,
387
- is_causal,
388
- window_size_left,
389
- window_size_right,
390
- softcap,
391
- is_rotary_interleaved,
392
- num_splits,
393
- )
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/flash_attn_interface.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/layers/__init__.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/layers/patch_embed.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/layers/rotary.py RENAMED
File without changes
build/torch28-cxx11-xpu20251-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/__init__.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/activations.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/fused_dense.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/layer_norm.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/rms_norm.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/__init__.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/cross_entropy.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/k_activations.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/layer_norm.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/linear.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/mlp.py RENAMED
File without changes
build/{torch29-cxx11-xpu20252-x86_64-linux/flash_attn2 β†’ torch28-cxx11-xpu20251-x86_64-linux}/ops/triton/rotary.py RENAMED
File without changes
build/torch29-cxx11-xpu20252-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/torch29-cxx11-xpu20252-x86_64-linux/_flash_attn2_5dab8ba_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e7e91cf691aa55f859b6f983b4e3aecbf08e04f24ea4fc322e3c8123d060c9d
3
+ size 7279344
build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn2_5dab8ba_dirty
3
+ ops = torch.ops._flash_attn2_5dab8ba_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn2_5dab8ba_dirty::{op_name}"
build/torch29-cxx11-xpu20252-x86_64-linux/bert_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
19
+ ).reshape(-1, *other_shape)
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ (indices,) = ctx.saved_tensors
24
+ assert grad_output.ndim >= 2
25
+ other_shape = grad_output.shape[1:]
26
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
27
+ grad_input = torch.zeros(
28
+ [ctx.first_axis_dim, grad_output.shape[1]],
29
+ device=grad_output.device,
30
+ dtype=grad_output.dtype,
31
+ )
32
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
33
+ # grad_input[indices] = grad_output
34
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
36
+
37
+
38
+ index_first_axis = IndexFirstAxis.apply
39
+
40
+
41
+ class IndexPutFirstAxis(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, values, indices, first_axis_dim):
44
+ ctx.save_for_backward(indices)
45
+ assert indices.ndim == 1
46
+ assert values.ndim >= 2
47
+ output = torch.zeros(
48
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ output[indices] = values
52
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
53
+ return output
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ (indices,) = ctx.saved_tensors
58
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
59
+ grad_values = grad_output[indices]
60
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
61
+ return grad_values, None, None
62
+
63
+
64
+ index_put_first_axis = IndexPutFirstAxis.apply
65
+
66
+
67
+ class IndexFirstAxisResidual(torch.autograd.Function):
68
+ @staticmethod
69
+ def forward(ctx, input, indices):
70
+ ctx.save_for_backward(indices)
71
+ assert input.ndim >= 2
72
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
73
+ second_dim = other_shape.numel()
74
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
75
+ output = input[indices]
76
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
77
+ # memory format to channel_first. In other words, input might not be contiguous.
78
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
79
+ return output, input.detach()
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output, grad_residual):
83
+ (indices,) = ctx.saved_tensors
84
+ assert grad_output.ndim >= 2
85
+ other_shape = grad_output.shape[1:]
86
+ assert grad_residual.shape[1:] == other_shape
87
+ grad_input = grad_residual
88
+ # grad_input[indices] += grad_output
89
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
90
+ indices = indices.expand_as(grad_output)
91
+ grad_input.scatter_add_(0, indices, grad_output)
92
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
93
+
94
+
95
+ index_first_axis_residual = IndexFirstAxisResidual.apply
96
+
97
+
98
+ def unpad_input(hidden_states, attention_mask, unused_mask=None):
99
+ """
100
+ Arguments:
101
+ hidden_states: (batch, seqlen, ...)
102
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
103
+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
104
+ Return:
105
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
106
+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
107
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
108
+ max_seqlen_in_batch: int
109
+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
110
+ """
111
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
112
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
113
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
114
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
115
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
116
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
117
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
118
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
119
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
120
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
121
+ # so we write custom forward and backward to make it a bit faster.
122
+ return (
123
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
124
+ indices,
125
+ cu_seqlens,
126
+ max_seqlen_in_batch,
127
+ used_seqlens_in_batch,
128
+ )
129
+
130
+
131
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
132
+ """
133
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
134
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
135
+
136
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
137
+ ```
138
+ [
139
+ [2, 3, 0, 0, 0, 0],
140
+ [3, 2, 0, 0, 0, 0],
141
+ [6, 0, 0, 0, 0, 0]
142
+ ]
143
+ ```
144
+ , which refers to the 3D-attention mask:
145
+ ```
146
+ [
147
+ [
148
+ [1, 0, 0, 0, 0, 0],
149
+ [1, 1, 0, 0, 0, 0],
150
+ [0, 0, 1, 0, 0, 0],
151
+ [0, 0, 1, 1, 0, 0],
152
+ [0, 0, 1, 1, 1, 0],
153
+ [0, 0, 0, 0, 0, 1]
154
+ ],
155
+ [
156
+ [1, 0, 0, 0, 0, 0],
157
+ [1, 1, 0, 0, 0, 0],
158
+ [1, 1, 1, 0, 0, 0],
159
+ [0, 0, 0, 1, 0, 0],
160
+ [0, 0, 0, 1, 1, 0],
161
+ [0, 0, 0, 0, 0, 1]
162
+ ],
163
+ [
164
+ [1, 0, 0, 0, 0, 0],
165
+ [1, 1, 0, 0, 0, 0],
166
+ [1, 1, 1, 0, 0, 0],
167
+ [1, 1, 1, 1, 0, 0],
168
+ [1, 1, 1, 1, 1, 0],
169
+ [1, 1, 1, 1, 1, 1]
170
+ ]
171
+ ]
172
+ ```.
173
+
174
+ Arguments:
175
+ hidden_states: (batch, seqlen, ...)
176
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
177
+ Return:
178
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
179
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
180
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
181
+ max_seqlen_in_batch: int
182
+ """
183
+ length = attention_mask_in_length.sum(dim=-1)
184
+ seqlen = attention_mask_in_length.size(-1)
185
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops as flash_attn_ops
4
+ from .flash_attn_interface import (
5
+ flash_attn_func,
6
+ flash_attn_kvpacked_func,
7
+ flash_attn_qkvpacked_func,
8
+ flash_attn_varlen_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ flash_attn_varlen_qkvpacked_func,
11
+ flash_attn_with_kvcache,
12
+ )
13
+
14
+
15
+ def fwd(
16
+ q: torch.Tensor,
17
+ k: torch.Tensor,
18
+ v: torch.Tensor,
19
+ out: Optional[torch.Tensor] = None,
20
+ alibi_slopes: Optional[torch.Tensor] = None,
21
+ p_dropout: float = 0.0,
22
+ softmax_scale: Optional[float] = None,
23
+ is_causal: bool = False,
24
+ window_size_left: int = -1,
25
+ window_size_right: int = -1,
26
+ softcap: float = 0.0,
27
+ return_softmax: bool = False,
28
+ gen: Optional[torch.Generator] = None,
29
+ ) -> List[torch.Tensor]:
30
+ """
31
+ Forward pass for multi-head attention.
32
+
33
+ Args:
34
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
35
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
36
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
37
+ out: Optional output tensor, same shape as q
38
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
39
+ p_dropout: Dropout probability
40
+ softmax_scale: Scale factor for softmax
41
+ is_causal: Whether to use causal attention
42
+ window_size_left: Window size for left context (-1 for unlimited)
43
+ window_size_right: Window size for right context (-1 for unlimited)
44
+ softcap: Soft cap for attention weights
45
+ return_softmax: Whether to return softmax weights
46
+ gen: Optional random number generator
47
+
48
+ Returns:
49
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
50
+ """
51
+ if softmax_scale is None:
52
+ attention_head_dim = q.shape[-1]
53
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
54
+
55
+ return flash_attn_ops.fwd(
56
+ q,
57
+ k,
58
+ v,
59
+ out,
60
+ alibi_slopes,
61
+ p_dropout,
62
+ softmax_scale,
63
+ is_causal,
64
+ window_size_left,
65
+ window_size_right,
66
+ softcap,
67
+ return_softmax,
68
+ gen,
69
+ )
70
+
71
+
72
+ def varlen_fwd(
73
+ q: torch.Tensor,
74
+ k: torch.Tensor,
75
+ v: torch.Tensor,
76
+ cu_seqlens_q: torch.Tensor,
77
+ cu_seqlens_k: torch.Tensor,
78
+ out: Optional[torch.Tensor] = None,
79
+ seqused_k: Optional[torch.Tensor] = None,
80
+ leftpad_k: Optional[torch.Tensor] = None,
81
+ block_table: Optional[torch.Tensor] = None,
82
+ alibi_slopes: Optional[torch.Tensor] = None,
83
+ max_seqlen_q: int = 0,
84
+ max_seqlen_k: int = 0,
85
+ p_dropout: float = 0.0,
86
+ softmax_scale: Optional[float] = None,
87
+ zero_tensors: bool = False,
88
+ is_causal: bool = False,
89
+ window_size_left: int = -1,
90
+ window_size_right: int = -1,
91
+ softcap: float = 0.0,
92
+ return_softmax: bool = False,
93
+ gen: Optional[torch.Generator] = None,
94
+ ) -> List[torch.Tensor]:
95
+ """
96
+ Forward pass for multi-head attention with variable sequence lengths.
97
+
98
+ Args:
99
+ q: Query tensor of shape [total_q, num_heads, head_size]
100
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
101
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
102
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
103
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
104
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
105
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
106
+ leftpad_k: Optional left padding for keys of shape [batch_size]
107
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
108
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
109
+ max_seqlen_q: Maximum sequence length for queries
110
+ max_seqlen_k: Maximum sequence length for keys
111
+ p_dropout: Dropout probability
112
+ softmax_scale: Scale factor for softmax
113
+ zero_tensors: Whether to zero tensors before computation
114
+ is_causal: Whether to use causal attention
115
+ window_size_left: Window size for left context (-1 for unlimited)
116
+ window_size_right: Window size for right context (-1 for unlimited)
117
+ softcap: Soft cap for attention weights
118
+ return_softmax: Whether to return softmax weights
119
+ gen: Optional random number generator
120
+
121
+ Returns:
122
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
123
+ """
124
+ if softmax_scale is None:
125
+ attention_head_dim = q.shape[-1]
126
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
127
+
128
+ return flash_attn_ops.varlen_fwd(
129
+ q,
130
+ k,
131
+ v,
132
+ out,
133
+ cu_seqlens_q,
134
+ cu_seqlens_k,
135
+ seqused_k,
136
+ leftpad_k,
137
+ block_table,
138
+ alibi_slopes,
139
+ max_seqlen_q,
140
+ max_seqlen_k,
141
+ p_dropout,
142
+ softmax_scale,
143
+ zero_tensors,
144
+ is_causal,
145
+ window_size_left,
146
+ window_size_right,
147
+ softcap,
148
+ return_softmax,
149
+ gen,
150
+ )
151
+
152
+
153
+ def bwd(
154
+ dout: torch.Tensor,
155
+ q: torch.Tensor,
156
+ k: torch.Tensor,
157
+ v: torch.Tensor,
158
+ out: torch.Tensor,
159
+ softmax_lse: torch.Tensor,
160
+ dq: Optional[torch.Tensor] = None,
161
+ dk: Optional[torch.Tensor] = None,
162
+ dv: Optional[torch.Tensor] = None,
163
+ alibi_slopes: Optional[torch.Tensor] = None,
164
+ p_dropout: float = 0.0,
165
+ softmax_scale: Optional[float] = None,
166
+ is_causal: bool = False,
167
+ window_size_left: int = -1,
168
+ window_size_right: int = -1,
169
+ softcap: float = 0.0,
170
+ deterministic: bool = False,
171
+ gen: Optional[torch.Generator] = None,
172
+ rng_state: Optional[torch.Tensor] = None,
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Backward pass for multi-head attention.
176
+
177
+ Args:
178
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
179
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
180
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
181
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
182
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
183
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
184
+ dq: Optional gradient tensor for queries, same shape as q
185
+ dk: Optional gradient tensor for keys, same shape as k
186
+ dv: Optional gradient tensor for values, same shape as v
187
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
188
+ p_dropout: Dropout probability
189
+ softmax_scale: Scale factor for softmax
190
+ is_causal: Whether to use causal attention
191
+ window_size_left: Window size for left context (-1 for unlimited)
192
+ window_size_right: Window size for right context (-1 for unlimited)
193
+ softcap: Soft cap for attention weights
194
+ deterministic: Whether to use deterministic algorithms
195
+ gen: Optional random number generator
196
+ rng_state: Optional RNG state from forward pass
197
+
198
+ Returns:
199
+ List of tensors: [dq, dk, dv]
200
+ """
201
+ if softmax_scale is None:
202
+ attention_head_dim = q.shape[-1]
203
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
204
+
205
+ return flash_attn_ops.bwd(
206
+ dout,
207
+ q,
208
+ k,
209
+ v,
210
+ out,
211
+ softmax_lse,
212
+ dq,
213
+ dk,
214
+ dv,
215
+ alibi_slopes,
216
+ p_dropout,
217
+ softmax_scale,
218
+ is_causal,
219
+ window_size_left,
220
+ window_size_right,
221
+ softcap,
222
+ deterministic,
223
+ gen,
224
+ rng_state,
225
+ )
226
+
227
+
228
+ def varlen_bwd(
229
+ dout: torch.Tensor,
230
+ q: torch.Tensor,
231
+ k: torch.Tensor,
232
+ v: torch.Tensor,
233
+ out: torch.Tensor,
234
+ softmax_lse: torch.Tensor,
235
+ cu_seqlens_q: torch.Tensor,
236
+ cu_seqlens_k: torch.Tensor,
237
+ dq: Optional[torch.Tensor] = None,
238
+ dk: Optional[torch.Tensor] = None,
239
+ dv: Optional[torch.Tensor] = None,
240
+ alibi_slopes: Optional[torch.Tensor] = None,
241
+ max_seqlen_q: int = 0,
242
+ max_seqlen_k: int = 0,
243
+ p_dropout: float = 0.0,
244
+ softmax_scale: Optional[float] = None,
245
+ zero_tensors: bool = False,
246
+ is_causal: bool = False,
247
+ window_size_left: int = -1,
248
+ window_size_right: int = -1,
249
+ softcap: float = 0.0,
250
+ deterministic: bool = False,
251
+ gen: Optional[torch.Generator] = None,
252
+ rng_state: Optional[torch.Tensor] = None,
253
+ ) -> List[torch.Tensor]:
254
+ """
255
+ Backward pass for multi-head attention with variable sequence lengths.
256
+
257
+ Args:
258
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
259
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
260
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
261
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
262
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
263
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
264
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
265
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
266
+ dq: Optional gradient tensor for queries, same shape as q
267
+ dk: Optional gradient tensor for keys, same shape as k
268
+ dv: Optional gradient tensor for values, same shape as v
269
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
270
+ max_seqlen_q: Maximum sequence length for queries
271
+ max_seqlen_k: Maximum sequence length for keys
272
+ p_dropout: Dropout probability
273
+ softmax_scale: Scale factor for softmax
274
+ zero_tensors: Whether to zero tensors before computation
275
+ is_causal: Whether to use causal attention
276
+ window_size_left: Window size for left context (-1 for unlimited)
277
+ window_size_right: Window size for right context (-1 for unlimited)
278
+ softcap: Soft cap for attention weights
279
+ deterministic: Whether to use deterministic algorithms
280
+ gen: Optional random number generator
281
+ rng_state: Optional RNG state from forward pass
282
+
283
+ Returns:
284
+ List of tensors: [dq, dk, dv]
285
+ """
286
+ if softmax_scale is None:
287
+ attention_head_dim = q.shape[-1]
288
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
289
+
290
+ return flash_attn_ops.varlen_bwd(
291
+ dout,
292
+ q,
293
+ k,
294
+ v,
295
+ out,
296
+ softmax_lse,
297
+ dq,
298
+ dk,
299
+ dv,
300
+ cu_seqlens_q,
301
+ cu_seqlens_k,
302
+ alibi_slopes,
303
+ max_seqlen_q,
304
+ max_seqlen_k,
305
+ p_dropout,
306
+ softmax_scale,
307
+ zero_tensors,
308
+ is_causal,
309
+ window_size_left,
310
+ window_size_right,
311
+ softcap,
312
+ deterministic,
313
+ gen,
314
+ rng_state,
315
+ )
316
+
317
+
318
+ def fwd_kvcache(
319
+ q: torch.Tensor,
320
+ kcache: torch.Tensor,
321
+ vcache: torch.Tensor,
322
+ k: Optional[torch.Tensor] = None,
323
+ v: Optional[torch.Tensor] = None,
324
+ seqlens_k: Optional[torch.Tensor] = None,
325
+ rotary_cos: Optional[torch.Tensor] = None,
326
+ rotary_sin: Optional[torch.Tensor] = None,
327
+ cache_batch_idx: Optional[torch.Tensor] = None,
328
+ leftpad_k: Optional[torch.Tensor] = None,
329
+ block_table: Optional[torch.Tensor] = None,
330
+ alibi_slopes: Optional[torch.Tensor] = None,
331
+ out: Optional[torch.Tensor] = None,
332
+ softmax_scale: Optional[float] = None,
333
+ is_causal: bool = False,
334
+ window_size_left: int = -1,
335
+ window_size_right: int = -1,
336
+ softcap: float = 0.0,
337
+ is_rotary_interleaved: bool = False,
338
+ num_splits: int = 1,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Forward pass for multi-head attention with KV cache.
342
+
343
+ Args:
344
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
345
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
346
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
347
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
348
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
349
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
350
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
351
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
352
+ cache_batch_idx: Optional indices to index into the KV cache
353
+ leftpad_k: Optional left padding for keys of shape [batch_size]
354
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
355
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
356
+ out: Optional output tensor, same shape as q
357
+ softmax_scale: Scale factor for softmax
358
+ is_causal: Whether to use causal attention
359
+ window_size_left: Window size for left context (-1 for unlimited)
360
+ window_size_right: Window size for right context (-1 for unlimited)
361
+ softcap: Soft cap for attention weights
362
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
363
+ num_splits: Number of splits for computation
364
+
365
+ Returns:
366
+ List of tensors: [output, softmax_lse]
367
+ """
368
+ if softmax_scale is None:
369
+ attention_head_dim = q.shape[-1]
370
+ softmax_scale = 1.0 / (attention_head_dim**0.5)
371
+
372
+ return flash_attn_ops.fwd_kvcache(
373
+ q,
374
+ kcache,
375
+ vcache,
376
+ k,
377
+ v,
378
+ seqlens_k,
379
+ rotary_cos,
380
+ rotary_sin,
381
+ cache_batch_idx,
382
+ leftpad_k,
383
+ block_table,
384
+ alibi_slopes,
385
+ out,
386
+ softmax_scale,
387
+ is_causal,
388
+ window_size_left,
389
+ window_size_right,
390
+ softcap,
391
+ is_rotary_interleaved,
392
+ num_splits,
393
+ )
build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/_flash_attn_c984dd4_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98c57dda92346f88486e974d74aee9b0fb1b1f663506666929d3c02aa897e528
3
+ size 3420928
build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_c984dd4_dirty
3
+ ops = torch.ops._flash_attn_c984dd4_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_c984dd4_dirty::{op_name}"
build/torch29-cxx11-xpu20252-x86_64-linux/flash_attn/bert_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
19
+ ).reshape(-1, *other_shape)
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ (indices,) = ctx.saved_tensors
24
+ assert grad_output.ndim >= 2
25
+ other_shape = grad_output.shape[1:]
26
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
27
+ grad_input = torch.zeros(
28
+ [ctx.first_axis_dim, grad_output.shape[1]],
29
+ device=grad_output.device,
30
+ dtype=grad_output.dtype,
31
+ )
32
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
33
+ # grad_input[indices] = grad_output
34
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
36
+
37
+
38
+ index_first_axis = IndexFirstAxis.apply
39
+
40
+
41
+ class IndexPutFirstAxis(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, values, indices, first_axis_dim):
44
+ ctx.save_for_backward(indices)
45
+ assert indices.ndim == 1
46
+ assert values.ndim >= 2
47
+ output = torch.zeros(
48
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ output[indices] = values
52
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
53
+ return output
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ (indices,) = ctx.saved_tensors
58
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
59
+ grad_values = grad_output[indices]
60
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
61
+ return grad_values, None, None
62
+
63
+
64
+ index_put_first_axis = IndexPutFirstAxis.apply
65
+
66
+
67
+ class IndexFirstAxisResidual(torch.autograd.Function):
68
+ @staticmethod
69
+ def forward(ctx, input, indices):
70
+ ctx.save_for_backward(indices)
71
+ assert input.ndim >= 2
72
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
73
+ second_dim = other_shape.numel()
74
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
75
+ output = input[indices]
76
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
77
+ # memory format to channel_first. In other words, input might not be contiguous.
78
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
79
+ return output, input.detach()
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output, grad_residual):
83
+ (indices,) = ctx.saved_tensors
84
+ assert grad_output.ndim >= 2
85
+ other_shape = grad_output.shape[1:]
86
+ assert grad_residual.shape[1:] == other_shape
87
+ grad_input = grad_residual
88
+ # grad_input[indices] += grad_output
89
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
90
+ indices = indices.expand_as(grad_output)
91
+ grad_input.scatter_add_(0, indices, grad_output)
92
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
93
+
94
+
95
+ index_first_axis_residual = IndexFirstAxisResidual.apply
96
+
97
+
98
+ def unpad_input(hidden_states, attention_mask, unused_mask=None):
99
+ """
100
+ Arguments:
101
+ hidden_states: (batch, seqlen, ...)
102
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
103
+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
104
+ Return:
105
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
106
+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
107
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
108
+ max_seqlen_in_batch: int
109
+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
110
+ """
111
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
112
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
113
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
114
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
115
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
116
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
117
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
118
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
119
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
120
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
121
+ # so we write custom forward and backward to make it a bit faster.
122
+ return (
123
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
124
+ indices,
125
+ cu_seqlens,
126
+ max_seqlen_in_batch,
127
+ used_seqlens_in_batch,
128
+ )
129
+
130
+
131
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
132
+ """
133
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
134
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
135
+
136
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
137
+ ```
138
+ [
139
+ [2, 3, 0, 0, 0, 0],
140
+ [3, 2, 0, 0, 0, 0],
141
+ [6, 0, 0, 0, 0, 0]
142
+ ]
143
+ ```
144
+ , which refers to the 3D-attention mask:
145
+ ```
146
+ [
147
+ [
148
+ [1, 0, 0, 0, 0, 0],
149
+ [1, 1, 0, 0, 0, 0],
150
+ [0, 0, 1, 0, 0, 0],
151
+ [0, 0, 1, 1, 0, 0],
152
+ [0, 0, 1, 1, 1, 0],
153
+ [0, 0, 0, 0, 0, 1]
154
+ ],
155
+ [
156
+ [1, 0, 0, 0, 0, 0],
157
+ [1, 1, 0, 0, 0, 0],
158
+ [1, 1, 1, 0, 0, 0],
159
+ [0, 0, 0, 1, 0, 0],
160
+ [0, 0, 0, 1, 1, 0],
161
+ [0, 0, 0, 0, 0, 1]
162
+ ],
163
+ [
164
+ [1, 0, 0, 0, 0, 0],
165
+ [1, 1, 0, 0, 0, 0],
166
+ [1, 1, 1, 0, 0, 0],
167
+ [1, 1, 1, 1, 0, 0],
168
+ [1, 1, 1, 1, 1, 0],
169
+ [1, 1, 1, 1, 1, 1]
170
+ ]
171
+ ]
172
+ ```.
173
+
174
+ Arguments:
175
+ hidden_states: (batch, seqlen, ...)
176
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
177
+ Return:
178
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
179
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
180
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
181
+ max_seqlen_in_batch: int
182
+ """
183
+ length = attention_mask_in_length.sum(dim=-1)
184
+ seqlen = attention_mask_in_length.size(-1)
185
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)