ArtemisTAO commited on
Commit
ec78611
·
verified ·
0 Parent(s):

Duplicate from ArtemisTAO/WIN_21_1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +191 -0
  2. Dockerfile +48 -0
  3. LICENSE +201 -0
  4. README.md +13 -0
  5. assets/Qwen2.5_Omni.pdf +3 -0
  6. cookbooks/=4.41.0 +17 -0
  7. cookbooks/=4.50.0.dev0 +17 -0
  8. cookbooks/=4.51.0.dev0 +17 -0
  9. cookbooks/flash-attention/.github/workflows/publish.yml +218 -0
  10. cookbooks/flash-attention/.gitignore +31 -0
  11. cookbooks/flash-attention/.gitmodules +6 -0
  12. cookbooks/flash-attention/AUTHORS +1 -0
  13. cookbooks/flash-attention/LICENSE +29 -0
  14. cookbooks/flash-attention/MANIFEST.in +12 -0
  15. cookbooks/flash-attention/Makefile +9 -0
  16. cookbooks/flash-attention/README.md +524 -0
  17. cookbooks/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png +3 -0
  18. cookbooks/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png +3 -0
  19. cookbooks/flash-attention/assets/flash3_fp16_fwd.png +3 -0
  20. cookbooks/flash-attention/assets/flashattention_logo.png +3 -0
  21. cookbooks/flash-attention/assets/flashattn_banner.jpg +3 -0
  22. cookbooks/flash-attention/assets/flashattn_banner.pdf +3 -0
  23. cookbooks/flash-attention/assets/flashattn_memory.jpg +0 -0
  24. cookbooks/flash-attention/assets/flashattn_speedup.jpg +3 -0
  25. cookbooks/flash-attention/assets/flashattn_speedup_3090.jpg +3 -0
  26. cookbooks/flash-attention/assets/flashattn_speedup_a100_d128.jpg +3 -0
  27. cookbooks/flash-attention/assets/flashattn_speedup_t4.jpg +3 -0
  28. cookbooks/flash-attention/assets/flashattn_speedup_t4_fwd.jpg +3 -0
  29. cookbooks/flash-attention/assets/gpt2_training_curve.jpg +3 -0
  30. cookbooks/flash-attention/assets/gpt2_training_efficiency.jpg +3 -0
  31. cookbooks/flash-attention/assets/gpt3_training_curve.jpg +3 -0
  32. cookbooks/flash-attention/assets/gpt3_training_efficiency.jpg +3 -0
  33. cookbooks/flash-attention/benchmarks/benchmark_alibi.py +275 -0
  34. cookbooks/flash-attention/benchmarks/benchmark_causal.py +225 -0
  35. cookbooks/flash-attention/benchmarks/benchmark_flash_attention.py +180 -0
  36. cookbooks/flash-attention/benchmarks/benchmark_gemm.py +47 -0
  37. cookbooks/flash-attention/csrc/flash_attn/flash_api.cpp +1485 -0
  38. cookbooks/flash-attention/csrc/flash_attn/src/alibi.h +75 -0
  39. cookbooks/flash-attention/csrc/flash_attn/src/block_info.h +49 -0
  40. cookbooks/flash-attention/csrc/flash_attn/src/dropout.h +95 -0
  41. cookbooks/flash-attention/csrc/flash_attn/src/flash.h +194 -0
  42. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu +14 -0
  43. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +14 -0
  44. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu +14 -0
  45. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +14 -0
  46. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +14 -0
  47. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +14 -0
  48. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +14 -0
  49. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +14 -0
  50. cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu +14 -0
.gitattributes ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/Qwen2.5_Omni.pdf filter=lfs diff=lfs merge=lfs -text
37
+ cookbooks/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
38
+ cookbooks/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
39
+ cookbooks/flash-attention/assets/flash3_fp16_fwd.png filter=lfs diff=lfs merge=lfs -text
40
+ cookbooks/flash-attention/assets/flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
41
+ cookbooks/flash-attention/assets/flashattn_banner.jpg filter=lfs diff=lfs merge=lfs -text
42
+ cookbooks/flash-attention/assets/flashattn_banner.pdf filter=lfs diff=lfs merge=lfs -text
43
+ cookbooks/flash-attention/assets/flashattn_speedup.jpg filter=lfs diff=lfs merge=lfs -text
44
+ cookbooks/flash-attention/assets/flashattn_speedup_3090.jpg filter=lfs diff=lfs merge=lfs -text
45
+ cookbooks/flash-attention/assets/flashattn_speedup_a100_d128.jpg filter=lfs diff=lfs merge=lfs -text
46
+ cookbooks/flash-attention/assets/flashattn_speedup_t4.jpg filter=lfs diff=lfs merge=lfs -text
47
+ cookbooks/flash-attention/assets/flashattn_speedup_t4_fwd.jpg filter=lfs diff=lfs merge=lfs -text
48
+ cookbooks/flash-attention/assets/gpt2_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
49
+ cookbooks/flash-attention/assets/gpt2_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
50
+ cookbooks/flash-attention/assets/gpt3_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
51
+ cookbooks/flash-attention/assets/gpt3_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
52
+ flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
53
+ flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
54
+ flash-attention/assets/flash3_fp16_fwd.png filter=lfs diff=lfs merge=lfs -text
55
+ flash-attention/assets/flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
56
+ flash-attention/assets/flashattn_banner.jpg filter=lfs diff=lfs merge=lfs -text
57
+ flash-attention/assets/flashattn_banner.pdf filter=lfs diff=lfs merge=lfs -text
58
+ flash-attention/assets/flashattn_speedup.jpg filter=lfs diff=lfs merge=lfs -text
59
+ flash-attention/assets/flashattn_speedup_3090.jpg filter=lfs diff=lfs merge=lfs -text
60
+ flash-attention/assets/flashattn_speedup_a100_d128.jpg filter=lfs diff=lfs merge=lfs -text
61
+ flash-attention/assets/flashattn_speedup_t4.jpg filter=lfs diff=lfs merge=lfs -text
62
+ flash-attention/assets/flashattn_speedup_t4_fwd.jpg filter=lfs diff=lfs merge=lfs -text
63
+ flash-attention/assets/gpt2_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
64
+ flash-attention/assets/gpt2_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
65
+ flash-attention/assets/gpt3_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
66
+ flash-attention/assets/gpt3_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
67
+ flash-attention/csrc/composable_kernel/docs/data/ck_component.png filter=lfs diff=lfs merge=lfs -text
68
+ flash-attention/csrc/composable_kernel/docs/data/ck_layer.png filter=lfs diff=lfs merge=lfs -text
69
+ flash-attention/csrc/composable_kernel/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png filter=lfs diff=lfs merge=lfs -text
70
+ flash-attention/csrc/composable_kernel/example/ck_tile/15_fused_moe/misc/moe-2.png filter=lfs diff=lfs merge=lfs -text
71
+ flash-attention/csrc/cutlass/media/images/M128xK4_scalefactor_gmem.png filter=lfs diff=lfs merge=lfs -text
72
+ flash-attention/csrc/cutlass/media/images/conv2d-fprop-int4.png filter=lfs diff=lfs merge=lfs -text
73
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT.png filter=lfs diff=lfs merge=lfs -text
74
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2.png filter=lfs diff=lfs merge=lfs -text
75
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32Mx32x4.png filter=lfs diff=lfs merge=lfs -text
76
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32x32x4.png filter=lfs diff=lfs merge=lfs -text
77
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_Atom.png filter=lfs diff=lfs merge=lfs -text
78
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.AB.png filter=lfs diff=lfs merge=lfs -text
79
+ flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.C.png filter=lfs diff=lfs merge=lfs -text
80
+ flash-attention/csrc/cutlass/media/images/cute/TiledCopyA.png filter=lfs diff=lfs merge=lfs -text
81
+ flash-attention/csrc/cutlass/media/images/cute/TiledMmaC.png filter=lfs diff=lfs merge=lfs -text
82
+ flash-attention/csrc/cutlass/media/images/cute/composition1.png filter=lfs diff=lfs merge=lfs -text
83
+ flash-attention/csrc/cutlass/media/images/cute/composition2.png filter=lfs diff=lfs merge=lfs -text
84
+ flash-attention/csrc/cutlass/media/images/cute/divide2.png filter=lfs diff=lfs merge=lfs -text
85
+ flash-attention/csrc/cutlass/media/images/cute/divide3.png filter=lfs diff=lfs merge=lfs -text
86
+ flash-attention/csrc/cutlass/media/images/cute/gmma_coremat_cd_fp16.png filter=lfs diff=lfs merge=lfs -text
87
+ flash-attention/csrc/cutlass/media/images/cute/gmma_wg_n_slice.png filter=lfs diff=lfs merge=lfs -text
88
+ flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide-2.png filter=lfs diff=lfs merge=lfs -text
89
+ flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide.png filter=lfs diff=lfs merge=lfs -text
90
+ flash-attention/csrc/cutlass/media/images/cute/product2d.png filter=lfs diff=lfs merge=lfs -text
91
+ flash-attention/csrc/cutlass/media/images/cute/productblocked2d.png filter=lfs diff=lfs merge=lfs -text
92
+ flash-attention/csrc/cutlass/media/images/cute/productraked2d.png filter=lfs diff=lfs merge=lfs -text
93
+ flash-attention/csrc/cutlass/media/images/cute/slice.png filter=lfs diff=lfs merge=lfs -text
94
+ flash-attention/csrc/cutlass/media/images/cute/tC_partitioning.png filter=lfs diff=lfs merge=lfs -text
95
+ flash-attention/csrc/cutlass/media/images/cute/tv_layout.png filter=lfs diff=lfs merge=lfs -text
96
+ flash-attention/csrc/cutlass/media/images/cutlass-2.8-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
97
+ flash-attention/csrc/cutlass/media/images/cutlass-2.9-implicit-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
98
+ flash-attention/csrc/cutlass/media/images/cutlass-3.0-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
99
+ flash-attention/csrc/cutlass/media/images/cutlass-3.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
100
+ flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png filter=lfs diff=lfs merge=lfs -text
101
+ flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
102
+ flash-attention/csrc/cutlass/media/images/cutlass-gemm-components.png filter=lfs diff=lfs merge=lfs -text
103
+ flash-attention/csrc/cutlass/media/images/cutlass-reduction-in-named-iterators.png filter=lfs diff=lfs merge=lfs -text
104
+ flash-attention/csrc/cutlass/media/images/cutlass-threadblock-mma-pipelined.png filter=lfs diff=lfs merge=lfs -text
105
+ flash-attention/csrc/cutlass/media/images/cutlass-tile-structure.png filter=lfs diff=lfs merge=lfs -text
106
+ flash-attention/csrc/cutlass/media/images/cutlass-warp-level-gemm-api-instantiation.png filter=lfs diff=lfs merge=lfs -text
107
+ flash-attention/csrc/cutlass/media/images/cutlass-warp-thread-tile-structure.png filter=lfs diff=lfs merge=lfs -text
108
+ flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue-no-labels.png filter=lfs diff=lfs merge=lfs -text
109
+ flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue.png filter=lfs diff=lfs merge=lfs -text
110
+ flash-attention/csrc/cutlass/media/images/gemm-structural-components.png filter=lfs diff=lfs merge=lfs -text
111
+ flash-attention/csrc/cutlass/media/images/ldmatrix-8x128bx4.png filter=lfs diff=lfs merge=lfs -text
112
+ flash-attention/csrc/cutlass/media/images/ldmatrix-tensorop-32x32x32.png filter=lfs diff=lfs merge=lfs -text
113
+ flash-attention/csrc/cutlass/media/images/mma-8x8x32.png filter=lfs diff=lfs merge=lfs -text
114
+ flash-attention/csrc/cutlass/media/images/non_persistent.png filter=lfs diff=lfs merge=lfs -text
115
+ flash-attention/csrc/cutlass/media/images/persistent_clc.png filter=lfs diff=lfs merge=lfs -text
116
+ flash-attention/csrc/cutlass/media/images/persistent_static.png filter=lfs diff=lfs merge=lfs -text
117
+ flash-attention/csrc/cutlass/media/images/software-pipeline.png filter=lfs diff=lfs merge=lfs -text
118
+ flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k0.png filter=lfs diff=lfs merge=lfs -text
119
+ flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k1.png filter=lfs diff=lfs merge=lfs -text
120
+ flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN.png filter=lfs diff=lfs merge=lfs -text
121
+ flash-attention/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
122
+ flash-attention/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
123
+ flash-attention/flash-attention/assets/flash3_fp16_fwd.png filter=lfs diff=lfs merge=lfs -text
124
+ flash-attention/flash-attention/assets/flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
125
+ flash-attention/flash-attention/assets/flashattn_banner.jpg filter=lfs diff=lfs merge=lfs -text
126
+ flash-attention/flash-attention/assets/flashattn_banner.pdf filter=lfs diff=lfs merge=lfs -text
127
+ flash-attention/flash-attention/assets/flashattn_speedup.jpg filter=lfs diff=lfs merge=lfs -text
128
+ flash-attention/flash-attention/assets/flashattn_speedup_3090.jpg filter=lfs diff=lfs merge=lfs -text
129
+ flash-attention/flash-attention/assets/flashattn_speedup_a100_d128.jpg filter=lfs diff=lfs merge=lfs -text
130
+ flash-attention/flash-attention/assets/flashattn_speedup_t4.jpg filter=lfs diff=lfs merge=lfs -text
131
+ flash-attention/flash-attention/assets/flashattn_speedup_t4_fwd.jpg filter=lfs diff=lfs merge=lfs -text
132
+ flash-attention/flash-attention/assets/gpt2_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
133
+ flash-attention/flash-attention/assets/gpt2_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
134
+ flash-attention/flash-attention/assets/gpt3_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
135
+ flash-attention/flash-attention/assets/gpt3_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
136
+ flash-attention/flash-attention/csrc/composable_kernel/docs/data/ck_component.png filter=lfs diff=lfs merge=lfs -text
137
+ flash-attention/flash-attention/csrc/composable_kernel/docs/data/ck_layer.png filter=lfs diff=lfs merge=lfs -text
138
+ flash-attention/flash-attention/csrc/composable_kernel/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png filter=lfs diff=lfs merge=lfs -text
139
+ flash-attention/flash-attention/csrc/composable_kernel/example/ck_tile/15_fused_moe/misc/moe-2.png filter=lfs diff=lfs merge=lfs -text
140
+ flash-attention/flash-attention/csrc/cutlass/media/images/M128xK4_scalefactor_gmem.png filter=lfs diff=lfs merge=lfs -text
141
+ flash-attention/flash-attention/csrc/cutlass/media/images/conv2d-fprop-int4.png filter=lfs diff=lfs merge=lfs -text
142
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT.png filter=lfs diff=lfs merge=lfs -text
143
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2.png filter=lfs diff=lfs merge=lfs -text
144
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32Mx32x4.png filter=lfs diff=lfs merge=lfs -text
145
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32x32x4.png filter=lfs diff=lfs merge=lfs -text
146
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_Atom.png filter=lfs diff=lfs merge=lfs -text
147
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.AB.png filter=lfs diff=lfs merge=lfs -text
148
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.C.png filter=lfs diff=lfs merge=lfs -text
149
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/TiledCopyA.png filter=lfs diff=lfs merge=lfs -text
150
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/TiledMmaC.png filter=lfs diff=lfs merge=lfs -text
151
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/composition1.png filter=lfs diff=lfs merge=lfs -text
152
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/composition2.png filter=lfs diff=lfs merge=lfs -text
153
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/divide2.png filter=lfs diff=lfs merge=lfs -text
154
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/divide3.png filter=lfs diff=lfs merge=lfs -text
155
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/gmma_coremat_cd_fp16.png filter=lfs diff=lfs merge=lfs -text
156
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/gmma_wg_n_slice.png filter=lfs diff=lfs merge=lfs -text
157
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide-2.png filter=lfs diff=lfs merge=lfs -text
158
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide.png filter=lfs diff=lfs merge=lfs -text
159
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/product2d.png filter=lfs diff=lfs merge=lfs -text
160
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/productblocked2d.png filter=lfs diff=lfs merge=lfs -text
161
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/productraked2d.png filter=lfs diff=lfs merge=lfs -text
162
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/slice.png filter=lfs diff=lfs merge=lfs -text
163
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/tC_partitioning.png filter=lfs diff=lfs merge=lfs -text
164
+ flash-attention/flash-attention/csrc/cutlass/media/images/cute/tv_layout.png filter=lfs diff=lfs merge=lfs -text
165
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-2.8-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
166
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-2.9-implicit-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
167
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.0-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
168
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
169
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png filter=lfs diff=lfs merge=lfs -text
170
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
171
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-gemm-components.png filter=lfs diff=lfs merge=lfs -text
172
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-reduction-in-named-iterators.png filter=lfs diff=lfs merge=lfs -text
173
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-threadblock-mma-pipelined.png filter=lfs diff=lfs merge=lfs -text
174
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-tile-structure.png filter=lfs diff=lfs merge=lfs -text
175
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-warp-level-gemm-api-instantiation.png filter=lfs diff=lfs merge=lfs -text
176
+ flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-warp-thread-tile-structure.png filter=lfs diff=lfs merge=lfs -text
177
+ flash-attention/flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue-no-labels.png filter=lfs diff=lfs merge=lfs -text
178
+ flash-attention/flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue.png filter=lfs diff=lfs merge=lfs -text
179
+ flash-attention/flash-attention/csrc/cutlass/media/images/gemm-structural-components.png filter=lfs diff=lfs merge=lfs -text
180
+ flash-attention/flash-attention/csrc/cutlass/media/images/ldmatrix-8x128bx4.png filter=lfs diff=lfs merge=lfs -text
181
+ flash-attention/flash-attention/csrc/cutlass/media/images/ldmatrix-tensorop-32x32x32.png filter=lfs diff=lfs merge=lfs -text
182
+ flash-attention/flash-attention/csrc/cutlass/media/images/mma-8x8x32.png filter=lfs diff=lfs merge=lfs -text
183
+ flash-attention/flash-attention/csrc/cutlass/media/images/non_persistent.png filter=lfs diff=lfs merge=lfs -text
184
+ flash-attention/flash-attention/csrc/cutlass/media/images/persistent_clc.png filter=lfs diff=lfs merge=lfs -text
185
+ flash-attention/flash-attention/csrc/cutlass/media/images/persistent_static.png filter=lfs diff=lfs merge=lfs -text
186
+ flash-attention/flash-attention/csrc/cutlass/media/images/software-pipeline.png filter=lfs diff=lfs merge=lfs -text
187
+ flash-attention/flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k0.png filter=lfs diff=lfs merge=lfs -text
188
+ flash-attention/flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k1.png filter=lfs diff=lfs merge=lfs -text
189
+ flash-attention/flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN.png filter=lfs diff=lfs merge=lfs -text
190
+ input_audio.wav filter=lfs diff=lfs merge=lfs -text
191
+ model/Qwen2.5-Omni-7B/tokenizer.json filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ HF_HOME=/app/models \
7
+ NUMBA_CACHE_DIR=/tmp/numba_cache \
8
+ TORCH_CUDA_ARCH_LIST=8.0
9
+
10
+ # Install system dependencies
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ python3 \
13
+ python3-pip \
14
+ python3-dev \
15
+ build-essential \
16
+ git \
17
+ ffmpeg \
18
+ libsndfile1 \
19
+ libcusparse-dev-12-3 \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ # Install Python build tools
23
+ RUN pip install --upgrade pip setuptools wheel packaging ninja
24
+
25
+ WORKDIR /app
26
+
27
+ # Create cache directory
28
+ RUN mkdir -p /tmp/numba_cache && \
29
+ chmod 777 /tmp/numba_cache
30
+
31
+ # Install PyTorch with CUDA 12.1 first
32
+ RUN pip install --pre torch torchvision torchaudio \
33
+ --index-url https://download.pytorch.org/whl/nightly/cu121
34
+
35
+ # Copy and install requirements
36
+ COPY requirements.txt .
37
+ RUN pip install --no-cache-dir -r requirements.txt
38
+
39
+ # Install flash-attn separately with no isolation
40
+ RUN pip install flash-attn==2.7.4.post1 --no-build-isolation
41
+
42
+ # Copy application files
43
+ COPY server.py .
44
+ COPY qwen-omni-utils/ ./qwen-omni-utils/
45
+ COPY model/ ./model/
46
+
47
+ EXPOSE 8000
48
+ CMD ["python3", "server.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Alibaba Cloud
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
assets/Qwen2.5_Omni.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0c9e0042ad20bc0c95cbbfc96f63f4ff1f28727c5b32973e7fd597557b6b15f
3
+ size 4014433
cookbooks/=4.41.0 ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: transformers in /home/ubuntu/.venv/lib/python3.10/site-packages (4.51.0.dev0)
2
+ Requirement already satisfied: filelock in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (3.18.0)
3
+ Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.29.3)
4
+ Requirement already satisfied: numpy>=1.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.1.3)
5
+ Requirement already satisfied: packaging>=20.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (24.2)
6
+ Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
7
+ Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2024.11.6)
8
+ Requirement already satisfied: requests in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.32.3)
9
+ Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.21.1)
10
+ Requirement already satisfied: safetensors>=0.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.5.3)
11
+ Requirement already satisfied: tqdm>=4.27 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
12
+ Requirement already satisfied: fsspec>=2023.5.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2025.3.0)
13
+ Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.13.0)
14
+ Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.1)
15
+ Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
16
+ Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2.3.0)
17
+ Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)
cookbooks/=4.50.0.dev0 ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: transformers in /home/ubuntu/.venv/lib/python3.10/site-packages (4.50.0.dev0)
2
+ Requirement already satisfied: filelock in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (3.18.0)
3
+ Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.29.3)
4
+ Requirement already satisfied: numpy>=1.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.1.3)
5
+ Requirement already satisfied: packaging>=20.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (24.2)
6
+ Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
7
+ Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2024.11.6)
8
+ Requirement already satisfied: requests in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.32.3)
9
+ Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.21.1)
10
+ Requirement already satisfied: safetensors>=0.4.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.5.3)
11
+ Requirement already satisfied: tqdm>=4.27 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
12
+ Requirement already satisfied: fsspec>=2023.5.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2025.3.0)
13
+ Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.13.0)
14
+ Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.1)
15
+ Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
16
+ Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2.3.0)
17
+ Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)
cookbooks/=4.51.0.dev0 ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: transformers in /home/ubuntu/.venv/lib/python3.10/site-packages (4.50.0.dev0)
2
+ Requirement already satisfied: filelock in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (3.18.0)
3
+ Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.29.3)
4
+ Requirement already satisfied: numpy>=1.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.1.3)
5
+ Requirement already satisfied: packaging>=20.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (24.2)
6
+ Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
7
+ Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2024.11.6)
8
+ Requirement already satisfied: requests in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.32.3)
9
+ Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.21.1)
10
+ Requirement already satisfied: safetensors>=0.4.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.5.3)
11
+ Requirement already satisfied: tqdm>=4.27 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
12
+ Requirement already satisfied: fsspec>=2023.5.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2025.3.0)
13
+ Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.13.0)
14
+ Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.1)
15
+ Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
16
+ Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2.3.0)
17
+ Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)
cookbooks/flash-attention/.github/workflows/publish.yml ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will:
2
+ # - Create a new Github release
3
+ # - Build wheels for supported architectures
4
+ # - Deploy the wheels to the Github release
5
+ # - Release the static code to PyPi
6
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
7
+
8
+ name: Build wheels and deploy
9
+
10
+ on:
11
+ create:
12
+ tags:
13
+ - v*
14
+
15
+ jobs:
16
+
17
+ setup_release:
18
+ name: Create Release
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - name: Get the tag version
22
+ id: extract_branch
23
+ run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
24
+ shell: bash
25
+
26
+ - name: Create Release
27
+ id: create_release
28
+ uses: actions/create-release@v1
29
+ env:
30
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31
+ with:
32
+ tag_name: ${{ steps.extract_branch.outputs.branch }}
33
+ release_name: ${{ steps.extract_branch.outputs.branch }}
34
+
35
+ build_wheels:
36
+ name: Build Wheel
37
+ needs: setup_release
38
+ runs-on: ${{ matrix.os }}
39
+
40
+ strategy:
41
+ fail-fast: false
42
+ matrix:
43
+ # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44
+ # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
45
+ os: [ubuntu-20.04]
46
+ python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
47
+ torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0']
48
+ cuda-version: ['12.4.1']
49
+ # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
50
+ # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
51
+ # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
52
+ # when building without C++11 ABI and using it on nvcr images.
53
+ cxx11_abi: ['FALSE', 'TRUE']
54
+ exclude:
55
+ # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
56
+ # Pytorch < 2.5 does not support Python 3.13
57
+ - torch-version: '2.2.2'
58
+ python-version: '3.13'
59
+ - torch-version: '2.3.1'
60
+ python-version: '3.13'
61
+ - torch-version: '2.4.0'
62
+ python-version: '3.13'
63
+
64
+ steps:
65
+ - name: Checkout
66
+ uses: actions/checkout@v4
67
+
68
+ - name: Set up Python
69
+ uses: actions/setup-python@v5
70
+ with:
71
+ python-version: ${{ matrix.python-version }}
72
+
73
+ - name: Set CUDA and PyTorch versions
74
+ run: |
75
+ echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
76
+ echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
77
+ echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
78
+ echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
79
+
80
+ - name: Free up disk space
81
+ if: ${{ runner.os == 'Linux' }}
82
+ # https://github.com/easimon/maximize-build-space/blob/master/action.yml
83
+ # https://github.com/easimon/maximize-build-space/tree/test-report
84
+ run: |
85
+ sudo rm -rf /usr/share/dotnet
86
+ sudo rm -rf /opt/ghc
87
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
88
+
89
+ - name: Set up swap space
90
+ if: runner.os == 'Linux'
91
+ uses: pierotofy/set-swap-space@v1.0
92
+ with:
93
+ swap-size-gb: 10
94
+
95
+ - name: Install CUDA ${{ matrix.cuda-version }}
96
+ if: ${{ matrix.cuda-version != 'cpu' }}
97
+ uses: Jimver/cuda-toolkit@v0.2.19
98
+ id: cuda-toolkit
99
+ with:
100
+ cuda: ${{ matrix.cuda-version }}
101
+ linux-local-args: '["--toolkit"]'
102
+ # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
103
+ # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
104
+ method: 'network'
105
+ sub-packages: '["nvcc"]'
106
+
107
+ - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
108
+ run: |
109
+ pip install --upgrade pip
110
+ # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
111
+ pip install setuptools==75.8.0
112
+ # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
113
+ # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
114
+ pip install typing-extensions==4.12.2
115
+ # We want to figure out the CUDA version to download pytorch
116
+ # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
117
+ # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
118
+ # This code is ugly, maybe there's a better way to do this.
119
+ export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
120
+ minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
121
+ maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
122
+ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
123
+ )
124
+ if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
125
+ # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
126
+ # Can't use --no-deps because we need cudnn etc.
127
+ # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
128
+ pip install jinja2
129
+ pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
130
+ pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
131
+ else
132
+ pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
133
+ fi
134
+ nvcc --version
135
+ python --version
136
+ python -c "import torch; print('PyTorch:', torch.__version__)"
137
+ python -c "import torch; print('CUDA:', torch.version.cuda)"
138
+ python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
139
+ shell:
140
+ bash
141
+
142
+ - name: Build wheel
143
+ run: |
144
+ # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
145
+ # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
146
+ # However this still fails so I'm using a newer version of setuptools
147
+ pip install setuptools==75.8.0
148
+ pip install ninja packaging wheel
149
+ export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
150
+ export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
151
+ # Limit MAX_JOBS otherwise the github runner goes OOM
152
+ # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM
153
+ MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
154
+ tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
155
+ wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
156
+ ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
157
+ echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
158
+
159
+ - name: Log Built Wheels
160
+ run: |
161
+ ls dist
162
+
163
+ - name: Get the tag version
164
+ id: extract_branch
165
+ run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
166
+
167
+ - name: Get Release with tag
168
+ id: get_current_release
169
+ uses: joutvhu/get-release@v1
170
+ with:
171
+ tag_name: ${{ steps.extract_branch.outputs.branch }}
172
+ env:
173
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
174
+
175
+ - name: Upload Release Asset
176
+ id: upload_release_asset
177
+ uses: actions/upload-release-asset@v1
178
+ env:
179
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
180
+ with:
181
+ upload_url: ${{ steps.get_current_release.outputs.upload_url }}
182
+ asset_path: ./dist/${{env.wheel_name}}
183
+ asset_name: ${{env.wheel_name}}
184
+ asset_content_type: application/*
185
+
186
+ publish_package:
187
+ name: Publish package
188
+ needs: [build_wheels]
189
+
190
+ runs-on: ubuntu-latest
191
+
192
+ steps:
193
+ - uses: actions/checkout@v4
194
+
195
+ - uses: actions/setup-python@v5
196
+ with:
197
+ python-version: '3.10'
198
+
199
+ - name: Install dependencies
200
+ run: |
201
+ pip install ninja packaging wheel twine
202
+ # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)
203
+ pip install setuptools==75.8.0
204
+ # We don't want to download anything CUDA-related here
205
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
206
+
207
+ - name: Build core package
208
+ env:
209
+ FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
210
+ run: |
211
+ python setup.py sdist --dist-dir=dist
212
+
213
+ - name: Deploy
214
+ env:
215
+ TWINE_USERNAME: "__token__"
216
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
217
+ run: |
218
+ python -m twine upload dist/*
cookbooks/flash-attention/.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.ncu-rep
2
+ .DS_store
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ bin/
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+ .eggs/
26
+
27
+ # IDE-related
28
+ .idea/
29
+
30
+ # Dev
31
+ venv
cookbooks/flash-attention/.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "csrc/cutlass"]
2
+ path = csrc/cutlass
3
+ url = https://github.com/NVIDIA/cutlass.git
4
+ [submodule "csrc/composable_kernel"]
5
+ path = csrc/composable_kernel
6
+ url = https://github.com/ROCm/composable_kernel.git
cookbooks/flash-attention/AUTHORS ADDED
@@ -0,0 +1 @@
 
 
1
+ Tri Dao, trid@cs.stanford.edu
cookbooks/flash-attention/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cookbooks/flash-attention/MANIFEST.in ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ recursive-include csrc *.cu
2
+ recursive-include csrc *.h
3
+ recursive-include csrc *.cuh
4
+ recursive-include csrc *.cpp
5
+ recursive-include csrc *.hpp
6
+ recursive-include csrc *.py
7
+
8
+ recursive-include flash_attn *.cu
9
+ recursive-include flash_attn *.h
10
+ recursive-include flash_attn *.cuh
11
+ recursive-include flash_attn *.cpp
12
+ recursive-include flash_attn *.hpp
cookbooks/flash-attention/Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ clean_dist:
3
+ rm -rf dist/*
4
+
5
+ create_dist: clean_dist
6
+ python setup.py sdist
7
+
8
+ upload_package: create_dist
9
+ twine upload dist/*
cookbooks/flash-attention/README.md ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashAttention
2
+ This repository provides the official implementation of FlashAttention and
3
+ FlashAttention-2 from the
4
+ following papers.
5
+
6
+ **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
7
+ Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
8
+ Paper: https://arxiv.org/abs/2205.14135
9
+ IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
10
+ ![FlashAttention](assets/flashattn_banner.jpg)
11
+
12
+ **FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
13
+ Tri Dao
14
+
15
+ Paper: https://tridao.me/publications/flash2/flash2.pdf
16
+
17
+ ![FlashAttention-2](assets/flashattention_logo.png)
18
+
19
+
20
+ ## Usage
21
+
22
+ We've been very happy to see FlashAttention being widely adopted in such a short
23
+ time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
24
+ contains a partial list of places where FlashAttention is being used.
25
+
26
+ FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
27
+ Please cite and credit FlashAttention if you use it.
28
+
29
+
30
+ ## FlashAttention-3 beta release
31
+ FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
32
+
33
+ Blogpost: https://tridao.me/blog/2024/flash3/
34
+
35
+ Paper: https://tridao.me/publications/flash3/flash3.pdf
36
+
37
+ ![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png)
38
+
39
+ This is a beta release for testing / benchmarking before we integrate that with
40
+ the rest of the repo.
41
+
42
+ Currently released:
43
+ - FP16 / BF16 forward and backward, FP8 forward
44
+
45
+ Requirements: H100 / H800 GPU, CUDA >= 12.3.
46
+
47
+ We highly recommend CUDA 12.8 for best performance.
48
+
49
+ To install:
50
+ ```sh
51
+ cd hopper
52
+ python setup.py install
53
+ ```
54
+ To run the test:
55
+ ```sh
56
+ export PYTHONPATH=$PWD
57
+ pytest -q -s test_flash_attn.py
58
+ ```
59
+ Once the package is installed, you can import it as follows:
60
+ ```python
61
+ import flash_attn_interface
62
+ flash_attn_interface.flash_attn_func()
63
+ ```
64
+
65
+ ## Installation and features
66
+ **Requirements:**
67
+ - CUDA toolkit or ROCm toolkit
68
+ - PyTorch 2.2 and above.
69
+ - `packaging` Python package (`pip install packaging`)
70
+ - `ninja` Python package (`pip install ninja`) *
71
+ - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
72
+
73
+ \* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
74
+ --version` then `echo $?` should return exit code 0). If not (sometimes `ninja
75
+ --version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
76
+ `ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
77
+ compiling can take a very long time (2h) since it does not use multiple CPU
78
+ cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
79
+
80
+ **To install:**
81
+ ```sh
82
+ pip install flash-attn --no-build-isolation
83
+ ```
84
+ Alternatively you can compile from source:
85
+ ```sh
86
+ python setup.py install
87
+ ```
88
+
89
+ If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
90
+ run too many parallel compilation jobs that could exhaust the amount of RAM. To
91
+ limit the number of parallel compilation jobs, you can set the environment
92
+ variable `MAX_JOBS`:
93
+ ```sh
94
+ MAX_JOBS=4 pip install flash-attn --no-build-isolation
95
+ ```
96
+
97
+ **Interface:** `src/flash_attention_interface.py`
98
+
99
+ ### NVIDIA CUDA Support
100
+ **Requirements:**
101
+ - CUDA 12.0 and above.
102
+
103
+ We recommend the
104
+ [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
105
+ container from Nvidia, which has all the required tools to install FlashAttention.
106
+
107
+ FlashAttention-2 with CUDA currently supports:
108
+ 1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
109
+ GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
110
+ GPUs for now.
111
+ 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
112
+ 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
113
+
114
+ ### AMD ROCm Support
115
+ ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.
116
+
117
+ **Requirements:**
118
+ - ROCm 6.0 and above.
119
+
120
+ We recommend the
121
+ [Pytorch](https://hub.docker.com/r/rocm/pytorch)
122
+ container from ROCm, which has all the required tools to install FlashAttention.
123
+
124
+ #### Composable Kernel Backend
125
+ FlashAttention-2 ROCm CK backend currently supports:
126
+ 1. MI200 or MI300 GPUs.
127
+ 2. Datatype fp16 and bf16
128
+ 3. Both forward's and backward's head dimensions up to 256.
129
+
130
+ #### Triton Backend
131
+ The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
132
+
133
+ It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
134
+
135
+ These features are supported in Fwd and Bwd
136
+ 1) Fwd and Bwd with causal masking
137
+ 2) Variable sequence lengths
138
+ 3) Arbitrary Q and KV sequence lengths
139
+ 4) Arbitrary head sizes
140
+
141
+ These features are supported in Fwd for now. We will add them to backward soon.
142
+ 1) Multi and grouped query attention
143
+ 2) ALiBi and matrix bias
144
+
145
+ These features are in development
146
+ 1) Paged Attention
147
+ 2) Sliding Window
148
+ 3) Rotary embeddings
149
+ 4) Dropout
150
+ 5) Performance Improvements
151
+
152
+ #### Getting Started
153
+ To get started with the triton backend for AMD, follow the steps below.
154
+
155
+ First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).
156
+
157
+ ```
158
+ git clone https://github.com/triton-lang/triton
159
+ cd triton
160
+ git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
161
+ pip install --verbose -e python
162
+ ```
163
+ Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
164
+
165
+ ```
166
+ export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
167
+ cd flash-attention
168
+ python setup.py install
169
+ pytest tests/test_flash_attn.py
170
+ ```
171
+
172
+
173
+ ## How to use FlashAttention
174
+
175
+ The main functions implement scaled dot product attention (softmax(Q @ K^T *
176
+ softmax_scale) @ V):
177
+ ```python
178
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
179
+ ```
180
+
181
+ ```python
182
+ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
183
+ window_size=(-1, -1), alibi_slopes=None, deterministic=False):
184
+ """dropout_p should be set to 0.0 during evaluation
185
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
186
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
187
+ of the gradients of Q, K, V.
188
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
189
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
190
+ Arguments:
191
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
192
+ dropout_p: float. Dropout probability.
193
+ softmax_scale: float. The scaling of QK^T before applying softmax.
194
+ Default to 1 / sqrt(headdim).
195
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
196
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
197
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
198
+ the attention score of query i and key j.
199
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
200
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
201
+ Return:
202
+ out: (batch_size, seqlen, nheads, headdim).
203
+ """
204
+ ```
205
+
206
+ ```python
207
+ flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
208
+ window_size=(-1, -1), alibi_slopes=None, deterministic=False):
209
+ """dropout_p should be set to 0.0 during evaluation
210
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
211
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
212
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
213
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
214
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
215
+ will only attend to keys between
216
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
217
+
218
+ Arguments:
219
+ q: (batch_size, seqlen, nheads, headdim)
220
+ k: (batch_size, seqlen, nheads_k, headdim)
221
+ v: (batch_size, seqlen, nheads_k, headdim)
222
+ dropout_p: float. Dropout probability.
223
+ softmax_scale: float. The scaling of QK^T before applying softmax.
224
+ Default to 1 / sqrt(headdim).
225
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
226
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
227
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
228
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
229
+ is added to the attention score of query i and key j.
230
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
231
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
232
+ Return:
233
+ out: (batch_size, seqlen, nheads, headdim).
234
+ """
235
+ ```
236
+
237
+ ```python
238
+ def flash_attn_with_kvcache(
239
+ q,
240
+ k_cache,
241
+ v_cache,
242
+ k=None,
243
+ v=None,
244
+ rotary_cos=None,
245
+ rotary_sin=None,
246
+ cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
247
+ cache_batch_idx: Optional[torch.Tensor] = None,
248
+ block_table: Optional[torch.Tensor] = None,
249
+ softmax_scale=None,
250
+ causal=False,
251
+ window_size=(-1, -1), # -1 means infinite context window
252
+ rotary_interleaved=True,
253
+ alibi_slopes=None,
254
+ ):
255
+ """
256
+ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
257
+ k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
258
+ the previous step, and update them with the new keys/values from the current step, and do
259
+ attention with the updated cache, all in 1 kernel.
260
+
261
+ If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
262
+ For example, the KV cache could be pre-allocated with the max sequence length, and you can use
263
+ cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
264
+
265
+ Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
266
+ rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
267
+ If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
268
+ and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
269
+ If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
270
+ indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
271
+
272
+ See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
273
+
274
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
275
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
276
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
277
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
278
+
279
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
280
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
281
+ 1 1 1 1 0
282
+ 1 1 1 1 1
283
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
284
+ 0 0
285
+ 0 0
286
+ 0 0
287
+ 1 0
288
+ 1 1
289
+ If the row of the mask is all zero, the output will be zero.
290
+
291
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
292
+ will only attend to keys between
293
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
294
+
295
+ Note: Does not support backward pass.
296
+
297
+ Arguments:
298
+ q: (batch_size, seqlen, nheads, headdim)
299
+ k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
300
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
301
+ page_block_size must be a multiple of 256.
302
+ v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
303
+ or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
304
+ k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
305
+ k with k_cache, starting at the indices specified by cache_seqlens.
306
+ v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
307
+ rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
308
+ to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
309
+ rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
310
+ cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
311
+ KV cache.
312
+ block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
313
+ cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
314
+ If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
315
+ If the indices are not distinct, and k and v are provided, the values updated in the cache
316
+ might come from any of the duplicate indices.
317
+ softmax_scale: float. The scaling of QK^T before applying softmax.
318
+ Default to 1 / sqrt(headdim).
319
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
320
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
321
+ rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
322
+ If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
323
+ rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
324
+ (i.e. GPT-NeoX style).
325
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
326
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
327
+ is added to the attention score of query i and key j.
328
+
329
+ Return:
330
+ out: (batch_size, seqlen, nheads, headdim).
331
+ """
332
+ ```
333
+
334
+ To see how these functions are used in a multi-head attention layer (which
335
+ includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
336
+
337
+ ## Changelog
338
+
339
+ ### 2.0: Complete rewrite, 2x faster
340
+ Upgrading from FlashAttention (1.x) to FlashAttention-2
341
+
342
+ These functions have been renamed:
343
+ - `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
344
+ - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
345
+ - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
346
+
347
+ If the inputs have the same sequence lengths in the same batch, it is simpler
348
+ and faster to use these functions:
349
+ ```python
350
+ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
351
+ ```
352
+ ```python
353
+ flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
354
+ ```
355
+ ### 2.1: Change behavior of causal flag
356
+
357
+ If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
358
+ bottom right corner of the attention matrix, instead of the top-left corner.
359
+
360
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =
361
+ masked out) is:
362
+ v2.0:
363
+ 1 0 0 0 0
364
+ 1 1 0 0 0
365
+ v2.1:
366
+ 1 1 1 1 0
367
+ 1 1 1 1 1
368
+
369
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
370
+ v2.0:
371
+ 1 0
372
+ 1 1
373
+ 1 1
374
+ 1 1
375
+ 1 1
376
+ v2.1:
377
+ 0 0
378
+ 0 0
379
+ 0 0
380
+ 1 0
381
+ 1 1
382
+ If the row of the mask is all zero, the output will be zero.
383
+
384
+ ### 2.2: Optimize for inference
385
+
386
+ Optimize for inference (iterative decoding) when query has very small sequence
387
+ length (e.g., query sequence length = 1). The bottleneck here is to load KV
388
+ cache as fast as possible, and we split the loading across different thread
389
+ blocks, with a separate kernel to combine results.
390
+
391
+ See the function `flash_attn_with_kvcache` with more features for inference
392
+ (perform rotary embedding, updating KV cache inplace).
393
+
394
+ Thanks to the xformers team, and in particular Daniel Haziza, for this
395
+ collaboration.
396
+
397
+ ### 2.3: Local (i.e., sliding window) attention
398
+
399
+ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
400
+ AI](https://mistral.ai/) and in particular Timothée Lacroix for this
401
+ contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
402
+
403
+ ### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
404
+
405
+ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
406
+
407
+ Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
408
+
409
+ ### 2.5: Paged KV cache.
410
+
411
+ Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
412
+ Thanks to @beginlner for this contribution.
413
+
414
+ ### 2.6: Softcapping.
415
+
416
+ Support attention with softcapping, as used in Gemma-2 and Grok models.
417
+ Thanks to @Narsil and @lucidrains for this contribution.
418
+
419
+ ### 2.7: Compatibility with torch compile
420
+
421
+ Thanks to @ani300 for this contribution.
422
+
423
+ ## Performance
424
+
425
+ We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
426
+
427
+ We currently have benchmarks for these GPUs:
428
+ * [A100](#a100)
429
+ * [H100](#h100)
430
+ <!-- * [RTX 3090](#rtx-3090) -->
431
+ <!-- * [T4](#t4) -->
432
+
433
+ ### A100
434
+
435
+ We display FlashAttention speedup using these parameters:
436
+ * Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
437
+ * Sequence length 512, 1k, 2k, 4k, 8k, 16k.
438
+ * Batch size set to 16k / seqlen.
439
+
440
+ #### Speedup
441
+
442
+ ![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)
443
+
444
+ #### Memory
445
+
446
+ ![FlashAttention memory](assets/flashattn_memory.jpg)
447
+
448
+ We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
449
+ Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
450
+ We see 10X memory savings at sequence length 2K, and 20X at 4K.
451
+ As a result, FlashAttention can scale to much longer sequence lengths.
452
+
453
+ ### H100
454
+
455
+ ![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)
456
+
457
+ ## Full model code and training script
458
+
459
+ We have released the full GPT model
460
+ [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
461
+ We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
462
+ cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
463
+ compared to the baseline implementation from Huggingface, reaching up to 225
464
+ TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
465
+ any activation checkpointing).
466
+
467
+ We also include a training
468
+ [script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
469
+ train GPT2 on Openwebtext and GPT3 on The Pile.
470
+
471
+ ## Triton implementation of FlashAttention
472
+
473
+ Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
474
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
475
+
476
+ As Triton is a higher-level language than CUDA, it might be easier to understand
477
+ and experiment with. The notations in the Triton implementation are also closer
478
+ to what's used in our paper.
479
+
480
+ We also have an experimental implementation in Triton that support attention
481
+ bias (e.g. ALiBi):
482
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
483
+
484
+
485
+ ## Tests
486
+ We test that FlashAttention produces the same output and gradient as a reference
487
+ implementation, up to some numerical tolerance. In particular, we check that the
488
+ maximum numerical error of FlashAttention is at most twice the numerical error
489
+ of a baseline implementation in Pytorch (for different head dimensions, input
490
+ dtype, sequence length, causal / non-causal).
491
+
492
+ To run the tests:
493
+ ```sh
494
+ pytest -q -s tests/test_flash_attn.py
495
+ ```
496
+ ## When you encounter issues
497
+
498
+ This new release of FlashAttention-2 has been tested on several GPT-style
499
+ models, mostly on A100 GPUs.
500
+
501
+ If you encounter bugs, please open a GitHub Issue!
502
+
503
+ ## Tests
504
+ To run the tests:
505
+ ```sh
506
+ pytest tests/test_flash_attn_ck.py
507
+ ```
508
+
509
+ ## Citation
510
+ If you use this codebase, or otherwise found our work valuable, please cite:
511
+ ```
512
+ @inproceedings{dao2022flashattention,
513
+ title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
514
+ author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
515
+ booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
516
+ year={2022}
517
+ }
518
+ @inproceedings{dao2023flashattention2,
519
+ title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
520
+ author={Dao, Tri},
521
+ booktitle={International Conference on Learning Representations (ICLR)},
522
+ year={2024}
523
+ }
524
+ ```
cookbooks/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png ADDED

Git LFS Details

  • SHA256: c1d31938cd597204eaee34f54b551123dc0f4b4554d46a2bc7907653a876232d
  • Pointer size: 131 Bytes
  • Size of remote file: 378 kB
cookbooks/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png ADDED

Git LFS Details

  • SHA256: 92bec94ee65e454bf55efbfe32dfde0419c40cb7017da1902de760ae1803505e
  • Pointer size: 131 Bytes
  • Size of remote file: 315 kB
cookbooks/flash-attention/assets/flash3_fp16_fwd.png ADDED

Git LFS Details

  • SHA256: 8d32a7a34f9ab2160e0a314cb3e6d8c7b51d9c59f9b30de93cfa04e3eae0aacd
  • Pointer size: 131 Bytes
  • Size of remote file: 205 kB
cookbooks/flash-attention/assets/flashattention_logo.png ADDED

Git LFS Details

  • SHA256: 61969fc112a38be106744ce2c416a2bca8026a173ef3cbb883826c998732958c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.74 MB
cookbooks/flash-attention/assets/flashattn_banner.jpg ADDED

Git LFS Details

  • SHA256: dbf9e1e910446414035e90c05bd7cb5932e390c438fd0622c04d2172d39ac63b
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
cookbooks/flash-attention/assets/flashattn_banner.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f4df0222057bbffcd2894fbae18bbfa6304e5d0583d47e44e9ac7a97bfb75ce
3
+ size 474702
cookbooks/flash-attention/assets/flashattn_memory.jpg ADDED
cookbooks/flash-attention/assets/flashattn_speedup.jpg ADDED

Git LFS Details

  • SHA256: a55ea48b7296b77530b5b0af46460870e749e166039c2ab81acc3473fc8e01c8
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
cookbooks/flash-attention/assets/flashattn_speedup_3090.jpg ADDED

Git LFS Details

  • SHA256: e8e693583cbb4d2880fd1bb03c6b6649e639df81194dbf2aec50fd9de02e8bd6
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
cookbooks/flash-attention/assets/flashattn_speedup_a100_d128.jpg ADDED

Git LFS Details

  • SHA256: d27d69f1e25e7537c57f8376d3cb4c1b28bdf5812a925626d2c24c3972a8779d
  • Pointer size: 131 Bytes
  • Size of remote file: 128 kB
cookbooks/flash-attention/assets/flashattn_speedup_t4.jpg ADDED

Git LFS Details

  • SHA256: e475db1780c9bdf5bfb069b13761f12b47a6cc332a39aba742873f2e64b3749d
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
cookbooks/flash-attention/assets/flashattn_speedup_t4_fwd.jpg ADDED

Git LFS Details

  • SHA256: 783673fc7c8f1fd3b60b3db07b29cd0b10ed40a545ae8a473d4b6608cc0c11f2
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
cookbooks/flash-attention/assets/gpt2_training_curve.jpg ADDED

Git LFS Details

  • SHA256: 16cc3582fe0923f3e14670e0b0e8754e7797041be2e8907a5173e1fdda1c8ab3
  • Pointer size: 131 Bytes
  • Size of remote file: 172 kB
cookbooks/flash-attention/assets/gpt2_training_efficiency.jpg ADDED

Git LFS Details

  • SHA256: 55839bbfda5ad56ebd48834154822499b83ceb7935efd744c8d76dde57782584
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
cookbooks/flash-attention/assets/gpt3_training_curve.jpg ADDED

Git LFS Details

  • SHA256: 6f5c428f0a413158bba4be56bcfc33c891c189e0c60d18987a04368d03917c86
  • Pointer size: 131 Bytes
  • Size of remote file: 187 kB
cookbooks/flash-attention/assets/gpt3_training_efficiency.jpg ADDED

Git LFS Details

  • SHA256: 3fce533a0e170e2f832aedc735581a37739e8c17e2e36387555d3efaed4ea259
  • Pointer size: 131 Bytes
  • Size of remote file: 392 kB
cookbooks/flash-attention/benchmarks/benchmark_alibi.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Sanghun Cho, Tri Dao.
2
+
3
+ import pickle
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+ from flash_attn.layers.rotary import apply_rotary_emb
11
+
12
+ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
13
+ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
14
+
15
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
16
+
17
+ try:
18
+ import xformers.ops as xops
19
+ except ImportError:
20
+ xops = None
21
+
22
+
23
+ def generate_cos_sin(seqlen, rotary_dim, device, dtype):
24
+ assert rotary_dim % 2 == 0
25
+ angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
26
+ cos = torch.cos(angle).to(dtype=dtype)
27
+ sin = torch.sin(angle).to(dtype=dtype)
28
+ return cos, sin
29
+
30
+
31
+ def flash_rotary(q, k, v, cos, sin, causal=False):
32
+ # corrected by @tridao comments
33
+ q = apply_rotary_emb(
34
+ q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
35
+ )
36
+ k = apply_rotary_emb(
37
+ k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
38
+ )
39
+
40
+ return flash_attn_func(q, k, v, causal=causal)
41
+
42
+
43
+ def attn_bias_from_alibi_slopes(
44
+ slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
45
+ ):
46
+ batch, nheads = slopes.shape
47
+ device = slopes.device
48
+ slopes = rearrange(slopes, "b h -> b h 1 1")
49
+ if causal:
50
+ return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
51
+ else:
52
+ row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
53
+ col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
54
+ sk = (
55
+ seqlen_k
56
+ if key_padding_mask is None
57
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
58
+ )
59
+ sq = (
60
+ seqlen_q
61
+ if query_padding_mask is None
62
+ else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
63
+ )
64
+ relative_pos = torch.abs(row_idx + sk - sq - col_idx)
65
+ return -slopes * relative_pos.to(dtype=slopes.dtype)
66
+
67
+
68
+ def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
69
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
70
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
71
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
72
+
73
+
74
+ def efficiency(flop, time):
75
+ return (flop / time / 10**12) if not math.isnan(time) else 0.0
76
+
77
+
78
+ def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
79
+ """
80
+ Arguments:
81
+ q, k, v: (batch_size, seqlen, nheads, head_dim)
82
+ dropout_p: float
83
+ attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
84
+ Output:
85
+ output: (batch_size, seqlen, nheads, head_dim)
86
+ """
87
+ batch_size, seqlen, nheads, d = q.shape
88
+ q = rearrange(q, 'b t h d -> (b h) t d')
89
+ k = rearrange(k, 'b s h d -> (b h) d s')
90
+ softmax_scale = 1.0 / math.sqrt(d)
91
+ # Preallocate attn_weights for `baddbmm`
92
+ if attn_bias is not None:
93
+ scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
94
+ else:
95
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
96
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
97
+ '(b h) t s -> b h t s', h=nheads)
98
+ if causal:
99
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
100
+ # So we have to construct the mask in float
101
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
102
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
103
+ scores = scores + causal_mask.to(dtype=scores.dtype)
104
+ attention = torch.softmax(scores, dim=-1)
105
+ attention_drop = F.dropout(attention, dropout_p)
106
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
107
+ return output.to(dtype=q.dtype)
108
+
109
+
110
+ def time_fwd_bwd(func, *args, **kwargs):
111
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
112
+ return time_f[1].mean, time_b[1].mean
113
+
114
+
115
+ repeats = 30
116
+ device = 'cuda'
117
+ dtype = torch.float16
118
+
119
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
120
+ causal_vals = [False, True]
121
+ headdim_vals = [64, 128]
122
+ dim = 2048
123
+ dropout_p = 0.0
124
+
125
+ methods = (["fa2_alibi", "torch"]
126
+ + (["xformers"] if xops is not None else [])
127
+ + ["sdpa"]
128
+ + ["fa2_baseline"]
129
+ + ["fa2_rotary"])
130
+
131
+ time_f = {}
132
+ time_b = {}
133
+ time_f_b = {}
134
+ speed_f = {}
135
+ speed_b = {}
136
+ speed_f_b = {}
137
+ for causal in causal_vals:
138
+ for headdim in headdim_vals:
139
+ for batch_size, seqlen in bs_seqlen_vals:
140
+ config = (causal, headdim, batch_size, seqlen)
141
+ nheads = dim // headdim
142
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
143
+ requires_grad=True) for _ in range(3)]
144
+ # alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
145
+ alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
146
+ attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
147
+ attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
148
+ f, b = time_fwd_bwd(
149
+ flash_attn_func,
150
+ q, k, v,
151
+ dropout_p,
152
+ causal=causal,
153
+ # alibi_slopes=alibi_slopes,
154
+ alibi_slopes=None,
155
+ repeats=repeats,
156
+ verbose=False
157
+ )
158
+ time_f[config, "fa2_baseline"] = f
159
+ time_b[config, "fa2_baseline"] = b
160
+
161
+ q = q.detach().requires_grad_(True)
162
+ k = k.detach().requires_grad_(True)
163
+ v = v.detach().requires_grad_(True)
164
+ f, b = time_fwd_bwd(
165
+ flash_attn_func,
166
+ q, k, v,
167
+ dropout_p,
168
+ causal=causal,
169
+ alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
170
+ # alibi_slopes=None,
171
+ repeats=repeats,
172
+ verbose=False
173
+ )
174
+ time_f[config, "fa2_alibi"] = f
175
+ time_b[config, "fa2_alibi"] = b
176
+
177
+ try:
178
+ q = q.detach().requires_grad_(True)
179
+ k = k.detach().requires_grad_(True)
180
+ v = v.detach().requires_grad_(True)
181
+ f, b = time_fwd_bwd(
182
+ attention_pytorch,
183
+ q, k, v,
184
+ dropout_p,
185
+ causal=causal,
186
+ attn_bias=attn_bias,
187
+ repeats=repeats,
188
+ verbose=False
189
+ )
190
+ except: # Skip if OOM
191
+ f, b = float('nan'), float('nan')
192
+ time_f[config, "torch"] = f
193
+ time_b[config, "torch"] = b
194
+
195
+ # F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
196
+ with torch.backends.cuda.sdp_kernel(enable_flash=False):
197
+ q_pt = q.detach().requires_grad_(True).transpose(1, 2)
198
+ k_pt = k.detach().requires_grad_(True).transpose(1, 2)
199
+ v_pt = v.detach().requires_grad_(True).transpose(1, 2)
200
+ f, b = time_fwd_bwd(
201
+ F.scaled_dot_product_attention,
202
+ q_pt, k_pt, v_pt,
203
+ attn_mask=attn_bias,
204
+ dropout_p=dropout_p,
205
+ is_causal=causal,
206
+ repeats=repeats,
207
+ verbose=False
208
+ )
209
+ time_f[config, "sdpa"] = f
210
+ time_b[config, "sdpa"] = b
211
+
212
+ if xops is not None:
213
+ q = q.detach().requires_grad_(True)
214
+ k = k.detach().requires_grad_(True)
215
+ v = v.detach().requires_grad_(True)
216
+ if causal:
217
+ attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
218
+ # NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
219
+ # `flshattB@v2.3.6` is not supported because:
220
+ # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
221
+ # `cutlassB` is not supported because:
222
+ # attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
223
+ attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
224
+ else:
225
+ attn_bias_xops = attn_bias.to(dtype=q.dtype)
226
+ f, b = time_fwd_bwd(
227
+ xops.memory_efficient_attention,
228
+ q, k, v,
229
+ attn_bias_xops,
230
+ dropout_p,
231
+ repeats=repeats,
232
+ verbose=False
233
+ )
234
+ time_f[config, "xformers"] = f
235
+ time_b[config, "xformers"] = b
236
+
237
+ q = q.detach().requires_grad_(True)
238
+ k = k.detach().requires_grad_(True)
239
+ v = v.detach().requires_grad_(True)
240
+ cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
241
+ f, b = time_fwd_bwd(
242
+ flash_rotary,
243
+ q, k, v,
244
+ cos, sin,
245
+ causal,
246
+ repeats=repeats,
247
+ verbose=False
248
+ )
249
+ time_f[config, "fa2_rotary"] = f
250
+ time_b[config, "fa2_rotary"] = b
251
+
252
+ print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
253
+ csv_output = ""
254
+ csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
255
+ for method in methods:
256
+ time_f_b[config, method] = time_f[config, method] + time_b[config, method]
257
+ speed_f[config, method] = efficiency(
258
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
259
+ time_f[config, method]
260
+ )
261
+ speed_b[config, method] = efficiency(
262
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
263
+ time_b[config, method]
264
+ )
265
+ speed_f_b[config, method] = efficiency(
266
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
267
+ time_f_b[config, method]
268
+ )
269
+ print(
270
+ f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
271
+ f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
272
+ f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
273
+ )
274
+ csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
275
+ print(csv_output)
cookbooks/flash-attention/benchmarks/benchmark_causal.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
10
+ from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
12
+ # # from flash_attn.triton.fused_attention import attention as attention
13
+ # from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
14
+ # from flash_attn.flash_attn_triton_og import attention as attention_og
15
+
16
+ # from triton.ops.flash_attention import attention as attention_triton
17
+
18
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
19
+
20
+ try:
21
+ from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
22
+ except ImportError:
23
+ scaled_upper_triang_masked_softmax = None
24
+
25
+
26
+ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
27
+ """
28
+ Arguments:
29
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
30
+ dropout_p: float
31
+ Output:
32
+ output: (batch_size, seqlen, nheads, head_dim)
33
+ """
34
+ batch_size, seqlen, _, nheads, d = qkv.shape
35
+ q, k, v = qkv.unbind(dim=2)
36
+ q = rearrange(q, 'b t h d -> (b h) t d')
37
+ k = rearrange(k, 'b s h d -> (b h) d s')
38
+ softmax_scale = 1.0 / math.sqrt(d)
39
+ # Preallocate attn_weights for `baddbmm`
40
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
41
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
42
+ '(b h) t s -> b h t s', h=nheads)
43
+ if causal:
44
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
45
+ # So we have to construct the mask in float
46
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
47
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
48
+ scores = scores + causal_mask.to(dtype=scores.dtype)
49
+ attention = torch.softmax(scores, dim=-1)
50
+ attention_drop = F.dropout(attention, dropout_p)
51
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
52
+ return output.to(dtype=qkv.dtype)
53
+
54
+
55
+ def attention_megatron(qkv):
56
+ """
57
+ Arguments:
58
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
59
+ Output:
60
+ output: (batch_size, seqlen, nheads, head_dim)
61
+ """
62
+ batch_size, seqlen, _, nheads, d = qkv.shape
63
+ q, k, v = qkv.unbind(dim=2)
64
+ q = rearrange(q, 'b t h d -> (b h) t d')
65
+ k = rearrange(k, 'b s h d -> (b h) d s')
66
+ softmax_scale = 1.0 / math.sqrt(d)
67
+ # Preallocate attn_weights for `baddbmm`
68
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
69
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
70
+ '(b h) t s -> b h t s', h=nheads)
71
+ attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0)
72
+ output = torch.einsum('bhts,bshd->bthd', attention, v)
73
+ return output.to(dtype=qkv.dtype)
74
+
75
+
76
+ torch.manual_seed(0)
77
+ repeats = 30
78
+ batch_size = 8
79
+ seqlen = 2048
80
+ nheads = 12
81
+ headdim = 128
82
+ # nheads = 24
83
+ # headdim = 64
84
+ # batch_size = 64
85
+ # seqlen = 512
86
+ # nheads = 8
87
+ # headdim = 128
88
+ dropout_p = 0.0
89
+ causal = True
90
+ dtype = torch.float16
91
+ device = 'cuda'
92
+
93
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
94
+ requires_grad=True)
95
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
96
+ device=qkv.device)
97
+
98
+ qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
99
+ # benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
100
+ # cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
101
+ # pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
102
+ # cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
103
+ benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
104
+ pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
105
+
106
+ # for dropout_p in [0.1, 0.0]:
107
+ # for causal in [False, True]:
108
+ # print(f"### {dropout_p = }, {causal = } ###")
109
+ # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
110
+
111
+
112
+ # nheads_k = 2
113
+ # q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
114
+ # kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
115
+ # requires_grad=True)
116
+ # if fav2_kvpacked_func is not None:
117
+ # benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
118
+ # pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
119
+
120
+ # dropout_p = 0.0
121
+ # causal = False
122
+ # benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
123
+ # repeats=repeats, desc='PyTorch Attention')
124
+
125
+ # benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
126
+ # pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
127
+
128
+ # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
129
+ # requires_grad=True) for _ in range(3)]
130
+ # benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
131
+ # # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
132
+
133
+ # if scaled_upper_triang_masked_softmax is not None:
134
+ # benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
135
+
136
+ # from src.ops.fftconv import fftconv_func
137
+
138
+ # dim = nheads * headdim
139
+ # u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
140
+ # k = torch.randn(dim, seqlen, device=device, requires_grad=True)
141
+ # D = torch.randn(dim, device=device, requires_grad=True)
142
+ # benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
143
+ # pytorch_profiler(fftconv_func, u, k, D, backward=True)
144
+ # pytorch_profiler(torch.fft.rfft, u.float())
145
+
146
+ flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
147
+ ideal_a100_time = flops / 312 / 1e9
148
+ print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
149
+ exit(0)
150
+
151
+
152
+ def time_fwd_bwd(func, *args, **kwargs):
153
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
154
+ return time_f[1].mean, time_b[1].mean
155
+
156
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
157
+ causal_vals = [False, True]
158
+ headdim_vals = [64, 128]
159
+ dim = 2048
160
+ dropout_p = 0.0
161
+
162
+ time_f = {}
163
+ time_b = {}
164
+ for causal in causal_vals:
165
+ for headdim in headdim_vals:
166
+ for batch_size, seqlen in bs_seqlen_vals:
167
+ nheads = dim // headdim
168
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
169
+ requires_grad=True)
170
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
171
+ device=qkv.device)
172
+ qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
173
+ f, b = time_fwd_bwd(
174
+ flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
175
+ causal=causal, repeats=repeats, verbose=False
176
+ )
177
+ time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
178
+ time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
179
+
180
+ qkv = qkv.detach().requires_grad_(True)
181
+ f, b = time_fwd_bwd(
182
+ fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
183
+ )
184
+ time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
185
+ time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
186
+
187
+ # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
188
+ # requires_grad=True) for _ in range(3)]
189
+ # # Try both values of sequence_parallel and pick the faster one
190
+ # f, b = time_fwd_bwd(
191
+ # attention_triton, q, k, v, causal, headdim**(-0.5),
192
+ # False, repeats=repeats, verbose=False
193
+ # )
194
+ # _, b0 = time_fwd_bwd(
195
+ # attention_triton, q, k, v, causal, headdim**(-0.5),
196
+ # True, repeats=repeats, verbose=False
197
+ # )
198
+ # time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
199
+ # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
200
+
201
+ if seqlen <= 8 * 1024:
202
+ qkv = qkv.detach().requires_grad_(True)
203
+ f, b = time_fwd_bwd(
204
+ attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
205
+ )
206
+ else:
207
+ f, b = float('nan'), float('nan')
208
+ time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
209
+ time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
210
+
211
+ # q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
212
+ # requires_grad=True) for _ in range(3)]
213
+ # import xformers.ops as xops
214
+ # f, b = time_fwd_bwd(
215
+ # xops.memory_efficient_attention, q, k, v,
216
+ # attn_bias=xops.LowerTriangularMask() if causal else None,
217
+ # op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
218
+ # )
219
+ # time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
220
+ # time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
221
+
222
+
223
+ import pickle
224
+ with open('flash2_attn_time_h100.plk', 'wb') as fp:
225
+ pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
cookbooks/flash-attention/benchmarks/benchmark_flash_attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the newest triton version with
2
+ # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
3
+ import pickle
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
12
+ from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
13
+
14
+ from flash_attn import flash_attn_qkvpacked_func
15
+
16
+ try:
17
+ from triton.ops.flash_attention import attention as attention_triton
18
+ except ImportError:
19
+ attention_triton = None
20
+
21
+ try:
22
+ import xformers.ops as xops
23
+ except ImportError:
24
+ xops = None
25
+
26
+
27
+ def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
28
+ assert mode in ["fwd", "bwd", "fwd_bwd"]
29
+ f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
30
+ return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
31
+
32
+ def efficiency(flop, time):
33
+ return (flop / time / 10**12) if not math.isnan(time) else 0.0
34
+
35
+
36
+ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
37
+ """
38
+ Arguments:
39
+ qkv: (batch_size, seqlen, 3, nheads, head_dim)
40
+ dropout_p: float
41
+ Output:
42
+ output: (batch_size, seqlen, nheads, head_dim)
43
+ """
44
+ batch_size, seqlen, _, nheads, d = qkv.shape
45
+ q, k, v = qkv.unbind(dim=2)
46
+ q = rearrange(q, 'b t h d -> (b h) t d')
47
+ k = rearrange(k, 'b s h d -> (b h) d s')
48
+ softmax_scale = 1.0 / math.sqrt(d)
49
+ # Preallocate attn_weights for `baddbmm`
50
+ scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
51
+ scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
52
+ '(b h) t s -> b h t s', h=nheads)
53
+ if causal:
54
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
55
+ # So we have to construct the mask in float
56
+ causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
57
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
58
+ scores = scores + causal_mask.to(dtype=scores.dtype)
59
+ attention = torch.softmax(scores, dim=-1)
60
+ attention_drop = F.dropout(attention, dropout_p)
61
+ output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
62
+ return output.to(dtype=qkv.dtype)
63
+
64
+
65
+ def time_fwd_bwd(func, *args, **kwargs):
66
+ time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
67
+ return time_f[1].mean, time_b[1].mean
68
+
69
+
70
+ repeats = 30
71
+ device = 'cuda'
72
+ dtype = torch.float16
73
+
74
+ bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
75
+ causal_vals = [False, True]
76
+ headdim_vals = [64, 128]
77
+ dim = 2048
78
+ dropout_p = 0.0
79
+
80
+ methods = (["Flash2", "Pytorch"]
81
+ + (["Triton"] if attention_triton is not None else [])
82
+ + (["xformers.c"] if xops is not None else [])
83
+ + (["xformers.f"] if xops is not None else []))
84
+
85
+ time_f = {}
86
+ time_b = {}
87
+ time_f_b = {}
88
+ speed_f = {}
89
+ speed_b = {}
90
+ speed_f_b = {}
91
+ for causal in causal_vals:
92
+ for headdim in headdim_vals:
93
+ for batch_size, seqlen in bs_seqlen_vals:
94
+ config = (causal, headdim, batch_size, seqlen)
95
+ nheads = dim // headdim
96
+ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
97
+ requires_grad=True)
98
+ f, b = time_fwd_bwd(
99
+ flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
100
+ )
101
+ time_f[config, "Flash2"] = f
102
+ time_b[config, "Flash2"] = b
103
+
104
+ try:
105
+ qkv = qkv.detach().requires_grad_(True)
106
+ f, b = time_fwd_bwd(
107
+ attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
108
+ )
109
+ except: # Skip if OOM
110
+ f, b = float('nan'), float('nan')
111
+ time_f[config, "Pytorch"] = f
112
+ time_b[config, "Pytorch"] = b
113
+
114
+ if attention_triton is not None:
115
+ q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
116
+ requires_grad=True) for _ in range(3)]
117
+ # Try both values of sequence_parallel and pick the faster one
118
+ try:
119
+ f, b = time_fwd_bwd(
120
+ attention_triton, q, k, v, causal, headdim**(-0.5),
121
+ False, repeats=repeats, verbose=False
122
+ )
123
+ except:
124
+ f, b = float('nan'), float('inf')
125
+ try:
126
+ _, b0 = time_fwd_bwd(
127
+ attention_triton, q, k, v, causal, headdim**(-0.5),
128
+ True, repeats=repeats, verbose=False
129
+ )
130
+ except:
131
+ b0 = float('inf')
132
+ time_f[config, "Triton"] = f
133
+ time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
134
+
135
+ if xops is not None:
136
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
137
+ requires_grad=True) for _ in range(3)]
138
+ f, b = time_fwd_bwd(
139
+ xops.memory_efficient_attention, q, k, v,
140
+ attn_bias=xops.LowerTriangularMask() if causal else None,
141
+ op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
142
+ )
143
+ time_f[config, "xformers.c"] = f
144
+ time_b[config, "xformers.c"] = b
145
+
146
+ if xops is not None:
147
+ q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
148
+ requires_grad=True) for _ in range(3)]
149
+ f, b = time_fwd_bwd(
150
+ xops.memory_efficient_attention, q, k, v,
151
+ attn_bias=xops.LowerTriangularMask() if causal else None,
152
+ op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
153
+ )
154
+ time_f[config, "xformers.f"] = f
155
+ time_b[config, "xformers.f"] = b
156
+
157
+ print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
158
+ for method in methods:
159
+ time_f_b[config, method] = time_f[config, method] + time_b[config, method]
160
+ speed_f[config, method] = efficiency(
161
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
162
+ time_f[config, method]
163
+ )
164
+ speed_b[config, method] = efficiency(
165
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
166
+ time_b[config, method]
167
+ )
168
+ speed_f_b[config, method] = efficiency(
169
+ flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
170
+ time_f_b[config, method]
171
+ )
172
+ print(
173
+ f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
174
+ f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
175
+ f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
176
+ )
177
+
178
+
179
+ # with open('flash2_attn_time.plk', 'wb') as fp:
180
+ # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
cookbooks/flash-attention/benchmarks/benchmark_gemm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torch.utils.benchmark as benchmark
4
+
5
+ from triton.testing import do_bench
6
+
7
+ if torch.version.cuda:
8
+ backendBLAS = "cuBLAS"
9
+ elif torch.version.hip:
10
+ backendBLAS = "hipBLAS"
11
+
12
+ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
13
+ """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
14
+ if verbose:
15
+ print(desc, '- Forward pass')
16
+ t = benchmark.Timer(
17
+ stmt='fn(*inputs, **kwinputs)',
18
+ globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
19
+ num_threads=torch.get_num_threads(),
20
+ )
21
+ m = t.timeit(repeats)
22
+ if verbose:
23
+ print(m)
24
+ return t, m
25
+
26
+
27
+ torch.manual_seed(0)
28
+ repeats = 30
29
+ dtype = torch.bfloat16
30
+ device = 'cuda'
31
+ verbose = False
32
+ m, n = 8192, 8192
33
+
34
+ tflops_matmul = {}
35
+ tflops_matmul1 = {}
36
+ for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
37
+ a = torch.randn(m, k, device=device, dtype=dtype)
38
+ b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
39
+ nFLOPS_matmul = 2 * m * n * k
40
+ time.sleep(2) # to reduce power throttling
41
+ timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1]
42
+ tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
43
+ print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
44
+ time.sleep(2) # to reduce power throttling
45
+ ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
46
+ tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
47
+ print(f'[triton.test.do_bench] {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')
cookbooks/flash-attention/csrc/flash_attn/flash_api.cpp ADDED
@@ -0,0 +1,1485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
6
+ #include <torch/python.h>
7
+ #include <torch/nn/functional.h>
8
+ #include <c10/cuda/CUDAGuard.h>
9
+ #include <c10/cuda/CUDAStream.h>
10
+ #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
11
+ #include "philox_unpack.cuh" // For at::cuda::philox::unpack
12
+
13
+ #include <cutlass/numeric_types.h>
14
+
15
+ #include "namespace_config.h"
16
+ #include "hardware_info.h"
17
+ #include "flash.h"
18
+ #include "static_switch.h"
19
+
20
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
21
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
22
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
23
+
24
+ namespace FLASH_NAMESPACE {
25
+
26
+ void set_params_fprop(Flash_fwd_params &params,
27
+ // sizes
28
+ const size_t b,
29
+ const size_t seqlen_q,
30
+ const size_t seqlen_k,
31
+ const size_t seqlen_q_rounded,
32
+ const size_t seqlen_k_rounded,
33
+ const size_t h,
34
+ const size_t h_k,
35
+ const size_t d,
36
+ const size_t d_rounded,
37
+ // device pointers
38
+ const at::Tensor q,
39
+ const at::Tensor k,
40
+ const at::Tensor v,
41
+ at::Tensor out,
42
+ void *cu_seqlens_q_d,
43
+ void *cu_seqlens_k_d,
44
+ void *seqused_k,
45
+ void *p_d,
46
+ void *softmax_lse_d,
47
+ float p_dropout,
48
+ float softmax_scale,
49
+ int window_size_left,
50
+ int window_size_right,
51
+ const float softcap,
52
+ bool seqlenq_ngroups_swapped=false,
53
+ const bool unpadded_lse=false) {
54
+
55
+ // Reset the parameters
56
+ params = {};
57
+
58
+ params.is_bf16 = q.dtype() == torch::kBFloat16;
59
+
60
+ // Set the pointers and strides.
61
+ params.q_ptr = q.data_ptr();
62
+ params.k_ptr = k.data_ptr();
63
+ params.v_ptr = v.data_ptr();
64
+ // All stride are in elements, not bytes.
65
+ params.q_row_stride = q.stride(-3);
66
+ params.k_row_stride = k.stride(-3);
67
+ params.v_row_stride = v.stride(-3);
68
+ params.q_head_stride = q.stride(-2);
69
+ params.k_head_stride = k.stride(-2);
70
+ params.v_head_stride = v.stride(-2);
71
+ params.o_ptr = out.data_ptr();
72
+ params.o_row_stride = out.stride(-3);
73
+ params.o_head_stride = out.stride(-2);
74
+
75
+ if (cu_seqlens_q_d == nullptr) {
76
+ params.q_batch_stride = q.stride(0);
77
+ params.k_batch_stride = k.stride(0);
78
+ params.v_batch_stride = v.stride(0);
79
+ params.o_batch_stride = out.stride(0);
80
+ if (seqlenq_ngroups_swapped) {
81
+ params.q_batch_stride *= seqlen_q;
82
+ params.o_batch_stride *= seqlen_q;
83
+ }
84
+ }
85
+
86
+ params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
87
+ params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
88
+ params.seqused_k = static_cast<int *>(seqused_k);
89
+
90
+ // P = softmax(QK^T)
91
+ params.p_ptr = p_d;
92
+
93
+ // Softmax sum
94
+ params.softmax_lse_ptr = softmax_lse_d;
95
+
96
+ // Set the dimensions.
97
+ params.b = b;
98
+ params.h = h;
99
+ params.h_k = h_k;
100
+ params.h_h_k_ratio = h / h_k;
101
+ params.seqlen_q = seqlen_q;
102
+ params.seqlen_k = seqlen_k;
103
+ params.seqlen_q_rounded = seqlen_q_rounded;
104
+ params.seqlen_k_rounded = seqlen_k_rounded;
105
+ params.d = d;
106
+ params.d_rounded = d_rounded;
107
+
108
+ // Set the different scale values.
109
+ #ifdef FLASHATTENTION_DISABLE_SOFTCAP
110
+ TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
111
+ #endif
112
+ if (softcap > 0.0) {
113
+ params.softcap = softmax_scale / softcap;
114
+ params.scale_softmax = softcap;
115
+ params.scale_softmax_log2 = softcap * M_LOG2E;
116
+ } else{
117
+ // Remove potential NaN
118
+ params.softcap = 0.0;
119
+ params.scale_softmax = softmax_scale;
120
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
121
+ }
122
+
123
+ // Set this to probability of keeping an element to simplify things.
124
+ params.p_dropout = 1.f - p_dropout;
125
+ // Convert p from float to int so we don't have to convert the random uint to float to compare.
126
+ // [Minor] We want to round down since when we do the comparison we use <= instead of <
127
+ // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
128
+ // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
129
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
130
+ params.rp_dropout = 1.f / params.p_dropout;
131
+ params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
132
+ TORCH_CHECK(p_dropout < 1.f);
133
+ #ifdef FLASHATTENTION_DISABLE_DROPOUT
134
+ TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
135
+ #endif
136
+
137
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
138
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
139
+ params.is_causal = window_size_left < 0 && window_size_right == 0;
140
+
141
+ if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
142
+ if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
143
+ params.window_size_left = window_size_left;
144
+ params.window_size_right = window_size_right;
145
+
146
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
147
+ TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
148
+ "This flash attention build does not support local attention.");
149
+ #endif
150
+
151
+ params.is_seqlens_k_cumulative = true;
152
+
153
+ #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
154
+ TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
155
+ #endif
156
+
157
+ params.unpadded_lse = unpadded_lse;
158
+ params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
159
+ }
160
+
161
+ void set_params_dgrad(Flash_bwd_params &params,
162
+ // sizes
163
+ const size_t b,
164
+ const size_t seqlen_q,
165
+ const size_t seqlen_k,
166
+ const size_t seqlen_q_rounded,
167
+ const size_t seqlen_k_rounded,
168
+ const size_t h,
169
+ const size_t h_k,
170
+ const size_t d,
171
+ const size_t d_rounded,
172
+ // device pointers
173
+ const at::Tensor q,
174
+ const at::Tensor k,
175
+ const at::Tensor v,
176
+ const at::Tensor out,
177
+ const at::Tensor dout,
178
+ at::Tensor dq,
179
+ at::Tensor dk,
180
+ at::Tensor dv,
181
+ void *cu_seqlens_q_d,
182
+ void *cu_seqlens_k_d,
183
+ void *dq_accum_d,
184
+ void *dk_accum_d,
185
+ void *dv_accum_d,
186
+ void *softmax_lse_d,
187
+ void *dsoftmax_sum_d,
188
+ float p_dropout,
189
+ float softmax_scale,
190
+ int window_size_left,
191
+ int window_size_right,
192
+ const float softcap,
193
+ bool deterministic,
194
+ const bool unpadded_lse) {
195
+
196
+ set_params_fprop(params,
197
+ b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
198
+ q, k, v, out,
199
+ cu_seqlens_q_d,
200
+ cu_seqlens_k_d,
201
+ nullptr,
202
+ nullptr,
203
+ softmax_lse_d,
204
+ p_dropout,
205
+ softmax_scale,
206
+ window_size_left,
207
+ window_size_right,
208
+ softcap,
209
+ false, // seqlenq_ngroups_swapped
210
+ unpadded_lse);
211
+
212
+ // Set the pointers and strides.
213
+ params.do_ptr = dout.data_ptr();
214
+ params.do_row_stride = dout.stride(-3);
215
+ params.do_head_stride = dout.stride(-2);
216
+ params.dq_ptr = dq.data_ptr();
217
+ params.dk_ptr = dk.data_ptr();
218
+ params.dv_ptr = dv.data_ptr();
219
+ params.dq_row_stride = dq.stride(-3);
220
+ params.dk_row_stride = dk.stride(-3);
221
+ params.dv_row_stride = dv.stride(-3);
222
+ params.dq_head_stride = dq.stride(-2);
223
+ params.dk_head_stride = dk.stride(-2);
224
+ params.dv_head_stride = dv.stride(-2);
225
+
226
+ if (cu_seqlens_q_d == nullptr) {
227
+ params.do_batch_stride = dout.stride(0);
228
+ params.dq_batch_stride = dq.stride(0);
229
+ params.dk_batch_stride = dk.stride(0);
230
+ params.dv_batch_stride = dv.stride(0);
231
+ }
232
+
233
+ params.dq_accum_ptr = dq_accum_d;
234
+ params.dk_accum_ptr = dk_accum_d;
235
+ params.dv_accum_ptr = dv_accum_d;
236
+
237
+ // Softmax sum
238
+ params.dsoftmax_sum = dsoftmax_sum_d;
239
+
240
+ params.deterministic = deterministic;
241
+ }
242
+
243
+ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
244
+ FP16_SWITCH(!params.is_bf16, [&] {
245
+ HEADDIM_SWITCH(params.d, [&] {
246
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
247
+ if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
248
+ run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
249
+ } else {
250
+ run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
251
+ }
252
+ });
253
+ });
254
+ });
255
+ }
256
+
257
+ // Find the number of splits that maximizes the occupancy. For example, if we have
258
+ // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
259
+ // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
260
+ // splits as that would incur more HBM reads/writes.
261
+ // So we find the best efficiency, then find the smallest number of splits that gets 85%
262
+ // of the best efficiency.
263
+ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
264
+ // If we have enough to almost fill the SMs, then just use 1 split
265
+ if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
266
+ max_splits = std::min({max_splits, num_SMs, num_n_blocks});
267
+ float max_efficiency = 0.f;
268
+ std::vector<float> efficiency;
269
+ efficiency.reserve(max_splits);
270
+ auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
271
+ // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
272
+ // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
273
+ // (i.e. it's 11 splits anyway).
274
+ // So we check if the number of blocks per split is the same as the previous num_splits.
275
+ auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
276
+ return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
277
+ };
278
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
279
+ if (!is_split_eligible(num_splits)) {
280
+ efficiency.push_back(0.f);
281
+ } else {
282
+ float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
283
+ float eff = n_waves / ceil(n_waves);
284
+ // printf("num_splits = %d, eff = %f\n", num_splits, eff);
285
+ if (eff > max_efficiency) { max_efficiency = eff; }
286
+ efficiency.push_back(eff);
287
+ }
288
+ }
289
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
290
+ if (!is_split_eligible(num_splits)) { continue; }
291
+ if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
292
+ // printf("num_splits chosen = %d\n", num_splits);
293
+ return num_splits;
294
+ }
295
+ }
296
+ return 1;
297
+ }
298
+
299
+ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
300
+ const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
301
+ const int head_size_rounded, const float p_dropout,
302
+ const int num_splits, const int num_sm, struct c10::TensorOptions opts) {
303
+
304
+ // This needs to match with run_mha_fwd_splitkv_dispatch
305
+ const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
306
+ const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
307
+ // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
308
+ // In any case we don't expect seqlen_q to be larger than 64 for inference.
309
+ const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
310
+ params.num_splits = num_splits;
311
+ at::Tensor softmax_lse_accum;
312
+ at::Tensor out_accum;
313
+
314
+ if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
315
+ if (num_splits < 1) {
316
+ // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
317
+ params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
318
+ }
319
+ if (params.num_splits > 1) {
320
+ softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
321
+ out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
322
+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
323
+ params.oaccum_ptr = out_accum.data_ptr();
324
+ }
325
+ TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
326
+ }
327
+
328
+ return std::make_tuple(softmax_lse_accum, out_accum);
329
+ }
330
+
331
+ void set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
332
+ #ifdef FLASHATTENTION_DISABLE_ALIBI
333
+ TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
334
+ params.alibi_slopes_ptr = nullptr;
335
+ #else
336
+ if (alibi_slopes_.has_value()) {
337
+ auto alibi_slopes = alibi_slopes_.value();
338
+ TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
339
+ CHECK_DEVICE(alibi_slopes);
340
+ TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
341
+ TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
342
+ params.alibi_slopes_ptr = alibi_slopes.data_ptr();
343
+ params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
344
+ } else {
345
+ params.alibi_slopes_ptr = nullptr;
346
+ }
347
+ #endif
348
+ }
349
+
350
+ std::vector<at::Tensor>
351
+ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
352
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
353
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
354
+ std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
355
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
356
+ const float p_dropout,
357
+ const float softmax_scale,
358
+ bool is_causal,
359
+ int window_size_left,
360
+ int window_size_right,
361
+ const float softcap,
362
+ const bool return_softmax,
363
+ std::optional<at::Generator> gen_) {
364
+
365
+ // Otherwise the kernel will be launched from cuda:0 device
366
+ at::cuda::CUDAGuard device_guard{q.device()};
367
+
368
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
369
+ bool is_sm8x_min = cc_major >= 8;
370
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
371
+
372
+ auto q_dtype = q.dtype();
373
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
374
+ "FlashAttention only support fp16 and bf16 data type");
375
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
376
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
377
+
378
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
379
+
380
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
381
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
382
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
383
+
384
+ const auto sizes = q.sizes();
385
+
386
+ const int batch_size = sizes[0];
387
+ int seqlen_q = sizes[1];
388
+ int num_heads = sizes[2];
389
+ const int head_size = sizes[3];
390
+ const int seqlen_k = k.size(1);
391
+ const int num_heads_k = k.size(2);
392
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
393
+ TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
394
+ TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
395
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
396
+
397
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
398
+
399
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
400
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
401
+
402
+ // causal=true is the same as causal=false in this case
403
+ if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
404
+ if (is_causal) { window_size_right = 0; }
405
+
406
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
407
+ // H/t Daniel Haziza
408
+ const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
409
+ const int ngroups = num_heads / num_heads_k;
410
+ if (seqlenq_ngroups_swapped) {
411
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
412
+ seqlen_q = ngroups;
413
+ num_heads = num_heads_k;
414
+ }
415
+
416
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
417
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
418
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
419
+
420
+ at::Tensor out;
421
+ if (out_.has_value()) {
422
+ out = out_.value();
423
+ TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
424
+ CHECK_DEVICE(out);
425
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
426
+ CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
427
+ if (seqlenq_ngroups_swapped) {
428
+ out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
429
+ }
430
+ } else {
431
+ out = torch::empty_like(q);
432
+ }
433
+
434
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
435
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
436
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
437
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
438
+
439
+ auto opts = q.options();
440
+
441
+ auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
442
+ at::Tensor p;
443
+ // Only return softmax if there's dropout to reduce compilation time
444
+ if (return_softmax) {
445
+ TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
446
+ p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
447
+ }
448
+ else {
449
+ p = torch::empty({ 0 }, opts);
450
+ }
451
+
452
+ Flash_fwd_params params;
453
+ set_params_fprop(params,
454
+ batch_size,
455
+ seqlen_q, seqlen_k,
456
+ seqlen_q_rounded, seqlen_k_rounded,
457
+ num_heads, num_heads_k,
458
+ head_size, head_size_rounded,
459
+ q, k, v, out,
460
+ /*cu_seqlens_q_d=*/nullptr,
461
+ /*cu_seqlens_k_d=*/nullptr,
462
+ /*seqused_k=*/nullptr,
463
+ return_softmax ? p.data_ptr() : nullptr,
464
+ softmax_lse.data_ptr(),
465
+ p_dropout,
466
+ softmax_scale,
467
+ window_size_left,
468
+ window_size_right,
469
+ softcap
470
+ );
471
+
472
+ // Keep references to these tensors to extend their lifetime
473
+ at::Tensor softmax_lse_accum, out_accum;
474
+ std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
475
+ params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
476
+ head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
477
+
478
+ // number of times random will be generated per thread, to offset philox counter in thc random
479
+ // state
480
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
481
+ int64_t counter_offset = params.b * params.h * 32;
482
+ auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
483
+ auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
484
+ // Forward kernel will populate memory with the seed and offset.
485
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
486
+
487
+ if (p_dropout > 0.0) {
488
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
489
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
490
+ // See Note [Acquire lock when using random generators]
491
+ std::lock_guard<std::mutex> lock(gen->mutex_);
492
+ params.philox_args = gen->philox_cuda_state(counter_offset);
493
+ }
494
+
495
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
496
+
497
+ if (seqlen_k > 0) {
498
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
499
+ run_mha_fwd(params, stream);
500
+ } else {
501
+ // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
502
+ out.zero_();
503
+ softmax_lse.fill_(std::numeric_limits<float>::infinity());
504
+ }
505
+
506
+ if (seqlenq_ngroups_swapped) {
507
+ out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
508
+ q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
509
+ softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
510
+ }
511
+ return {out, softmax_lse, p, rng_state};
512
+ }
513
+
514
+ std::vector<at::Tensor>
515
+ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
516
+ const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
517
+ const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
518
+ std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
519
+ const at::Tensor &cu_seqlens_q, // b+1
520
+ const at::Tensor &cu_seqlens_k, // b+1
521
+ std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
522
+ std::optional<const at::Tensor> &leftpad_k_, // batch_size
523
+ std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
524
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
525
+ int max_seqlen_q,
526
+ const int max_seqlen_k,
527
+ const float p_dropout,
528
+ const float softmax_scale,
529
+ const bool zero_tensors,
530
+ bool is_causal,
531
+ int window_size_left,
532
+ int window_size_right,
533
+ const float softcap,
534
+ const bool return_softmax,
535
+ std::optional<at::Generator> gen_) {
536
+
537
+ // Otherwise the kernel will be launched from cuda:0 device
538
+ at::cuda::CUDAGuard device_guard{q.device()};
539
+
540
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
541
+ bool is_sm8x_min = cc_major >= 8;
542
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
543
+
544
+ auto q_dtype = q.dtype();
545
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
546
+ "FlashAttention only support fp16 and bf16 data type");
547
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
548
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
549
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
550
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
551
+
552
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
553
+ CHECK_DEVICE(cu_seqlens_q);
554
+ CHECK_DEVICE(cu_seqlens_k);
555
+
556
+ at::Tensor block_table;
557
+ const bool paged_KV = block_table_.has_value();
558
+ if (paged_KV) {
559
+ block_table = block_table_.value();
560
+ CHECK_DEVICE(block_table);
561
+ TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
562
+ TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
563
+ }
564
+
565
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
566
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
567
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
568
+ CHECK_CONTIGUOUS(cu_seqlens_q);
569
+ CHECK_CONTIGUOUS(cu_seqlens_k);
570
+
571
+ const auto sizes = q.sizes();
572
+
573
+ const int batch_size = cu_seqlens_q.numel() - 1;
574
+ int num_heads = sizes[1];
575
+ const int head_size = sizes[2];
576
+ const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
577
+
578
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
579
+
580
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
581
+ const int num_blocks = !paged_KV ? 0 : k.size(0);
582
+ const int page_block_size = !paged_KV ? 1 : k.size(1);
583
+ TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
584
+
585
+ if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
586
+ if (is_causal) { window_size_right = 0; }
587
+
588
+ void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
589
+
590
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
591
+ // H/t Daniel Haziza
592
+ const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
593
+ const int ngroups = num_heads / num_heads_k;
594
+ if (seqlenq_ngroups_swapped) {
595
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
596
+ max_seqlen_q = ngroups;
597
+ num_heads = num_heads_k;
598
+ cu_seqlens_q_d = nullptr;
599
+ }
600
+
601
+ const int total_q = q.sizes()[0];
602
+
603
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
604
+ TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
605
+ TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
606
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
607
+
608
+ if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
609
+ if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
610
+
611
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
612
+ if (!paged_KV) {
613
+ const int total_k = k.size(0);
614
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
615
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size);
616
+ } else {
617
+ CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
618
+ CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
619
+ CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
620
+ }
621
+
622
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
623
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
624
+ if (seqused_k.has_value()){
625
+ auto seqused_k_ = seqused_k.value();
626
+ TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
627
+ TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
628
+ TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
629
+ CHECK_SHAPE(seqused_k_, batch_size);
630
+ }
631
+
632
+ at::Tensor out;
633
+ if (out_.has_value()) {
634
+ out = out_.value();
635
+ TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
636
+ CHECK_DEVICE(out);
637
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
638
+ CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
639
+ if (seqlenq_ngroups_swapped) {
640
+ out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
641
+ }
642
+ } else {
643
+ out = torch::empty_like(q);
644
+ }
645
+
646
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
647
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
648
+ const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
649
+ const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
650
+
651
+ auto opts = q.options();
652
+ auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
653
+ at::Tensor p;
654
+ // Only return softmax if there's dropout to reduce compilation time
655
+ if (return_softmax) {
656
+ TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
657
+ p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
658
+ }
659
+ else {
660
+ p = torch::empty({ 0 }, opts);
661
+ }
662
+
663
+ if (zero_tensors) {
664
+ out.zero_();
665
+ softmax_lse.fill_(-std::numeric_limits<float>::infinity());
666
+ if (return_softmax) {p.zero_();}
667
+ }
668
+
669
+ Flash_fwd_params params;
670
+ set_params_fprop(params,
671
+ batch_size,
672
+ max_seqlen_q, max_seqlen_k,
673
+ seqlen_q_rounded, seqlen_k_rounded,
674
+ num_heads, num_heads_k,
675
+ head_size, head_size_rounded,
676
+ q, k, v, out,
677
+ cu_seqlens_q_d,
678
+ cu_seqlens_k.data_ptr(),
679
+ seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
680
+ return_softmax ? p.data_ptr() : nullptr,
681
+ softmax_lse.data_ptr(),
682
+ p_dropout,
683
+ softmax_scale,
684
+ window_size_left,
685
+ window_size_right,
686
+ softcap,
687
+ seqlenq_ngroups_swapped,
688
+ /*unpadded_lse*/true);
689
+ params.total_q = total_q;
690
+
691
+ if (paged_KV) {
692
+ params.block_table = block_table.data_ptr<int>();
693
+ params.block_table_batch_stride = block_table.stride(0);
694
+ params.k_batch_stride = k.stride(0);
695
+ params.v_batch_stride = v.stride(0);
696
+ }
697
+ params.page_block_size = page_block_size;
698
+ // Keep references to these tensors to extend their lifetime
699
+ at::Tensor softmax_lse_accum, out_accum;
700
+ if (seqlenq_ngroups_swapped) {
701
+ // Only apply split-k for decoding
702
+ std::tie(softmax_lse_accum, out_accum) =
703
+ set_params_splitkv(params, batch_size, num_heads, head_size,
704
+ max_seqlen_k, max_seqlen_q, head_size_rounded,
705
+ p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
706
+ }
707
+
708
+ if (leftpad_k_.has_value()) {
709
+ auto leftpad_k = leftpad_k_.value();
710
+ TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
711
+ TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
712
+ CHECK_DEVICE(leftpad_k);
713
+ CHECK_CONTIGUOUS(leftpad_k);
714
+ CHECK_SHAPE(leftpad_k, batch_size);
715
+ params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
716
+ }
717
+
718
+ // number of times random will be generated per thread, to offset philox counter in thc random
719
+ // state
720
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
721
+ int64_t counter_offset = params.b * params.h * 32;
722
+ auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
723
+ auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
724
+ // Forward kernel will populate memory with the seed and offset.
725
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
726
+
727
+ if (p_dropout > 0.0) {
728
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
729
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
730
+ // See Note [Acquire lock when using random generators]
731
+ std::lock_guard<std::mutex> lock(gen->mutex_);
732
+ params.philox_args = gen->philox_cuda_state(counter_offset);
733
+ }
734
+
735
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
736
+
737
+ if (max_seqlen_k > 0) {
738
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
739
+ run_mha_fwd(params, stream, paged_KV);
740
+ } else {
741
+ // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
742
+ out.zero_();
743
+ softmax_lse.fill_(std::numeric_limits<float>::infinity());
744
+ }
745
+
746
+ if (seqlenq_ngroups_swapped) {
747
+ int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
748
+ int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
749
+ out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
750
+ q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
751
+ softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
752
+ }
753
+
754
+ return {out, softmax_lse, p, rng_state};
755
+ }
756
+
757
+ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
758
+ FP16_SWITCH(!params.is_bf16, [&] {
759
+ HEADDIM_SWITCH(params.d, [&] {
760
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
761
+ run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
762
+ });
763
+ });
764
+ });
765
+ }
766
+
767
+ std::vector<at::Tensor>
768
+ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
769
+ const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
770
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
771
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
772
+ const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
773
+ const at::Tensor &softmax_lse, // b x h x seqlen_q
774
+ std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
775
+ std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
776
+ std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
777
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
778
+ const float p_dropout, // probability to drop
779
+ const float softmax_scale,
780
+ const bool is_causal,
781
+ int window_size_left,
782
+ int window_size_right,
783
+ const float softcap,
784
+ const bool deterministic,
785
+ std::optional<at::Generator> gen_,
786
+ std::optional<at::Tensor> &rng_state) {
787
+
788
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
789
+ TORCH_CHECK(false, "This flash attention build does not support backward.");
790
+ #endif
791
+ if (is_causal) { window_size_right = 0; }
792
+
793
+ // Otherwise the kernel will be launched from cuda:0 device
794
+ at::cuda::CUDAGuard device_guard{q.device()};
795
+
796
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
797
+ bool is_sm8x_min = cc_major >= 8;
798
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
799
+
800
+ bool is_dropout = p_dropout > 0.0;
801
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
802
+
803
+ auto q_dtype = q.dtype();
804
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
805
+ "FlashAttention only support fp16 and bf16 data type");
806
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
807
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
808
+ TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
809
+ TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
810
+
811
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
812
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
813
+
814
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
815
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
816
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
817
+ TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
818
+ TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
819
+
820
+ const auto sizes = q.sizes();
821
+
822
+ const int batch_size = sizes[0];
823
+ const int seqlen_q = sizes[1];
824
+ const int num_heads = sizes[2];
825
+ const int head_size = sizes[3];
826
+ const int seqlen_k = k.size(1);
827
+ const int num_heads_k = k.size(2);
828
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
829
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
830
+ TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
831
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
832
+
833
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
834
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
835
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
836
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
837
+
838
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
839
+
840
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
841
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
842
+
843
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
844
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
845
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
846
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
847
+ CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
848
+
849
+ at::Tensor dq, dk, dv;
850
+ if (dq_.has_value()) {
851
+ dq = dq_.value();
852
+ TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
853
+ CHECK_DEVICE(dq);
854
+ TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
855
+ CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
856
+ } else {
857
+ dq = torch::empty_like(q);
858
+ }
859
+ if (dk_.has_value()) {
860
+ dk = dk_.value();
861
+ TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
862
+ CHECK_DEVICE(dk);
863
+ TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
864
+ CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
865
+ } else {
866
+ dk = torch::empty_like(k);
867
+ }
868
+ if (dv_.has_value()) {
869
+ dv = dv_.value();
870
+ TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
871
+ CHECK_DEVICE(dv);
872
+ TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
873
+ CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
874
+ } else {
875
+ dv = torch::empty_like(v);
876
+ }
877
+
878
+ // bool loop = seqlen_k > blocksize_c;
879
+ // TODO: change later, for now set to true for simplicity
880
+ bool loop = true;
881
+
882
+ auto opts = q.options();
883
+ auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
884
+ at::Tensor dq_accum;
885
+ at::Tensor dk_accum, dv_accum;
886
+ if (loop) {
887
+ if (!deterministic) {
888
+ dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
889
+ } else {
890
+ const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
891
+ dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
892
+ }
893
+ // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
894
+ // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
895
+ }
896
+
897
+ at::Tensor dk_expanded, dv_expanded;
898
+ if (num_heads_k != num_heads) { // MQA / GQA
899
+ dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
900
+ dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
901
+ } else {
902
+ dk_expanded = dk;
903
+ dv_expanded = dv;
904
+ }
905
+
906
+ Flash_bwd_params params;
907
+
908
+ set_params_dgrad(params,
909
+ batch_size,
910
+ seqlen_q, seqlen_k,
911
+ seqlen_q_rounded, seqlen_k_rounded,
912
+ num_heads, num_heads_k,
913
+ head_size, head_size_rounded,
914
+ q, k, v, out,
915
+ dout, dq, dk_expanded, dv_expanded,
916
+ nullptr,
917
+ nullptr,
918
+ loop ? dq_accum.data_ptr() : nullptr,
919
+ // loop ? dk_accum.data_ptr() : nullptr,
920
+ // loop ? dv_accum.data_ptr() : nullptr,
921
+ nullptr,
922
+ nullptr,
923
+ softmax_lse.data_ptr(),
924
+ softmax_d.data_ptr(),
925
+ p_dropout,
926
+ softmax_scale,
927
+ window_size_left,
928
+ window_size_right,
929
+ softcap,
930
+ deterministic,
931
+ /*unpadded_lse*/false);
932
+ params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
933
+
934
+ auto launch = &run_mha_bwd;
935
+
936
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
937
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
938
+
939
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
940
+ int64_t counter_offset = params.b * params.h * 32;
941
+
942
+ if ( rng_state.has_value() ) {
943
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
944
+ } else if( is_dropout ) {
945
+ // See Note [Acquire lock when using random generators]
946
+ std::lock_guard<std::mutex> lock(gen->mutex_);
947
+ params.philox_args = gen->philox_cuda_state(counter_offset);
948
+ auto seeds = at::cuda::philox::unpack(params.philox_args);
949
+ params.rng_state[0] = std::get<0>(seeds);
950
+ params.rng_state[1] = std::get<1>(seeds);
951
+ }
952
+
953
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
954
+
955
+ if (seqlen_q > 0) {
956
+ launch(params, stream);
957
+ } else {
958
+ // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
959
+ dk_expanded.zero_();
960
+ dv_expanded.zero_();
961
+ softmax_d.zero_();
962
+ }
963
+
964
+ // For MQA/GQA we need to sum dK and dV across the groups
965
+ if (num_heads_k != num_heads) {
966
+ at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
967
+ at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
968
+ }
969
+
970
+ return { dq, dk, dv, softmax_d };
971
+ }
972
+
973
+ std::vector<at::Tensor>
974
+ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
975
+ const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
976
+ const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
977
+ const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
978
+ const at::Tensor &out, // total_q x num_heads x head_size
979
+ const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
980
+ std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
981
+ std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
982
+ std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
983
+ const at::Tensor &cu_seqlens_q, // b+1
984
+ const at::Tensor &cu_seqlens_k, // b+1
985
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
986
+ const int max_seqlen_q,
987
+ const int max_seqlen_k, // max sequence length to choose the kernel
988
+ const float p_dropout, // probability to drop
989
+ const float softmax_scale,
990
+ const bool zero_tensors,
991
+ const bool is_causal,
992
+ int window_size_left,
993
+ int window_size_right,
994
+ const float softcap,
995
+ const bool deterministic,
996
+ std::optional<at::Generator> gen_,
997
+ std::optional<at::Tensor> &rng_state) {
998
+
999
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
1000
+ TORCH_CHECK(false, "This flash attention build does not support backward.");
1001
+ #endif
1002
+ if (is_causal) { window_size_right = 0; }
1003
+
1004
+ // Otherwise the kernel will be launched from cuda:0 device
1005
+ at::cuda::CUDAGuard device_guard{q.device()};
1006
+
1007
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
1008
+ bool is_sm8x_min = cc_major >= 8;
1009
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
1010
+
1011
+ bool is_dropout = p_dropout > 0.0;
1012
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1013
+
1014
+ auto q_dtype = q.dtype();
1015
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
1016
+ "FlashAttention only support fp16 and bf16 data type");
1017
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
1018
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
1019
+ TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
1020
+ TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
1021
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
1022
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
1023
+
1024
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1025
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1026
+ CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
1027
+
1028
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1029
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1030
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1031
+ TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1032
+ TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1033
+ CHECK_CONTIGUOUS(cu_seqlens_q);
1034
+ CHECK_CONTIGUOUS(cu_seqlens_k);
1035
+
1036
+ const auto sizes = q.sizes();
1037
+
1038
+ const int total_q = sizes[0];
1039
+ const int batch_size = cu_seqlens_q.numel() - 1;
1040
+ const int num_heads = sizes[1];
1041
+ const int head_size = sizes[2];
1042
+ const int total_k = k.size(0);
1043
+ const int num_heads_k = k.size(1);
1044
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
1045
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1046
+ TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
1047
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1048
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
1049
+
1050
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1051
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
1052
+ const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
1053
+ const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
1054
+
1055
+ if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
1056
+ if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
1057
+
1058
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
1059
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1060
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size);
1061
+ CHECK_SHAPE(out, total_q, num_heads, head_size);
1062
+ CHECK_SHAPE(dout, total_q, num_heads, head_size);
1063
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1064
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1065
+
1066
+ at::Tensor dq, dk, dv;
1067
+ if (dq_.has_value()) {
1068
+ dq = dq_.value();
1069
+ TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
1070
+ CHECK_DEVICE(dq);
1071
+ TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1072
+ CHECK_SHAPE(dq, total_q, num_heads, head_size);
1073
+ } else {
1074
+ dq = torch::empty_like(q);
1075
+ }
1076
+ if (dk_.has_value()) {
1077
+ dk = dk_.value();
1078
+ TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
1079
+ CHECK_DEVICE(dk);
1080
+ TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1081
+ CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1082
+ } else {
1083
+ dk = torch::empty_like(k);
1084
+ }
1085
+ if (dv_.has_value()) {
1086
+ dv = dv_.value();
1087
+ TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
1088
+ CHECK_DEVICE(dv);
1089
+ TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1090
+ CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
1091
+ } else {
1092
+ dv = torch::empty_like(v);
1093
+ }
1094
+
1095
+ // bool loop = max_seqlen_k > blocksize_c;
1096
+ // TODO: change later, for now set to true for simplicity
1097
+ bool loop = true;
1098
+
1099
+ auto opts = q.options();
1100
+ auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
1101
+ at::Tensor dq_accum;
1102
+ if (loop) {
1103
+ // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
1104
+ // because that would be too large if there is a very long sequence and the rest of the sequences are short.
1105
+ // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
1106
+ // Note that 128 is the max block size on the seqlen_q dimension.
1107
+ // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
1108
+ // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
1109
+ // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
1110
+ // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
1111
+ // Same holds for softmax_d, since LSE is stored in unpadded format.
1112
+ if (!deterministic) {
1113
+ dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1114
+ } else {
1115
+ const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
1116
+ dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1117
+ }
1118
+ }
1119
+
1120
+ at::Tensor dk_expanded, dv_expanded;
1121
+ if (num_heads_k != num_heads) { // MQA / GQA
1122
+ dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
1123
+ dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
1124
+ } else {
1125
+ dk_expanded = dk;
1126
+ dv_expanded = dv;
1127
+ }
1128
+
1129
+ if( zero_tensors ) {
1130
+ dq.zero_();
1131
+ dk_expanded.zero_();
1132
+ dv_expanded.zero_();
1133
+ softmax_d.zero_();
1134
+ }
1135
+
1136
+ Flash_bwd_params params;
1137
+
1138
+ set_params_dgrad(params,
1139
+ batch_size,
1140
+ max_seqlen_q, max_seqlen_k,
1141
+ seqlen_q_rounded, seqlen_k_rounded,
1142
+ num_heads, num_heads_k,
1143
+ head_size, head_size_rounded,
1144
+ q, k, v, out,
1145
+ dout, dq, dk_expanded, dv_expanded,
1146
+ cu_seqlens_q.data_ptr(),
1147
+ cu_seqlens_k.data_ptr(),
1148
+ loop ? dq_accum.data_ptr() : nullptr,
1149
+ nullptr,
1150
+ nullptr,
1151
+ softmax_lse.data_ptr(),
1152
+ softmax_d.data_ptr(),
1153
+ p_dropout,
1154
+ softmax_scale,
1155
+ window_size_left,
1156
+ window_size_right,
1157
+ softcap,
1158
+ deterministic,
1159
+ /*unpadded_lse*/true);
1160
+ params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
1161
+ params.total_q = total_q;
1162
+
1163
+ auto launch = &run_mha_bwd;
1164
+
1165
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
1166
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
1167
+
1168
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
1169
+ int64_t counter_offset = params.b * params.h * 32;
1170
+
1171
+ if ( rng_state.has_value() ) {
1172
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
1173
+ } else if( is_dropout ) {
1174
+ // See Note [Acquire lock when using random generators]
1175
+ std::lock_guard<std::mutex> lock(gen->mutex_);
1176
+ params.philox_args = gen->philox_cuda_state(counter_offset);
1177
+ auto seeds = at::cuda::philox::unpack(params.philox_args);
1178
+ params.rng_state[0] = std::get<0>(seeds);
1179
+ params.rng_state[1] = std::get<1>(seeds);
1180
+ }
1181
+
1182
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1183
+
1184
+ if (max_seqlen_q > 0) {
1185
+ launch(params, stream);
1186
+ } else {
1187
+ // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1188
+ dk_expanded.zero_();
1189
+ dv_expanded.zero_();
1190
+ softmax_d.zero_();
1191
+ }
1192
+
1193
+ // For MQA/GQA we need to sum dK and dV across the groups
1194
+ if (num_heads_k != num_heads) {
1195
+ at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1196
+ at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1197
+ }
1198
+
1199
+ return { dq, dk, dv, softmax_d };
1200
+ }
1201
+
1202
+ std::vector<at::Tensor>
1203
+ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1204
+ const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1205
+ const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1206
+ std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1207
+ std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1208
+ std::optional<const at::Tensor> &seqlens_k_, // batch_size
1209
+ std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1210
+ std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1211
+ std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1212
+ std::optional<const at::Tensor> &leftpad_k_, // batch_size
1213
+ std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1214
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1215
+ std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
1216
+ const float softmax_scale,
1217
+ bool is_causal,
1218
+ int window_size_left,
1219
+ int window_size_right,
1220
+ const float softcap,
1221
+ bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1222
+ int num_splits
1223
+ ) {
1224
+
1225
+ // Otherwise the kernel will be launched from cuda:0 device
1226
+ at::cuda::CUDAGuard device_guard{q.device()};
1227
+
1228
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
1229
+ bool is_sm8x_min = cc_major >= 8;
1230
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
1231
+
1232
+ auto q_dtype = q.dtype();
1233
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
1234
+ "FlashAttention only support fp16 and bf16 data type");
1235
+ TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
1236
+ TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
1237
+
1238
+ CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
1239
+
1240
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1241
+ TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1242
+ TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1243
+
1244
+ at::Tensor block_table;
1245
+ const bool paged_KV = block_table_.has_value();
1246
+ if (paged_KV) {
1247
+ TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
1248
+ block_table = block_table_.value();
1249
+ CHECK_DEVICE(block_table);
1250
+ TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
1251
+ TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
1252
+ }
1253
+
1254
+ const auto sizes = q.sizes();
1255
+
1256
+ const int batch_size = sizes[0];
1257
+ int seqlen_q = sizes[1];
1258
+ int num_heads = sizes[2];
1259
+ const int head_size_og = sizes[3];
1260
+
1261
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
1262
+ const int num_blocks = !paged_KV ? 0 : kcache.size(0);
1263
+ const int page_block_size = !paged_KV ? 1 : kcache.size(1);
1264
+ TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
1265
+ const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
1266
+ const int num_heads_k = kcache.size(2);
1267
+ const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
1268
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
1269
+ TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
1270
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1271
+
1272
+ // causal=true is the same as causal=false in this case
1273
+ if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
1274
+ if (is_causal) { window_size_right = 0; }
1275
+
1276
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
1277
+ // H/t Daniel Haziza
1278
+ const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
1279
+ if (seqlenq_ngroups_swapped) {
1280
+ const int ngroups = num_heads / num_heads_k;
1281
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
1282
+ seqlen_q = ngroups;
1283
+ num_heads = num_heads_k;
1284
+ }
1285
+
1286
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
1287
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
1288
+
1289
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
1290
+ if (!paged_KV) {
1291
+ CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1292
+ CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1293
+ } else {
1294
+ CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1295
+ CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1296
+ CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
1297
+ }
1298
+
1299
+ at::Tensor q_padded, kcache_padded, vcache_padded;
1300
+ if (head_size_og % 8 != 0) {
1301
+ q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1302
+ kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1303
+ vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1304
+ } else {
1305
+ q_padded = q;
1306
+ kcache_padded = kcache;
1307
+ vcache_padded = vcache;
1308
+ }
1309
+
1310
+ at::Tensor out;
1311
+ if (out_.has_value()) {
1312
+ out = out_.value();
1313
+ TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
1314
+ CHECK_DEVICE(out);
1315
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1316
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
1317
+ if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
1318
+ } else {
1319
+ out = torch::empty_like(q_padded);
1320
+ }
1321
+
1322
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1323
+ const int head_size = round_multiple(head_size_og, 8);
1324
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
1325
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
1326
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
1327
+
1328
+ auto opts = q.options();
1329
+
1330
+ auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1331
+
1332
+ Flash_fwd_params params;
1333
+ set_params_fprop(params,
1334
+ batch_size,
1335
+ seqlen_q, seqlen_k,
1336
+ seqlen_q_rounded, seqlen_k_rounded,
1337
+ num_heads, num_heads_k,
1338
+ head_size, head_size_rounded,
1339
+ q_padded, kcache_padded, vcache_padded, out,
1340
+ /*cu_seqlens_q_d=*/nullptr,
1341
+ /*cu_seqlens_k_d=*/nullptr,
1342
+ /*seqused_k=*/nullptr,
1343
+ /*p_ptr=*/nullptr,
1344
+ softmax_lse.data_ptr(),
1345
+ /*p_dropout=*/0.f,
1346
+ softmax_scale,
1347
+ window_size_left,
1348
+ window_size_right,
1349
+ softcap
1350
+ );
1351
+
1352
+ at::Tensor k, v, k_padded, v_padded;
1353
+ if (k_.has_value()) {
1354
+ TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
1355
+ TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
1356
+ TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
1357
+ k = k_.value();
1358
+ v = v_.value();
1359
+ TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
1360
+ TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
1361
+ CHECK_DEVICE(k); CHECK_DEVICE(v);
1362
+ TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
1363
+ TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
1364
+ int seqlen_knew = k.size(1);
1365
+ CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
1366
+ CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
1367
+ if (head_size_og % 8 != 0) {
1368
+ k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1369
+ v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1370
+ } else {
1371
+ k_padded = k;
1372
+ v_padded = v;
1373
+ }
1374
+ params.seqlen_knew = seqlen_knew;
1375
+ params.knew_ptr = k_padded.data_ptr();
1376
+ params.vnew_ptr = v_padded.data_ptr();
1377
+ // All stride are in elements, not bytes.
1378
+ params.knew_batch_stride = k_padded.stride(0);
1379
+ params.vnew_batch_stride = v_padded.stride(0);
1380
+ params.knew_row_stride = k_padded.stride(-3);
1381
+ params.vnew_row_stride = v_padded.stride(-3);
1382
+ params.knew_head_stride = k_padded.stride(-2);
1383
+ params.vnew_head_stride = v_padded.stride(-2);
1384
+ }
1385
+
1386
+ if (seqlens_k_.has_value()) {
1387
+ auto seqlens_k = seqlens_k_.value();
1388
+ TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
1389
+ CHECK_DEVICE(seqlens_k);
1390
+ CHECK_CONTIGUOUS(seqlens_k);
1391
+ CHECK_SHAPE(seqlens_k, batch_size);
1392
+ params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
1393
+ }
1394
+ params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
1395
+ if (leftpad_k_.has_value()) {
1396
+ TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
1397
+ auto leftpad_k = leftpad_k_.value();
1398
+ TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
1399
+ CHECK_DEVICE(leftpad_k);
1400
+ CHECK_CONTIGUOUS(leftpad_k);
1401
+ CHECK_SHAPE(leftpad_k, batch_size);
1402
+ params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
1403
+ }
1404
+
1405
+ if (rotary_cos_.has_value()) {
1406
+ TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1407
+ auto rotary_cos = rotary_cos_.value();
1408
+ CHECK_DEVICE(rotary_cos);
1409
+ params.rotary_dim = rotary_cos.size(1) * 2;
1410
+ TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1411
+ TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1412
+ const int seqlen_ro = rotary_cos.size(0);
1413
+ TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1414
+ CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1415
+ CHECK_CONTIGUOUS(rotary_cos);
1416
+ TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1417
+
1418
+ TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1419
+ auto rotary_sin = rotary_sin_.value();
1420
+ CHECK_DEVICE(rotary_sin);
1421
+ CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1422
+ CHECK_CONTIGUOUS(rotary_sin);
1423
+ TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1424
+ params.rotary_cos_ptr = rotary_cos.data_ptr();
1425
+ params.rotary_sin_ptr = rotary_sin.data_ptr();
1426
+ params.is_rotary_interleaved = is_rotary_interleaved;
1427
+ } else {
1428
+ params.rotary_dim = 0;
1429
+ }
1430
+
1431
+ if (cache_batch_idx_.has_value()) {
1432
+ auto cache_batch_idx = cache_batch_idx_.value();
1433
+ CHECK_DEVICE(cache_batch_idx);
1434
+ CHECK_CONTIGUOUS(cache_batch_idx);
1435
+ TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
1436
+ params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
1437
+ }
1438
+
1439
+ // Keep references to these tensors to extend their lifetime
1440
+ at::Tensor softmax_lse_accum, out_accum;
1441
+ std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
1442
+ params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
1443
+ head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts);
1444
+
1445
+ if (paged_KV) {
1446
+ params.block_table = block_table.data_ptr<int>();
1447
+ params.block_table_batch_stride = block_table.stride(0);
1448
+ }
1449
+ params.page_block_size = page_block_size;
1450
+
1451
+
1452
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1453
+
1454
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1455
+ // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
1456
+ // or paged KV cache
1457
+ run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
1458
+
1459
+ if (head_size_og % 8 != 0) {
1460
+ out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
1461
+ if (out_.has_value()) { out_.value().copy_(out); }
1462
+ if (k_.has_value()) {
1463
+ // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
1464
+ // but we don't expect to get this case in practice. This is just so that the code works for that case.
1465
+ kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
1466
+ vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
1467
+ }
1468
+ }
1469
+
1470
+ if (seqlenq_ngroups_swapped) {
1471
+ out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
1472
+ softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
1473
+ }
1474
+ return {out, softmax_lse};
1475
+ }
1476
+ } // namespace FLASH_NAMESPACE
1477
+
1478
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1479
+ m.doc() = "FlashAttention";
1480
+ m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
1481
+ m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
1482
+ m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
1483
+ m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
1484
+ m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
1485
+ }
cookbooks/flash-attention/csrc/flash_attn/src/alibi.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cmath>
2
+
3
+ #include "namespace_config.h"
4
+ #include <cute/tensor.hpp>
5
+
6
+ #include <cutlass/cutlass.h>
7
+ #include <cutlass/array.h>
8
+
9
+ #include "utils.h"
10
+
11
+ namespace FLASH_NAMESPACE {
12
+
13
+ using namespace cute;
14
+
15
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
16
+
17
+ template <bool Is_causal>
18
+ struct Alibi {
19
+
20
+ const float alibi_slope;
21
+ const int max_seqlen_k, max_seqlen_q;
22
+
23
+ __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
24
+ : alibi_slope(alibi_slope)
25
+ , max_seqlen_k(max_seqlen_k)
26
+ , max_seqlen_q(max_seqlen_q) {
27
+ };
28
+
29
+
30
+ template <typename Engine, typename Layout>
31
+ __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
32
+ const int col_idx_offset_,
33
+ const int row_idx_offset,
34
+ const int warp_row_stride) {
35
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
36
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
37
+ const int lane_id = threadIdx.x % 32;
38
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
39
+ if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
40
+ #pragma unroll
41
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
42
+ const int col_idx_base = col_idx_offset + nj * 8;
43
+ #pragma unroll
44
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
45
+ const int col_idx = col_idx_base + j;
46
+ #pragma unroll
47
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
48
+ tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
49
+ }
50
+ }
51
+ }
52
+ } else { // Bias depends on both row_idx and col_idx
53
+ #pragma unroll
54
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
55
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
56
+ #pragma unroll
57
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
58
+ const int row_idx = row_idx_base + i * 8;
59
+ #pragma unroll
60
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
61
+ const int col_idx_base = col_idx_offset + nj * 8;
62
+ #pragma unroll
63
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
64
+ const int col_idx = col_idx_base + j;
65
+ tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
66
+ }
67
+ }
68
+ }
69
+ }
70
+ }
71
+ }
72
+
73
+ };
74
+
75
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/block_info.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ namespace FLASH_NAMESPACE {
9
+
10
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
11
+
12
+ template<bool Varlen=true>
13
+ struct BlockInfo {
14
+
15
+ template<typename Params>
16
+ __device__ BlockInfo(const Params &params, const int bidb)
17
+ : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
18
+ , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
19
+ , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
20
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
21
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
22
+ , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
23
+ , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
24
+ , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
25
+ {
26
+ }
27
+
28
+ template <typename index_t>
29
+ __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
30
+ return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
31
+ }
32
+
33
+ template <typename index_t>
34
+ __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
35
+ return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
36
+ }
37
+
38
+ const int sum_s_q;
39
+ const int sum_s_k;
40
+ const int actual_seqlen_q;
41
+ // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
42
+ const int leftpad_k;
43
+ const int seqlen_k_cache;
44
+ const int actual_seqlen_k;
45
+ };
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/dropout.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ #include "philox.cuh"
9
+ #include "utils.h"
10
+
11
+ namespace FLASH_NAMESPACE {
12
+
13
+ struct Dropout {
14
+
15
+ const unsigned long long seed, offset;
16
+ const uint8_t p_dropout_in_uint8_t;
17
+
18
+ __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
19
+ const uint8_t p_dropout_in_uint8_t,
20
+ const int bid, const int hid, const int tid, const int nheads)
21
+ : seed(seed)
22
+ , offset(offset + (bid * nheads + hid) * 32 + tid % 32)
23
+ , p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
24
+ }
25
+
26
+ template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
27
+ __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
28
+ int block_row_start, int block_col_start, int block_row_stride) {
29
+ // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
30
+ Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));
31
+ using T = typename Engine::value_type;
32
+ auto encode_dropout = [](bool keep, T val) {
33
+ return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
34
+ };
35
+ static_assert(decltype(size<2>(tensor))::value % 2 == 0);
36
+ const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
37
+ const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
38
+ // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
39
+ #pragma unroll
40
+ for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
41
+ uint2 rowcol = make_uint2(block_row_start, block_col_start);
42
+ #pragma unroll
43
+ for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
44
+ // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
45
+ uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
46
+ // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
47
+ uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
48
+ // Special implementation for 16-bit types: we duplicate the threshold to the
49
+ // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
50
+ // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
51
+ // and the high 16 bits will be either 0xffff or 0x0000, depending on whether
52
+ // the random value is less than the threshold.
53
+ // We then do a bit-wise AND between the mask and the original value (in 32-bit).
54
+ // We're exploiting the fact that floating point comparison is equivalent to integer
55
+ // comparison, since we're comparing unsigned integers whose top 8-bits are zero.
56
+ if (!encode_dropout_in_sign_bit
57
+ && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
58
+ uint16_t rnd_16[16];
59
+ #pragma unroll
60
+ for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
61
+ uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
62
+ #pragma unroll
63
+ for (int j = 0; j < 2; j++) {
64
+ Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
65
+ // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
66
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
67
+ #pragma unroll
68
+ for (int i = 0; i < 4; i++) {
69
+ uint32_t mask;
70
+ asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
71
+ tensor_uint32(i) &= mask;
72
+ }
73
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
74
+ }
75
+ } else {
76
+ #pragma unroll
77
+ for (int j = 0; j < 2; j++) {
78
+ #pragma unroll
79
+ for (int i = 0; i < 8; i++) {
80
+ tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
81
+ }
82
+ Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
83
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
84
+ }
85
+ }
86
+ // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
87
+ // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
88
+ // // }
89
+ }
90
+ }
91
+ }
92
+
93
+ };
94
+
95
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash.h ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+
9
+ #include <cuda.h>
10
+ #include <vector>
11
+
12
+ #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
13
+
14
+ namespace FLASH_NAMESPACE {
15
+ constexpr int TOTAL_DIM = 0;
16
+ constexpr int H_DIM = 1;
17
+ constexpr int D_DIM = 2;
18
+
19
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
20
+
21
+ struct Qkv_params {
22
+ using index_t = int64_t;
23
+ // The QKV matrices.
24
+ void *__restrict__ q_ptr;
25
+ void *__restrict__ k_ptr;
26
+ void *__restrict__ v_ptr;
27
+
28
+ // The stride between rows of the Q, K and V matrices.
29
+ index_t q_batch_stride;
30
+ index_t k_batch_stride;
31
+ index_t v_batch_stride;
32
+ index_t q_row_stride;
33
+ index_t k_row_stride;
34
+ index_t v_row_stride;
35
+ index_t q_head_stride;
36
+ index_t k_head_stride;
37
+ index_t v_head_stride;
38
+
39
+ // The number of heads.
40
+ int h, h_k;
41
+ // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
42
+ // different from nheads (query).
43
+ int h_h_k_ratio; // precompute h / h_k,
44
+ };
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ struct Flash_fwd_params : public Qkv_params {
49
+
50
+ // The O matrix (output).
51
+ void * __restrict__ o_ptr;
52
+ void * __restrict__ oaccum_ptr;
53
+
54
+ // The stride between rows of O.
55
+ index_t o_batch_stride;
56
+ index_t o_row_stride;
57
+ index_t o_head_stride;
58
+
59
+ // The pointer to the P matrix.
60
+ void * __restrict__ p_ptr;
61
+
62
+ // The pointer to the softmax sum.
63
+ void * __restrict__ softmax_lse_ptr;
64
+ void * __restrict__ softmax_lseaccum_ptr;
65
+
66
+ // The dimensions.
67
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
68
+
69
+ // The scaling factors for the kernel.
70
+ float scale_softmax;
71
+ float scale_softmax_log2;
72
+
73
+ // array of length b+1 holding starting offset of each sequence.
74
+ int * __restrict__ cu_seqlens_q;
75
+ int * __restrict__ cu_seqlens_k;
76
+ int * __restrict__ leftpad_k;
77
+
78
+ // If provided, the actual length of each k sequence.
79
+ int * __restrict__ seqused_k;
80
+
81
+ int *__restrict__ blockmask;
82
+
83
+ // The K_new and V_new matrices.
84
+ void * __restrict__ knew_ptr;
85
+ void * __restrict__ vnew_ptr;
86
+
87
+ // The stride between rows of the Q, K and V matrices.
88
+ index_t knew_batch_stride;
89
+ index_t vnew_batch_stride;
90
+ index_t knew_row_stride;
91
+ index_t vnew_row_stride;
92
+ index_t knew_head_stride;
93
+ index_t vnew_head_stride;
94
+
95
+ // The cos and sin matrices for rotary embedding.
96
+ void * __restrict__ rotary_cos_ptr;
97
+ void * __restrict__ rotary_sin_ptr;
98
+
99
+ // The indices to index into the KV cache.
100
+ int * __restrict__ cache_batch_idx;
101
+
102
+ // Paged KV cache
103
+ int * __restrict__ block_table;
104
+ index_t block_table_batch_stride;
105
+ int page_block_size;
106
+
107
+ // The dropout probability (probability of keeping an activation).
108
+ float p_dropout;
109
+ // uint32_t p_dropout_in_uint;
110
+ // uint16_t p_dropout_in_uint16_t;
111
+ uint8_t p_dropout_in_uint8_t;
112
+
113
+ // Scale factor of 1 / (1 - p_dropout).
114
+ float rp_dropout;
115
+ float scale_softmax_rp_dropout;
116
+
117
+ // Local window size
118
+ int window_size_left, window_size_right;
119
+ float softcap;
120
+
121
+ // Random state.
122
+ at::PhiloxCudaState philox_args;
123
+
124
+ // Pointer to the RNG seed (idx 0) and offset (idx 1).
125
+ uint64_t * rng_state;
126
+
127
+ bool is_bf16;
128
+ bool is_causal;
129
+
130
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
131
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
132
+ bool is_seqlens_k_cumulative;
133
+
134
+ bool is_rotary_interleaved;
135
+
136
+ int num_splits; // For split-KV version
137
+
138
+ void * __restrict__ alibi_slopes_ptr;
139
+ index_t alibi_slopes_batch_stride;
140
+
141
+ bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
142
+ bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
143
+ };
144
+
145
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
146
+
147
+ struct Flash_bwd_params : public Flash_fwd_params {
148
+
149
+ // The dO and dQKV matrices.
150
+ void *__restrict__ do_ptr;
151
+ void *__restrict__ dq_ptr;
152
+ void *__restrict__ dk_ptr;
153
+ void *__restrict__ dv_ptr;
154
+
155
+ // To accumulate dQ
156
+ void *__restrict__ dq_accum_ptr;
157
+ void *__restrict__ dk_accum_ptr;
158
+ void *__restrict__ dv_accum_ptr;
159
+
160
+ // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
161
+ // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
162
+ // dv_accum_ptr;
163
+
164
+ // The stride between rows of the dO, dQ, dK and dV matrices.
165
+ // TD [2022-04-16]: We're using 32-bit indexing to save registers.
166
+ // The code probably won't work for arrays larger than 2GB.
167
+ index_t do_batch_stride;
168
+ index_t do_row_stride;
169
+ index_t do_head_stride;
170
+ index_t dq_batch_stride;
171
+ index_t dk_batch_stride;
172
+ index_t dv_batch_stride;
173
+ index_t dq_row_stride;
174
+ index_t dk_row_stride;
175
+ index_t dv_row_stride;
176
+ index_t dq_head_stride;
177
+ index_t dk_head_stride;
178
+ index_t dv_head_stride;
179
+
180
+ // The pointer to the softmax d sum.
181
+ void *__restrict__ dsoftmax_sum;
182
+
183
+ bool deterministic;
184
+ index_t dq_accum_split_stride;
185
+ };
186
+
187
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
188
+
189
+ template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
190
+ template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
191
+
192
+ template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
193
+
194
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE