Commit ·
ec78611
verified ·
0
Parent(s):
Duplicate from ArtemisTAO/WIN_21_1
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +191 -0
- Dockerfile +48 -0
- LICENSE +201 -0
- README.md +13 -0
- assets/Qwen2.5_Omni.pdf +3 -0
- cookbooks/=4.41.0 +17 -0
- cookbooks/=4.50.0.dev0 +17 -0
- cookbooks/=4.51.0.dev0 +17 -0
- cookbooks/flash-attention/.github/workflows/publish.yml +218 -0
- cookbooks/flash-attention/.gitignore +31 -0
- cookbooks/flash-attention/.gitmodules +6 -0
- cookbooks/flash-attention/AUTHORS +1 -0
- cookbooks/flash-attention/LICENSE +29 -0
- cookbooks/flash-attention/MANIFEST.in +12 -0
- cookbooks/flash-attention/Makefile +9 -0
- cookbooks/flash-attention/README.md +524 -0
- cookbooks/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png +3 -0
- cookbooks/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png +3 -0
- cookbooks/flash-attention/assets/flash3_fp16_fwd.png +3 -0
- cookbooks/flash-attention/assets/flashattention_logo.png +3 -0
- cookbooks/flash-attention/assets/flashattn_banner.jpg +3 -0
- cookbooks/flash-attention/assets/flashattn_banner.pdf +3 -0
- cookbooks/flash-attention/assets/flashattn_memory.jpg +0 -0
- cookbooks/flash-attention/assets/flashattn_speedup.jpg +3 -0
- cookbooks/flash-attention/assets/flashattn_speedup_3090.jpg +3 -0
- cookbooks/flash-attention/assets/flashattn_speedup_a100_d128.jpg +3 -0
- cookbooks/flash-attention/assets/flashattn_speedup_t4.jpg +3 -0
- cookbooks/flash-attention/assets/flashattn_speedup_t4_fwd.jpg +3 -0
- cookbooks/flash-attention/assets/gpt2_training_curve.jpg +3 -0
- cookbooks/flash-attention/assets/gpt2_training_efficiency.jpg +3 -0
- cookbooks/flash-attention/assets/gpt3_training_curve.jpg +3 -0
- cookbooks/flash-attention/assets/gpt3_training_efficiency.jpg +3 -0
- cookbooks/flash-attention/benchmarks/benchmark_alibi.py +275 -0
- cookbooks/flash-attention/benchmarks/benchmark_causal.py +225 -0
- cookbooks/flash-attention/benchmarks/benchmark_flash_attention.py +180 -0
- cookbooks/flash-attention/benchmarks/benchmark_gemm.py +47 -0
- cookbooks/flash-attention/csrc/flash_attn/flash_api.cpp +1485 -0
- cookbooks/flash-attention/csrc/flash_attn/src/alibi.h +75 -0
- cookbooks/flash-attention/csrc/flash_attn/src/block_info.h +49 -0
- cookbooks/flash-attention/csrc/flash_attn/src/dropout.h +95 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash.h +194 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +14 -0
- cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu +14 -0
.gitattributes
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/Qwen2.5_Omni.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
cookbooks/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
cookbooks/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
cookbooks/flash-attention/assets/flash3_fp16_fwd.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
cookbooks/flash-attention/assets/flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
cookbooks/flash-attention/assets/flashattn_banner.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
cookbooks/flash-attention/assets/flashattn_banner.pdf filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
cookbooks/flash-attention/assets/flashattn_speedup.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
cookbooks/flash-attention/assets/flashattn_speedup_3090.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
cookbooks/flash-attention/assets/flashattn_speedup_a100_d128.jpg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
cookbooks/flash-attention/assets/flashattn_speedup_t4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
cookbooks/flash-attention/assets/flashattn_speedup_t4_fwd.jpg filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
cookbooks/flash-attention/assets/gpt2_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
cookbooks/flash-attention/assets/gpt2_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
cookbooks/flash-attention/assets/gpt3_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
cookbooks/flash-attention/assets/gpt3_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
flash-attention/assets/flash3_fp16_fwd.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
flash-attention/assets/flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
flash-attention/assets/flashattn_banner.jpg filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
flash-attention/assets/flashattn_banner.pdf filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
flash-attention/assets/flashattn_speedup.jpg filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
flash-attention/assets/flashattn_speedup_3090.jpg filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
flash-attention/assets/flashattn_speedup_a100_d128.jpg filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
flash-attention/assets/flashattn_speedup_t4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
flash-attention/assets/flashattn_speedup_t4_fwd.jpg filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
flash-attention/assets/gpt2_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
flash-attention/assets/gpt2_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
flash-attention/assets/gpt3_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
flash-attention/assets/gpt3_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
flash-attention/csrc/composable_kernel/docs/data/ck_component.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
flash-attention/csrc/composable_kernel/docs/data/ck_layer.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
flash-attention/csrc/composable_kernel/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
flash-attention/csrc/composable_kernel/example/ck_tile/15_fused_moe/misc/moe-2.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
flash-attention/csrc/cutlass/media/images/M128xK4_scalefactor_gmem.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
flash-attention/csrc/cutlass/media/images/conv2d-fprop-int4.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32Mx32x4.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32x32x4.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_Atom.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.AB.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.C.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
flash-attention/csrc/cutlass/media/images/cute/TiledCopyA.png filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
flash-attention/csrc/cutlass/media/images/cute/TiledMmaC.png filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
flash-attention/csrc/cutlass/media/images/cute/composition1.png filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
flash-attention/csrc/cutlass/media/images/cute/composition2.png filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
flash-attention/csrc/cutlass/media/images/cute/divide2.png filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
flash-attention/csrc/cutlass/media/images/cute/divide3.png filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
flash-attention/csrc/cutlass/media/images/cute/gmma_coremat_cd_fp16.png filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
flash-attention/csrc/cutlass/media/images/cute/gmma_wg_n_slice.png filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide-2.png filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide.png filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
flash-attention/csrc/cutlass/media/images/cute/product2d.png filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
flash-attention/csrc/cutlass/media/images/cute/productblocked2d.png filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
flash-attention/csrc/cutlass/media/images/cute/productraked2d.png filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
flash-attention/csrc/cutlass/media/images/cute/slice.png filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
flash-attention/csrc/cutlass/media/images/cute/tC_partitioning.png filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
flash-attention/csrc/cutlass/media/images/cute/tv_layout.png filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
flash-attention/csrc/cutlass/media/images/cutlass-2.8-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
flash-attention/csrc/cutlass/media/images/cutlass-2.9-implicit-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
flash-attention/csrc/cutlass/media/images/cutlass-3.0-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
flash-attention/csrc/cutlass/media/images/cutlass-3.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
flash-attention/csrc/cutlass/media/images/cutlass-gemm-components.png filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
flash-attention/csrc/cutlass/media/images/cutlass-reduction-in-named-iterators.png filter=lfs diff=lfs merge=lfs -text
|
| 104 |
+
flash-attention/csrc/cutlass/media/images/cutlass-threadblock-mma-pipelined.png filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
flash-attention/csrc/cutlass/media/images/cutlass-tile-structure.png filter=lfs diff=lfs merge=lfs -text
|
| 106 |
+
flash-attention/csrc/cutlass/media/images/cutlass-warp-level-gemm-api-instantiation.png filter=lfs diff=lfs merge=lfs -text
|
| 107 |
+
flash-attention/csrc/cutlass/media/images/cutlass-warp-thread-tile-structure.png filter=lfs diff=lfs merge=lfs -text
|
| 108 |
+
flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue-no-labels.png filter=lfs diff=lfs merge=lfs -text
|
| 109 |
+
flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue.png filter=lfs diff=lfs merge=lfs -text
|
| 110 |
+
flash-attention/csrc/cutlass/media/images/gemm-structural-components.png filter=lfs diff=lfs merge=lfs -text
|
| 111 |
+
flash-attention/csrc/cutlass/media/images/ldmatrix-8x128bx4.png filter=lfs diff=lfs merge=lfs -text
|
| 112 |
+
flash-attention/csrc/cutlass/media/images/ldmatrix-tensorop-32x32x32.png filter=lfs diff=lfs merge=lfs -text
|
| 113 |
+
flash-attention/csrc/cutlass/media/images/mma-8x8x32.png filter=lfs diff=lfs merge=lfs -text
|
| 114 |
+
flash-attention/csrc/cutlass/media/images/non_persistent.png filter=lfs diff=lfs merge=lfs -text
|
| 115 |
+
flash-attention/csrc/cutlass/media/images/persistent_clc.png filter=lfs diff=lfs merge=lfs -text
|
| 116 |
+
flash-attention/csrc/cutlass/media/images/persistent_static.png filter=lfs diff=lfs merge=lfs -text
|
| 117 |
+
flash-attention/csrc/cutlass/media/images/software-pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 118 |
+
flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k0.png filter=lfs diff=lfs merge=lfs -text
|
| 119 |
+
flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k1.png filter=lfs diff=lfs merge=lfs -text
|
| 120 |
+
flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN.png filter=lfs diff=lfs merge=lfs -text
|
| 121 |
+
flash-attention/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
|
| 122 |
+
flash-attention/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png filter=lfs diff=lfs merge=lfs -text
|
| 123 |
+
flash-attention/flash-attention/assets/flash3_fp16_fwd.png filter=lfs diff=lfs merge=lfs -text
|
| 124 |
+
flash-attention/flash-attention/assets/flashattention_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
flash-attention/flash-attention/assets/flashattn_banner.jpg filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
flash-attention/flash-attention/assets/flashattn_banner.pdf filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
flash-attention/flash-attention/assets/flashattn_speedup.jpg filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
flash-attention/flash-attention/assets/flashattn_speedup_3090.jpg filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
flash-attention/flash-attention/assets/flashattn_speedup_a100_d128.jpg filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
flash-attention/flash-attention/assets/flashattn_speedup_t4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 131 |
+
flash-attention/flash-attention/assets/flashattn_speedup_t4_fwd.jpg filter=lfs diff=lfs merge=lfs -text
|
| 132 |
+
flash-attention/flash-attention/assets/gpt2_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
|
| 133 |
+
flash-attention/flash-attention/assets/gpt2_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 134 |
+
flash-attention/flash-attention/assets/gpt3_training_curve.jpg filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
flash-attention/flash-attention/assets/gpt3_training_efficiency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
flash-attention/flash-attention/csrc/composable_kernel/docs/data/ck_component.png filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
flash-attention/flash-attention/csrc/composable_kernel/docs/data/ck_layer.png filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
flash-attention/flash-attention/csrc/composable_kernel/example/ck_tile/14_moe_smoothquant/misc/moe-sm.png filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
flash-attention/flash-attention/csrc/composable_kernel/example/ck_tile/15_fused_moe/misc/moe-2.png filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/M128xK4_scalefactor_gmem.png filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/conv2d-fprop-int4.png filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT.png filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2.png filter=lfs diff=lfs merge=lfs -text
|
| 144 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32Mx32x4.png filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_2x2_32x32x4.png filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.NT_Atom.png filter=lfs diff=lfs merge=lfs -text
|
| 147 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.AB.png filter=lfs diff=lfs merge=lfs -text
|
| 148 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/HMMA.8x8x4.quadpair.C.png filter=lfs diff=lfs merge=lfs -text
|
| 149 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/TiledCopyA.png filter=lfs diff=lfs merge=lfs -text
|
| 150 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/TiledMmaC.png filter=lfs diff=lfs merge=lfs -text
|
| 151 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/composition1.png filter=lfs diff=lfs merge=lfs -text
|
| 152 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/composition2.png filter=lfs diff=lfs merge=lfs -text
|
| 153 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/divide2.png filter=lfs diff=lfs merge=lfs -text
|
| 154 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/divide3.png filter=lfs diff=lfs merge=lfs -text
|
| 155 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/gmma_coremat_cd_fp16.png filter=lfs diff=lfs merge=lfs -text
|
| 156 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/gmma_wg_n_slice.png filter=lfs diff=lfs merge=lfs -text
|
| 157 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide-2.png filter=lfs diff=lfs merge=lfs -text
|
| 158 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/logical_divide-and-zipped_divide.png filter=lfs diff=lfs merge=lfs -text
|
| 159 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/product2d.png filter=lfs diff=lfs merge=lfs -text
|
| 160 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/productblocked2d.png filter=lfs diff=lfs merge=lfs -text
|
| 161 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/productraked2d.png filter=lfs diff=lfs merge=lfs -text
|
| 162 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/slice.png filter=lfs diff=lfs merge=lfs -text
|
| 163 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/tC_partitioning.png filter=lfs diff=lfs merge=lfs -text
|
| 164 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cute/tv_layout.png filter=lfs diff=lfs merge=lfs -text
|
| 165 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-2.8-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 166 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-2.9-implicit-gemm-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 167 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.0-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 168 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 169 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance-fp8.png filter=lfs diff=lfs merge=lfs -text
|
| 170 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-3.5.1-gemm-peak-performance.png filter=lfs diff=lfs merge=lfs -text
|
| 171 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-gemm-components.png filter=lfs diff=lfs merge=lfs -text
|
| 172 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-reduction-in-named-iterators.png filter=lfs diff=lfs merge=lfs -text
|
| 173 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-threadblock-mma-pipelined.png filter=lfs diff=lfs merge=lfs -text
|
| 174 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-tile-structure.png filter=lfs diff=lfs merge=lfs -text
|
| 175 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-warp-level-gemm-api-instantiation.png filter=lfs diff=lfs merge=lfs -text
|
| 176 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/cutlass-warp-thread-tile-structure.png filter=lfs diff=lfs merge=lfs -text
|
| 177 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue-no-labels.png filter=lfs diff=lfs merge=lfs -text
|
| 178 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/gemm-hierarchy-with-epilogue.png filter=lfs diff=lfs merge=lfs -text
|
| 179 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/gemm-structural-components.png filter=lfs diff=lfs merge=lfs -text
|
| 180 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/ldmatrix-8x128bx4.png filter=lfs diff=lfs merge=lfs -text
|
| 181 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/ldmatrix-tensorop-32x32x32.png filter=lfs diff=lfs merge=lfs -text
|
| 182 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/mma-8x8x32.png filter=lfs diff=lfs merge=lfs -text
|
| 183 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/non_persistent.png filter=lfs diff=lfs merge=lfs -text
|
| 184 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/persistent_clc.png filter=lfs diff=lfs merge=lfs -text
|
| 185 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/persistent_static.png filter=lfs diff=lfs merge=lfs -text
|
| 186 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/software-pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 187 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k0.png filter=lfs diff=lfs merge=lfs -text
|
| 188 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN-k1.png filter=lfs diff=lfs merge=lfs -text
|
| 189 |
+
flash-attention/flash-attention/csrc/cutlass/media/images/tensor-op-permuted-smem-layout-TN.png filter=lfs diff=lfs merge=lfs -text
|
| 190 |
+
input_audio.wav filter=lfs diff=lfs merge=lfs -text
|
| 191 |
+
model/Qwen2.5-Omni-7B/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
# Set environment variables
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
DEBIAN_FRONTEND=noninteractive \
|
| 6 |
+
HF_HOME=/app/models \
|
| 7 |
+
NUMBA_CACHE_DIR=/tmp/numba_cache \
|
| 8 |
+
TORCH_CUDA_ARCH_LIST=8.0
|
| 9 |
+
|
| 10 |
+
# Install system dependencies
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
python3 \
|
| 13 |
+
python3-pip \
|
| 14 |
+
python3-dev \
|
| 15 |
+
build-essential \
|
| 16 |
+
git \
|
| 17 |
+
ffmpeg \
|
| 18 |
+
libsndfile1 \
|
| 19 |
+
libcusparse-dev-12-3 \
|
| 20 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 21 |
+
|
| 22 |
+
# Install Python build tools
|
| 23 |
+
RUN pip install --upgrade pip setuptools wheel packaging ninja
|
| 24 |
+
|
| 25 |
+
WORKDIR /app
|
| 26 |
+
|
| 27 |
+
# Create cache directory
|
| 28 |
+
RUN mkdir -p /tmp/numba_cache && \
|
| 29 |
+
chmod 777 /tmp/numba_cache
|
| 30 |
+
|
| 31 |
+
# Install PyTorch with CUDA 12.1 first
|
| 32 |
+
RUN pip install --pre torch torchvision torchaudio \
|
| 33 |
+
--index-url https://download.pytorch.org/whl/nightly/cu121
|
| 34 |
+
|
| 35 |
+
# Copy and install requirements
|
| 36 |
+
COPY requirements.txt .
|
| 37 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 38 |
+
|
| 39 |
+
# Install flash-attn separately with no isolation
|
| 40 |
+
RUN pip install flash-attn==2.7.4.post1 --no-build-isolation
|
| 41 |
+
|
| 42 |
+
# Copy application files
|
| 43 |
+
COPY server.py .
|
| 44 |
+
COPY qwen-omni-utils/ ./qwen-omni-utils/
|
| 45 |
+
COPY model/ ./model/
|
| 46 |
+
|
| 47 |
+
EXPOSE 8000
|
| 48 |
+
CMD ["python3", "server.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Alibaba Cloud
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- any-to-any
|
| 5 |
+
- omega
|
| 6 |
+
- omegalabs
|
| 7 |
+
- bittensor
|
| 8 |
+
- agi
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
|
| 12 |
+
|
| 13 |
+
Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
|
assets/Qwen2.5_Omni.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e0c9e0042ad20bc0c95cbbfc96f63f4ff1f28727c5b32973e7fd597557b6b15f
|
| 3 |
+
size 4014433
|
cookbooks/=4.41.0
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Requirement already satisfied: transformers in /home/ubuntu/.venv/lib/python3.10/site-packages (4.51.0.dev0)
|
| 2 |
+
Requirement already satisfied: filelock in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (3.18.0)
|
| 3 |
+
Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.29.3)
|
| 4 |
+
Requirement already satisfied: numpy>=1.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.1.3)
|
| 5 |
+
Requirement already satisfied: packaging>=20.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (24.2)
|
| 6 |
+
Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
|
| 7 |
+
Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2024.11.6)
|
| 8 |
+
Requirement already satisfied: requests in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.32.3)
|
| 9 |
+
Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.21.1)
|
| 10 |
+
Requirement already satisfied: safetensors>=0.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.5.3)
|
| 11 |
+
Requirement already satisfied: tqdm>=4.27 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
|
| 12 |
+
Requirement already satisfied: fsspec>=2023.5.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2025.3.0)
|
| 13 |
+
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.13.0)
|
| 14 |
+
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.1)
|
| 15 |
+
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
|
| 16 |
+
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2.3.0)
|
| 17 |
+
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)
|
cookbooks/=4.50.0.dev0
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Requirement already satisfied: transformers in /home/ubuntu/.venv/lib/python3.10/site-packages (4.50.0.dev0)
|
| 2 |
+
Requirement already satisfied: filelock in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (3.18.0)
|
| 3 |
+
Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.29.3)
|
| 4 |
+
Requirement already satisfied: numpy>=1.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.1.3)
|
| 5 |
+
Requirement already satisfied: packaging>=20.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (24.2)
|
| 6 |
+
Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
|
| 7 |
+
Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2024.11.6)
|
| 8 |
+
Requirement already satisfied: requests in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.32.3)
|
| 9 |
+
Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.21.1)
|
| 10 |
+
Requirement already satisfied: safetensors>=0.4.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.5.3)
|
| 11 |
+
Requirement already satisfied: tqdm>=4.27 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
|
| 12 |
+
Requirement already satisfied: fsspec>=2023.5.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2025.3.0)
|
| 13 |
+
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.13.0)
|
| 14 |
+
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.1)
|
| 15 |
+
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
|
| 16 |
+
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2.3.0)
|
| 17 |
+
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)
|
cookbooks/=4.51.0.dev0
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Requirement already satisfied: transformers in /home/ubuntu/.venv/lib/python3.10/site-packages (4.50.0.dev0)
|
| 2 |
+
Requirement already satisfied: filelock in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (3.18.0)
|
| 3 |
+
Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.29.3)
|
| 4 |
+
Requirement already satisfied: numpy>=1.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.1.3)
|
| 5 |
+
Requirement already satisfied: packaging>=20.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (24.2)
|
| 6 |
+
Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (6.0.2)
|
| 7 |
+
Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2024.11.6)
|
| 8 |
+
Requirement already satisfied: requests in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (2.32.3)
|
| 9 |
+
Requirement already satisfied: tokenizers<0.22,>=0.21 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.21.1)
|
| 10 |
+
Requirement already satisfied: safetensors>=0.4.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (0.5.3)
|
| 11 |
+
Requirement already satisfied: tqdm>=4.27 in /home/ubuntu/.venv/lib/python3.10/site-packages (from transformers) (4.67.1)
|
| 12 |
+
Requirement already satisfied: fsspec>=2023.5.0 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2025.3.0)
|
| 13 |
+
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.venv/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.13.0)
|
| 14 |
+
Requirement already satisfied: charset-normalizer<4,>=2 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.4.1)
|
| 15 |
+
Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (3.10)
|
| 16 |
+
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2.3.0)
|
| 17 |
+
Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/.venv/lib/python3.10/site-packages (from requests->transformers) (2025.1.31)
|
cookbooks/flash-attention/.github/workflows/publish.yml
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This workflow will:
|
| 2 |
+
# - Create a new Github release
|
| 3 |
+
# - Build wheels for supported architectures
|
| 4 |
+
# - Deploy the wheels to the Github release
|
| 5 |
+
# - Release the static code to PyPi
|
| 6 |
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
| 7 |
+
|
| 8 |
+
name: Build wheels and deploy
|
| 9 |
+
|
| 10 |
+
on:
|
| 11 |
+
create:
|
| 12 |
+
tags:
|
| 13 |
+
- v*
|
| 14 |
+
|
| 15 |
+
jobs:
|
| 16 |
+
|
| 17 |
+
setup_release:
|
| 18 |
+
name: Create Release
|
| 19 |
+
runs-on: ubuntu-latest
|
| 20 |
+
steps:
|
| 21 |
+
- name: Get the tag version
|
| 22 |
+
id: extract_branch
|
| 23 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
| 24 |
+
shell: bash
|
| 25 |
+
|
| 26 |
+
- name: Create Release
|
| 27 |
+
id: create_release
|
| 28 |
+
uses: actions/create-release@v1
|
| 29 |
+
env:
|
| 30 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 31 |
+
with:
|
| 32 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
| 33 |
+
release_name: ${{ steps.extract_branch.outputs.branch }}
|
| 34 |
+
|
| 35 |
+
build_wheels:
|
| 36 |
+
name: Build Wheel
|
| 37 |
+
needs: setup_release
|
| 38 |
+
runs-on: ${{ matrix.os }}
|
| 39 |
+
|
| 40 |
+
strategy:
|
| 41 |
+
fail-fast: false
|
| 42 |
+
matrix:
|
| 43 |
+
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
|
| 44 |
+
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
|
| 45 |
+
os: [ubuntu-20.04]
|
| 46 |
+
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
| 47 |
+
torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0']
|
| 48 |
+
cuda-version: ['12.4.1']
|
| 49 |
+
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
|
| 50 |
+
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
|
| 51 |
+
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
|
| 52 |
+
# when building without C++11 ABI and using it on nvcr images.
|
| 53 |
+
cxx11_abi: ['FALSE', 'TRUE']
|
| 54 |
+
exclude:
|
| 55 |
+
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
|
| 56 |
+
# Pytorch < 2.5 does not support Python 3.13
|
| 57 |
+
- torch-version: '2.2.2'
|
| 58 |
+
python-version: '3.13'
|
| 59 |
+
- torch-version: '2.3.1'
|
| 60 |
+
python-version: '3.13'
|
| 61 |
+
- torch-version: '2.4.0'
|
| 62 |
+
python-version: '3.13'
|
| 63 |
+
|
| 64 |
+
steps:
|
| 65 |
+
- name: Checkout
|
| 66 |
+
uses: actions/checkout@v4
|
| 67 |
+
|
| 68 |
+
- name: Set up Python
|
| 69 |
+
uses: actions/setup-python@v5
|
| 70 |
+
with:
|
| 71 |
+
python-version: ${{ matrix.python-version }}
|
| 72 |
+
|
| 73 |
+
- name: Set CUDA and PyTorch versions
|
| 74 |
+
run: |
|
| 75 |
+
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
|
| 76 |
+
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
|
| 77 |
+
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
|
| 78 |
+
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
|
| 79 |
+
|
| 80 |
+
- name: Free up disk space
|
| 81 |
+
if: ${{ runner.os == 'Linux' }}
|
| 82 |
+
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
|
| 83 |
+
# https://github.com/easimon/maximize-build-space/tree/test-report
|
| 84 |
+
run: |
|
| 85 |
+
sudo rm -rf /usr/share/dotnet
|
| 86 |
+
sudo rm -rf /opt/ghc
|
| 87 |
+
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
| 88 |
+
|
| 89 |
+
- name: Set up swap space
|
| 90 |
+
if: runner.os == 'Linux'
|
| 91 |
+
uses: pierotofy/set-swap-space@v1.0
|
| 92 |
+
with:
|
| 93 |
+
swap-size-gb: 10
|
| 94 |
+
|
| 95 |
+
- name: Install CUDA ${{ matrix.cuda-version }}
|
| 96 |
+
if: ${{ matrix.cuda-version != 'cpu' }}
|
| 97 |
+
uses: Jimver/cuda-toolkit@v0.2.19
|
| 98 |
+
id: cuda-toolkit
|
| 99 |
+
with:
|
| 100 |
+
cuda: ${{ matrix.cuda-version }}
|
| 101 |
+
linux-local-args: '["--toolkit"]'
|
| 102 |
+
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
|
| 103 |
+
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
|
| 104 |
+
method: 'network'
|
| 105 |
+
sub-packages: '["nvcc"]'
|
| 106 |
+
|
| 107 |
+
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
| 108 |
+
run: |
|
| 109 |
+
pip install --upgrade pip
|
| 110 |
+
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
|
| 111 |
+
pip install setuptools==75.8.0
|
| 112 |
+
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
|
| 113 |
+
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
|
| 114 |
+
pip install typing-extensions==4.12.2
|
| 115 |
+
# We want to figure out the CUDA version to download pytorch
|
| 116 |
+
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
|
| 117 |
+
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
|
| 118 |
+
# This code is ugly, maybe there's a better way to do this.
|
| 119 |
+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
|
| 120 |
+
minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
|
| 121 |
+
maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
|
| 122 |
+
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
|
| 123 |
+
)
|
| 124 |
+
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
|
| 125 |
+
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
|
| 126 |
+
# Can't use --no-deps because we need cudnn etc.
|
| 127 |
+
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
|
| 128 |
+
pip install jinja2
|
| 129 |
+
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
|
| 130 |
+
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
|
| 131 |
+
else
|
| 132 |
+
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
|
| 133 |
+
fi
|
| 134 |
+
nvcc --version
|
| 135 |
+
python --version
|
| 136 |
+
python -c "import torch; print('PyTorch:', torch.__version__)"
|
| 137 |
+
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
| 138 |
+
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
| 139 |
+
shell:
|
| 140 |
+
bash
|
| 141 |
+
|
| 142 |
+
- name: Build wheel
|
| 143 |
+
run: |
|
| 144 |
+
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
|
| 145 |
+
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
|
| 146 |
+
# However this still fails so I'm using a newer version of setuptools
|
| 147 |
+
pip install setuptools==75.8.0
|
| 148 |
+
pip install ninja packaging wheel
|
| 149 |
+
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
| 150 |
+
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
| 151 |
+
# Limit MAX_JOBS otherwise the github runner goes OOM
|
| 152 |
+
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM
|
| 153 |
+
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
|
| 154 |
+
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
|
| 155 |
+
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
| 156 |
+
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
| 157 |
+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
| 158 |
+
|
| 159 |
+
- name: Log Built Wheels
|
| 160 |
+
run: |
|
| 161 |
+
ls dist
|
| 162 |
+
|
| 163 |
+
- name: Get the tag version
|
| 164 |
+
id: extract_branch
|
| 165 |
+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
| 166 |
+
|
| 167 |
+
- name: Get Release with tag
|
| 168 |
+
id: get_current_release
|
| 169 |
+
uses: joutvhu/get-release@v1
|
| 170 |
+
with:
|
| 171 |
+
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
| 172 |
+
env:
|
| 173 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 174 |
+
|
| 175 |
+
- name: Upload Release Asset
|
| 176 |
+
id: upload_release_asset
|
| 177 |
+
uses: actions/upload-release-asset@v1
|
| 178 |
+
env:
|
| 179 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 180 |
+
with:
|
| 181 |
+
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
| 182 |
+
asset_path: ./dist/${{env.wheel_name}}
|
| 183 |
+
asset_name: ${{env.wheel_name}}
|
| 184 |
+
asset_content_type: application/*
|
| 185 |
+
|
| 186 |
+
publish_package:
|
| 187 |
+
name: Publish package
|
| 188 |
+
needs: [build_wheels]
|
| 189 |
+
|
| 190 |
+
runs-on: ubuntu-latest
|
| 191 |
+
|
| 192 |
+
steps:
|
| 193 |
+
- uses: actions/checkout@v4
|
| 194 |
+
|
| 195 |
+
- uses: actions/setup-python@v5
|
| 196 |
+
with:
|
| 197 |
+
python-version: '3.10'
|
| 198 |
+
|
| 199 |
+
- name: Install dependencies
|
| 200 |
+
run: |
|
| 201 |
+
pip install ninja packaging wheel twine
|
| 202 |
+
# Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)
|
| 203 |
+
pip install setuptools==75.8.0
|
| 204 |
+
# We don't want to download anything CUDA-related here
|
| 205 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 206 |
+
|
| 207 |
+
- name: Build core package
|
| 208 |
+
env:
|
| 209 |
+
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
|
| 210 |
+
run: |
|
| 211 |
+
python setup.py sdist --dist-dir=dist
|
| 212 |
+
|
| 213 |
+
- name: Deploy
|
| 214 |
+
env:
|
| 215 |
+
TWINE_USERNAME: "__token__"
|
| 216 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
| 217 |
+
run: |
|
| 218 |
+
python -m twine upload dist/*
|
cookbooks/flash-attention/.gitignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ncu-rep
|
| 2 |
+
.DS_store
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
|
| 8 |
+
# C extensions
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Distribution / packaging
|
| 12 |
+
bin/
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
.eggs/
|
| 26 |
+
|
| 27 |
+
# IDE-related
|
| 28 |
+
.idea/
|
| 29 |
+
|
| 30 |
+
# Dev
|
| 31 |
+
venv
|
cookbooks/flash-attention/.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "csrc/cutlass"]
|
| 2 |
+
path = csrc/cutlass
|
| 3 |
+
url = https://github.com/NVIDIA/cutlass.git
|
| 4 |
+
[submodule "csrc/composable_kernel"]
|
| 5 |
+
path = csrc/composable_kernel
|
| 6 |
+
url = https://github.com/ROCm/composable_kernel.git
|
cookbooks/flash-attention/AUTHORS
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Tri Dao, trid@cs.stanford.edu
|
cookbooks/flash-attention/LICENSE
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without
|
| 7 |
+
modification, are permitted provided that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
and/or other materials provided with the distribution.
|
| 15 |
+
|
| 16 |
+
* Neither the name of the copyright holder nor the names of its
|
| 17 |
+
contributors may be used to endorse or promote products derived from
|
| 18 |
+
this software without specific prior written permission.
|
| 19 |
+
|
| 20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
cookbooks/flash-attention/MANIFEST.in
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
recursive-include csrc *.cu
|
| 2 |
+
recursive-include csrc *.h
|
| 3 |
+
recursive-include csrc *.cuh
|
| 4 |
+
recursive-include csrc *.cpp
|
| 5 |
+
recursive-include csrc *.hpp
|
| 6 |
+
recursive-include csrc *.py
|
| 7 |
+
|
| 8 |
+
recursive-include flash_attn *.cu
|
| 9 |
+
recursive-include flash_attn *.h
|
| 10 |
+
recursive-include flash_attn *.cuh
|
| 11 |
+
recursive-include flash_attn *.cpp
|
| 12 |
+
recursive-include flash_attn *.hpp
|
cookbooks/flash-attention/Makefile
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
clean_dist:
|
| 3 |
+
rm -rf dist/*
|
| 4 |
+
|
| 5 |
+
create_dist: clean_dist
|
| 6 |
+
python setup.py sdist
|
| 7 |
+
|
| 8 |
+
upload_package: create_dist
|
| 9 |
+
twine upload dist/*
|
cookbooks/flash-attention/README.md
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashAttention
|
| 2 |
+
This repository provides the official implementation of FlashAttention and
|
| 3 |
+
FlashAttention-2 from the
|
| 4 |
+
following papers.
|
| 5 |
+
|
| 6 |
+
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
|
| 7 |
+
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
|
| 8 |
+
Paper: https://arxiv.org/abs/2205.14135
|
| 9 |
+
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
|
| 10 |
+

|
| 11 |
+
|
| 12 |
+
**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
|
| 13 |
+
Tri Dao
|
| 14 |
+
|
| 15 |
+
Paper: https://tridao.me/publications/flash2/flash2.pdf
|
| 16 |
+
|
| 17 |
+

|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Usage
|
| 21 |
+
|
| 22 |
+
We've been very happy to see FlashAttention being widely adopted in such a short
|
| 23 |
+
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
|
| 24 |
+
contains a partial list of places where FlashAttention is being used.
|
| 25 |
+
|
| 26 |
+
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
|
| 27 |
+
Please cite and credit FlashAttention if you use it.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## FlashAttention-3 beta release
|
| 31 |
+
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
|
| 32 |
+
|
| 33 |
+
Blogpost: https://tridao.me/blog/2024/flash3/
|
| 34 |
+
|
| 35 |
+
Paper: https://tridao.me/publications/flash3/flash3.pdf
|
| 36 |
+
|
| 37 |
+

|
| 38 |
+
|
| 39 |
+
This is a beta release for testing / benchmarking before we integrate that with
|
| 40 |
+
the rest of the repo.
|
| 41 |
+
|
| 42 |
+
Currently released:
|
| 43 |
+
- FP16 / BF16 forward and backward, FP8 forward
|
| 44 |
+
|
| 45 |
+
Requirements: H100 / H800 GPU, CUDA >= 12.3.
|
| 46 |
+
|
| 47 |
+
We highly recommend CUDA 12.8 for best performance.
|
| 48 |
+
|
| 49 |
+
To install:
|
| 50 |
+
```sh
|
| 51 |
+
cd hopper
|
| 52 |
+
python setup.py install
|
| 53 |
+
```
|
| 54 |
+
To run the test:
|
| 55 |
+
```sh
|
| 56 |
+
export PYTHONPATH=$PWD
|
| 57 |
+
pytest -q -s test_flash_attn.py
|
| 58 |
+
```
|
| 59 |
+
Once the package is installed, you can import it as follows:
|
| 60 |
+
```python
|
| 61 |
+
import flash_attn_interface
|
| 62 |
+
flash_attn_interface.flash_attn_func()
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Installation and features
|
| 66 |
+
**Requirements:**
|
| 67 |
+
- CUDA toolkit or ROCm toolkit
|
| 68 |
+
- PyTorch 2.2 and above.
|
| 69 |
+
- `packaging` Python package (`pip install packaging`)
|
| 70 |
+
- `ninja` Python package (`pip install ninja`) *
|
| 71 |
+
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
|
| 72 |
+
|
| 73 |
+
\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
| 74 |
+
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
| 75 |
+
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
| 76 |
+
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
|
| 77 |
+
compiling can take a very long time (2h) since it does not use multiple CPU
|
| 78 |
+
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
|
| 79 |
+
|
| 80 |
+
**To install:**
|
| 81 |
+
```sh
|
| 82 |
+
pip install flash-attn --no-build-isolation
|
| 83 |
+
```
|
| 84 |
+
Alternatively you can compile from source:
|
| 85 |
+
```sh
|
| 86 |
+
python setup.py install
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might
|
| 90 |
+
run too many parallel compilation jobs that could exhaust the amount of RAM. To
|
| 91 |
+
limit the number of parallel compilation jobs, you can set the environment
|
| 92 |
+
variable `MAX_JOBS`:
|
| 93 |
+
```sh
|
| 94 |
+
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
**Interface:** `src/flash_attention_interface.py`
|
| 98 |
+
|
| 99 |
+
### NVIDIA CUDA Support
|
| 100 |
+
**Requirements:**
|
| 101 |
+
- CUDA 12.0 and above.
|
| 102 |
+
|
| 103 |
+
We recommend the
|
| 104 |
+
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
| 105 |
+
container from Nvidia, which has all the required tools to install FlashAttention.
|
| 106 |
+
|
| 107 |
+
FlashAttention-2 with CUDA currently supports:
|
| 108 |
+
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
|
| 109 |
+
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
|
| 110 |
+
GPUs for now.
|
| 111 |
+
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
| 112 |
+
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
|
| 113 |
+
|
| 114 |
+
### AMD ROCm Support
|
| 115 |
+
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.
|
| 116 |
+
|
| 117 |
+
**Requirements:**
|
| 118 |
+
- ROCm 6.0 and above.
|
| 119 |
+
|
| 120 |
+
We recommend the
|
| 121 |
+
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
|
| 122 |
+
container from ROCm, which has all the required tools to install FlashAttention.
|
| 123 |
+
|
| 124 |
+
#### Composable Kernel Backend
|
| 125 |
+
FlashAttention-2 ROCm CK backend currently supports:
|
| 126 |
+
1. MI200 or MI300 GPUs.
|
| 127 |
+
2. Datatype fp16 and bf16
|
| 128 |
+
3. Both forward's and backward's head dimensions up to 256.
|
| 129 |
+
|
| 130 |
+
#### Triton Backend
|
| 131 |
+
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
|
| 132 |
+
|
| 133 |
+
It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.
|
| 134 |
+
|
| 135 |
+
These features are supported in Fwd and Bwd
|
| 136 |
+
1) Fwd and Bwd with causal masking
|
| 137 |
+
2) Variable sequence lengths
|
| 138 |
+
3) Arbitrary Q and KV sequence lengths
|
| 139 |
+
4) Arbitrary head sizes
|
| 140 |
+
|
| 141 |
+
These features are supported in Fwd for now. We will add them to backward soon.
|
| 142 |
+
1) Multi and grouped query attention
|
| 143 |
+
2) ALiBi and matrix bias
|
| 144 |
+
|
| 145 |
+
These features are in development
|
| 146 |
+
1) Paged Attention
|
| 147 |
+
2) Sliding Window
|
| 148 |
+
3) Rotary embeddings
|
| 149 |
+
4) Dropout
|
| 150 |
+
5) Performance Improvements
|
| 151 |
+
|
| 152 |
+
#### Getting Started
|
| 153 |
+
To get started with the triton backend for AMD, follow the steps below.
|
| 154 |
+
|
| 155 |
+
First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).
|
| 156 |
+
|
| 157 |
+
```
|
| 158 |
+
git clone https://github.com/triton-lang/triton
|
| 159 |
+
cd triton
|
| 160 |
+
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
|
| 161 |
+
pip install --verbose -e python
|
| 162 |
+
```
|
| 163 |
+
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
|
| 164 |
+
|
| 165 |
+
```
|
| 166 |
+
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
|
| 167 |
+
cd flash-attention
|
| 168 |
+
python setup.py install
|
| 169 |
+
pytest tests/test_flash_attn.py
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
## How to use FlashAttention
|
| 174 |
+
|
| 175 |
+
The main functions implement scaled dot product attention (softmax(Q @ K^T *
|
| 176 |
+
softmax_scale) @ V):
|
| 177 |
+
```python
|
| 178 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
```python
|
| 182 |
+
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
|
| 183 |
+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
|
| 184 |
+
"""dropout_p should be set to 0.0 during evaluation
|
| 185 |
+
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
| 186 |
+
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
| 187 |
+
of the gradients of Q, K, V.
|
| 188 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
| 189 |
+
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
| 190 |
+
Arguments:
|
| 191 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
| 192 |
+
dropout_p: float. Dropout probability.
|
| 193 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 194 |
+
Default to 1 / sqrt(headdim).
|
| 195 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
| 196 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
| 197 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
| 198 |
+
the attention score of query i and key j.
|
| 199 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
| 200 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
| 201 |
+
Return:
|
| 202 |
+
out: (batch_size, seqlen, nheads, headdim).
|
| 203 |
+
"""
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
```python
|
| 207 |
+
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
|
| 208 |
+
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
|
| 209 |
+
"""dropout_p should be set to 0.0 during evaluation
|
| 210 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
| 211 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
| 212 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
| 213 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
| 214 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
| 215 |
+
will only attend to keys between
|
| 216 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
| 217 |
+
|
| 218 |
+
Arguments:
|
| 219 |
+
q: (batch_size, seqlen, nheads, headdim)
|
| 220 |
+
k: (batch_size, seqlen, nheads_k, headdim)
|
| 221 |
+
v: (batch_size, seqlen, nheads_k, headdim)
|
| 222 |
+
dropout_p: float. Dropout probability.
|
| 223 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 224 |
+
Default to 1 / sqrt(headdim).
|
| 225 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
| 226 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
| 227 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
| 228 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
| 229 |
+
is added to the attention score of query i and key j.
|
| 230 |
+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
| 231 |
+
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
| 232 |
+
Return:
|
| 233 |
+
out: (batch_size, seqlen, nheads, headdim).
|
| 234 |
+
"""
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
def flash_attn_with_kvcache(
|
| 239 |
+
q,
|
| 240 |
+
k_cache,
|
| 241 |
+
v_cache,
|
| 242 |
+
k=None,
|
| 243 |
+
v=None,
|
| 244 |
+
rotary_cos=None,
|
| 245 |
+
rotary_sin=None,
|
| 246 |
+
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
| 247 |
+
cache_batch_idx: Optional[torch.Tensor] = None,
|
| 248 |
+
block_table: Optional[torch.Tensor] = None,
|
| 249 |
+
softmax_scale=None,
|
| 250 |
+
causal=False,
|
| 251 |
+
window_size=(-1, -1), # -1 means infinite context window
|
| 252 |
+
rotary_interleaved=True,
|
| 253 |
+
alibi_slopes=None,
|
| 254 |
+
):
|
| 255 |
+
"""
|
| 256 |
+
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
| 257 |
+
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
| 258 |
+
the previous step, and update them with the new keys/values from the current step, and do
|
| 259 |
+
attention with the updated cache, all in 1 kernel.
|
| 260 |
+
|
| 261 |
+
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
| 262 |
+
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
| 263 |
+
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
| 264 |
+
|
| 265 |
+
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
| 266 |
+
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
| 267 |
+
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
| 268 |
+
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
| 269 |
+
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
| 270 |
+
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
| 271 |
+
|
| 272 |
+
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
| 273 |
+
|
| 274 |
+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
| 275 |
+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
| 276 |
+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
| 277 |
+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
| 278 |
+
|
| 279 |
+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
| 280 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
| 281 |
+
1 1 1 1 0
|
| 282 |
+
1 1 1 1 1
|
| 283 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
| 284 |
+
0 0
|
| 285 |
+
0 0
|
| 286 |
+
0 0
|
| 287 |
+
1 0
|
| 288 |
+
1 1
|
| 289 |
+
If the row of the mask is all zero, the output will be zero.
|
| 290 |
+
|
| 291 |
+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
| 292 |
+
will only attend to keys between
|
| 293 |
+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
| 294 |
+
|
| 295 |
+
Note: Does not support backward pass.
|
| 296 |
+
|
| 297 |
+
Arguments:
|
| 298 |
+
q: (batch_size, seqlen, nheads, headdim)
|
| 299 |
+
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
|
| 300 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
|
| 301 |
+
page_block_size must be a multiple of 256.
|
| 302 |
+
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
|
| 303 |
+
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
|
| 304 |
+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
| 305 |
+
k with k_cache, starting at the indices specified by cache_seqlens.
|
| 306 |
+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
| 307 |
+
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
| 308 |
+
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
| 309 |
+
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
| 310 |
+
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
| 311 |
+
KV cache.
|
| 312 |
+
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
| 313 |
+
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
| 314 |
+
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
| 315 |
+
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
| 316 |
+
might come from any of the duplicate indices.
|
| 317 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 318 |
+
Default to 1 / sqrt(headdim).
|
| 319 |
+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
| 320 |
+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
| 321 |
+
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
| 322 |
+
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
| 323 |
+
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
| 324 |
+
(i.e. GPT-NeoX style).
|
| 325 |
+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
| 326 |
+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
| 327 |
+
is added to the attention score of query i and key j.
|
| 328 |
+
|
| 329 |
+
Return:
|
| 330 |
+
out: (batch_size, seqlen, nheads, headdim).
|
| 331 |
+
"""
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
To see how these functions are used in a multi-head attention layer (which
|
| 335 |
+
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
|
| 336 |
+
|
| 337 |
+
## Changelog
|
| 338 |
+
|
| 339 |
+
### 2.0: Complete rewrite, 2x faster
|
| 340 |
+
Upgrading from FlashAttention (1.x) to FlashAttention-2
|
| 341 |
+
|
| 342 |
+
These functions have been renamed:
|
| 343 |
+
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
|
| 344 |
+
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
|
| 345 |
+
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
|
| 346 |
+
|
| 347 |
+
If the inputs have the same sequence lengths in the same batch, it is simpler
|
| 348 |
+
and faster to use these functions:
|
| 349 |
+
```python
|
| 350 |
+
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
|
| 351 |
+
```
|
| 352 |
+
```python
|
| 353 |
+
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
| 354 |
+
```
|
| 355 |
+
### 2.1: Change behavior of causal flag
|
| 356 |
+
|
| 357 |
+
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
|
| 358 |
+
bottom right corner of the attention matrix, instead of the top-left corner.
|
| 359 |
+
|
| 360 |
+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =
|
| 361 |
+
masked out) is:
|
| 362 |
+
v2.0:
|
| 363 |
+
1 0 0 0 0
|
| 364 |
+
1 1 0 0 0
|
| 365 |
+
v2.1:
|
| 366 |
+
1 1 1 1 0
|
| 367 |
+
1 1 1 1 1
|
| 368 |
+
|
| 369 |
+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
| 370 |
+
v2.0:
|
| 371 |
+
1 0
|
| 372 |
+
1 1
|
| 373 |
+
1 1
|
| 374 |
+
1 1
|
| 375 |
+
1 1
|
| 376 |
+
v2.1:
|
| 377 |
+
0 0
|
| 378 |
+
0 0
|
| 379 |
+
0 0
|
| 380 |
+
1 0
|
| 381 |
+
1 1
|
| 382 |
+
If the row of the mask is all zero, the output will be zero.
|
| 383 |
+
|
| 384 |
+
### 2.2: Optimize for inference
|
| 385 |
+
|
| 386 |
+
Optimize for inference (iterative decoding) when query has very small sequence
|
| 387 |
+
length (e.g., query sequence length = 1). The bottleneck here is to load KV
|
| 388 |
+
cache as fast as possible, and we split the loading across different thread
|
| 389 |
+
blocks, with a separate kernel to combine results.
|
| 390 |
+
|
| 391 |
+
See the function `flash_attn_with_kvcache` with more features for inference
|
| 392 |
+
(perform rotary embedding, updating KV cache inplace).
|
| 393 |
+
|
| 394 |
+
Thanks to the xformers team, and in particular Daniel Haziza, for this
|
| 395 |
+
collaboration.
|
| 396 |
+
|
| 397 |
+
### 2.3: Local (i.e., sliding window) attention
|
| 398 |
+
|
| 399 |
+
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
|
| 400 |
+
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
|
| 401 |
+
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
|
| 402 |
+
|
| 403 |
+
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
|
| 404 |
+
|
| 405 |
+
Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
|
| 406 |
+
|
| 407 |
+
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
|
| 408 |
+
|
| 409 |
+
### 2.5: Paged KV cache.
|
| 410 |
+
|
| 411 |
+
Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
|
| 412 |
+
Thanks to @beginlner for this contribution.
|
| 413 |
+
|
| 414 |
+
### 2.6: Softcapping.
|
| 415 |
+
|
| 416 |
+
Support attention with softcapping, as used in Gemma-2 and Grok models.
|
| 417 |
+
Thanks to @Narsil and @lucidrains for this contribution.
|
| 418 |
+
|
| 419 |
+
### 2.7: Compatibility with torch compile
|
| 420 |
+
|
| 421 |
+
Thanks to @ani300 for this contribution.
|
| 422 |
+
|
| 423 |
+
## Performance
|
| 424 |
+
|
| 425 |
+
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|
| 426 |
+
|
| 427 |
+
We currently have benchmarks for these GPUs:
|
| 428 |
+
* [A100](#a100)
|
| 429 |
+
* [H100](#h100)
|
| 430 |
+
<!-- * [RTX 3090](#rtx-3090) -->
|
| 431 |
+
<!-- * [T4](#t4) -->
|
| 432 |
+
|
| 433 |
+
### A100
|
| 434 |
+
|
| 435 |
+
We display FlashAttention speedup using these parameters:
|
| 436 |
+
* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
|
| 437 |
+
* Sequence length 512, 1k, 2k, 4k, 8k, 16k.
|
| 438 |
+
* Batch size set to 16k / seqlen.
|
| 439 |
+
|
| 440 |
+
#### Speedup
|
| 441 |
+
|
| 442 |
+

|
| 443 |
+
|
| 444 |
+
#### Memory
|
| 445 |
+
|
| 446 |
+

|
| 447 |
+
|
| 448 |
+
We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
|
| 449 |
+
Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
|
| 450 |
+
We see 10X memory savings at sequence length 2K, and 20X at 4K.
|
| 451 |
+
As a result, FlashAttention can scale to much longer sequence lengths.
|
| 452 |
+
|
| 453 |
+
### H100
|
| 454 |
+
|
| 455 |
+

|
| 456 |
+
|
| 457 |
+
## Full model code and training script
|
| 458 |
+
|
| 459 |
+
We have released the full GPT model
|
| 460 |
+
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
|
| 461 |
+
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
|
| 462 |
+
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
|
| 463 |
+
compared to the baseline implementation from Huggingface, reaching up to 225
|
| 464 |
+
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
|
| 465 |
+
any activation checkpointing).
|
| 466 |
+
|
| 467 |
+
We also include a training
|
| 468 |
+
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
|
| 469 |
+
train GPT2 on Openwebtext and GPT3 on The Pile.
|
| 470 |
+
|
| 471 |
+
## Triton implementation of FlashAttention
|
| 472 |
+
|
| 473 |
+
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
|
| 474 |
+
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
| 475 |
+
|
| 476 |
+
As Triton is a higher-level language than CUDA, it might be easier to understand
|
| 477 |
+
and experiment with. The notations in the Triton implementation are also closer
|
| 478 |
+
to what's used in our paper.
|
| 479 |
+
|
| 480 |
+
We also have an experimental implementation in Triton that support attention
|
| 481 |
+
bias (e.g. ALiBi):
|
| 482 |
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
## Tests
|
| 486 |
+
We test that FlashAttention produces the same output and gradient as a reference
|
| 487 |
+
implementation, up to some numerical tolerance. In particular, we check that the
|
| 488 |
+
maximum numerical error of FlashAttention is at most twice the numerical error
|
| 489 |
+
of a baseline implementation in Pytorch (for different head dimensions, input
|
| 490 |
+
dtype, sequence length, causal / non-causal).
|
| 491 |
+
|
| 492 |
+
To run the tests:
|
| 493 |
+
```sh
|
| 494 |
+
pytest -q -s tests/test_flash_attn.py
|
| 495 |
+
```
|
| 496 |
+
## When you encounter issues
|
| 497 |
+
|
| 498 |
+
This new release of FlashAttention-2 has been tested on several GPT-style
|
| 499 |
+
models, mostly on A100 GPUs.
|
| 500 |
+
|
| 501 |
+
If you encounter bugs, please open a GitHub Issue!
|
| 502 |
+
|
| 503 |
+
## Tests
|
| 504 |
+
To run the tests:
|
| 505 |
+
```sh
|
| 506 |
+
pytest tests/test_flash_attn_ck.py
|
| 507 |
+
```
|
| 508 |
+
|
| 509 |
+
## Citation
|
| 510 |
+
If you use this codebase, or otherwise found our work valuable, please cite:
|
| 511 |
+
```
|
| 512 |
+
@inproceedings{dao2022flashattention,
|
| 513 |
+
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
|
| 514 |
+
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
|
| 515 |
+
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
|
| 516 |
+
year={2022}
|
| 517 |
+
}
|
| 518 |
+
@inproceedings{dao2023flashattention2,
|
| 519 |
+
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
|
| 520 |
+
author={Dao, Tri},
|
| 521 |
+
booktitle={International Conference on Learning Representations (ICLR)},
|
| 522 |
+
year={2024}
|
| 523 |
+
}
|
| 524 |
+
```
|
cookbooks/flash-attention/assets/flash2_a100_fwd_bwd_benchmark.png
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flash2_h100_fwd_bwd_benchmark.png
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flash3_fp16_fwd.png
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattention_logo.png
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattn_banner.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattn_banner.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f4df0222057bbffcd2894fbae18bbfa6304e5d0583d47e44e9ac7a97bfb75ce
|
| 3 |
+
size 474702
|
cookbooks/flash-attention/assets/flashattn_memory.jpg
ADDED
|
cookbooks/flash-attention/assets/flashattn_speedup.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattn_speedup_3090.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattn_speedup_a100_d128.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattn_speedup_t4.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/flashattn_speedup_t4_fwd.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/gpt2_training_curve.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/gpt2_training_efficiency.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/gpt3_training_curve.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/assets/gpt3_training_efficiency.jpg
ADDED
|
Git LFS Details
|
cookbooks/flash-attention/benchmarks/benchmark_alibi.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Sanghun Cho, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
from flash_attn.layers.rotary import apply_rotary_emb
|
| 11 |
+
|
| 12 |
+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
| 13 |
+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
| 14 |
+
|
| 15 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import xformers.ops as xops
|
| 19 |
+
except ImportError:
|
| 20 |
+
xops = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
|
| 24 |
+
assert rotary_dim % 2 == 0
|
| 25 |
+
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
|
| 26 |
+
cos = torch.cos(angle).to(dtype=dtype)
|
| 27 |
+
sin = torch.sin(angle).to(dtype=dtype)
|
| 28 |
+
return cos, sin
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def flash_rotary(q, k, v, cos, sin, causal=False):
|
| 32 |
+
# corrected by @tridao comments
|
| 33 |
+
q = apply_rotary_emb(
|
| 34 |
+
q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
|
| 35 |
+
)
|
| 36 |
+
k = apply_rotary_emb(
|
| 37 |
+
k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return flash_attn_func(q, k, v, causal=causal)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def attn_bias_from_alibi_slopes(
|
| 44 |
+
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
|
| 45 |
+
):
|
| 46 |
+
batch, nheads = slopes.shape
|
| 47 |
+
device = slopes.device
|
| 48 |
+
slopes = rearrange(slopes, "b h -> b h 1 1")
|
| 49 |
+
if causal:
|
| 50 |
+
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
|
| 51 |
+
else:
|
| 52 |
+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
| 53 |
+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
| 54 |
+
sk = (
|
| 55 |
+
seqlen_k
|
| 56 |
+
if key_padding_mask is None
|
| 57 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 58 |
+
)
|
| 59 |
+
sq = (
|
| 60 |
+
seqlen_q
|
| 61 |
+
if query_padding_mask is None
|
| 62 |
+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 63 |
+
)
|
| 64 |
+
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
|
| 65 |
+
return -slopes * relative_pos.to(dtype=slopes.dtype)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
| 69 |
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
| 70 |
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
| 71 |
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def efficiency(flop, time):
|
| 75 |
+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):
|
| 79 |
+
"""
|
| 80 |
+
Arguments:
|
| 81 |
+
q, k, v: (batch_size, seqlen, nheads, head_dim)
|
| 82 |
+
dropout_p: float
|
| 83 |
+
attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)
|
| 84 |
+
Output:
|
| 85 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
| 86 |
+
"""
|
| 87 |
+
batch_size, seqlen, nheads, d = q.shape
|
| 88 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
| 89 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
| 90 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
| 91 |
+
# Preallocate attn_weights for `baddbmm`
|
| 92 |
+
if attn_bias is not None:
|
| 93 |
+
scores = rearrange(attn_bias, 'b h t s -> (b h) t s')
|
| 94 |
+
else:
|
| 95 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)
|
| 96 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),
|
| 97 |
+
'(b h) t s -> b h t s', h=nheads)
|
| 98 |
+
if causal:
|
| 99 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
| 100 |
+
# So we have to construct the mask in float
|
| 101 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
| 102 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 103 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 104 |
+
attention = torch.softmax(scores, dim=-1)
|
| 105 |
+
attention_drop = F.dropout(attention, dropout_p)
|
| 106 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
| 107 |
+
return output.to(dtype=q.dtype)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def time_fwd_bwd(func, *args, **kwargs):
|
| 111 |
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
| 112 |
+
return time_f[1].mean, time_b[1].mean
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
repeats = 30
|
| 116 |
+
device = 'cuda'
|
| 117 |
+
dtype = torch.float16
|
| 118 |
+
|
| 119 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
| 120 |
+
causal_vals = [False, True]
|
| 121 |
+
headdim_vals = [64, 128]
|
| 122 |
+
dim = 2048
|
| 123 |
+
dropout_p = 0.0
|
| 124 |
+
|
| 125 |
+
methods = (["fa2_alibi", "torch"]
|
| 126 |
+
+ (["xformers"] if xops is not None else [])
|
| 127 |
+
+ ["sdpa"]
|
| 128 |
+
+ ["fa2_baseline"]
|
| 129 |
+
+ ["fa2_rotary"])
|
| 130 |
+
|
| 131 |
+
time_f = {}
|
| 132 |
+
time_b = {}
|
| 133 |
+
time_f_b = {}
|
| 134 |
+
speed_f = {}
|
| 135 |
+
speed_b = {}
|
| 136 |
+
speed_f_b = {}
|
| 137 |
+
for causal in causal_vals:
|
| 138 |
+
for headdim in headdim_vals:
|
| 139 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
| 140 |
+
config = (causal, headdim, batch_size, seqlen)
|
| 141 |
+
nheads = dim // headdim
|
| 142 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
| 143 |
+
requires_grad=True) for _ in range(3)]
|
| 144 |
+
# alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
|
| 145 |
+
alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3
|
| 146 |
+
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)
|
| 147 |
+
attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size)
|
| 148 |
+
f, b = time_fwd_bwd(
|
| 149 |
+
flash_attn_func,
|
| 150 |
+
q, k, v,
|
| 151 |
+
dropout_p,
|
| 152 |
+
causal=causal,
|
| 153 |
+
# alibi_slopes=alibi_slopes,
|
| 154 |
+
alibi_slopes=None,
|
| 155 |
+
repeats=repeats,
|
| 156 |
+
verbose=False
|
| 157 |
+
)
|
| 158 |
+
time_f[config, "fa2_baseline"] = f
|
| 159 |
+
time_b[config, "fa2_baseline"] = b
|
| 160 |
+
|
| 161 |
+
q = q.detach().requires_grad_(True)
|
| 162 |
+
k = k.detach().requires_grad_(True)
|
| 163 |
+
v = v.detach().requires_grad_(True)
|
| 164 |
+
f, b = time_fwd_bwd(
|
| 165 |
+
flash_attn_func,
|
| 166 |
+
q, k, v,
|
| 167 |
+
dropout_p,
|
| 168 |
+
causal=causal,
|
| 169 |
+
alibi_slopes=rearrange(alibi_slopes, "1 h -> h"),
|
| 170 |
+
# alibi_slopes=None,
|
| 171 |
+
repeats=repeats,
|
| 172 |
+
verbose=False
|
| 173 |
+
)
|
| 174 |
+
time_f[config, "fa2_alibi"] = f
|
| 175 |
+
time_b[config, "fa2_alibi"] = b
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
q = q.detach().requires_grad_(True)
|
| 179 |
+
k = k.detach().requires_grad_(True)
|
| 180 |
+
v = v.detach().requires_grad_(True)
|
| 181 |
+
f, b = time_fwd_bwd(
|
| 182 |
+
attention_pytorch,
|
| 183 |
+
q, k, v,
|
| 184 |
+
dropout_p,
|
| 185 |
+
causal=causal,
|
| 186 |
+
attn_bias=attn_bias,
|
| 187 |
+
repeats=repeats,
|
| 188 |
+
verbose=False
|
| 189 |
+
)
|
| 190 |
+
except: # Skip if OOM
|
| 191 |
+
f, b = float('nan'), float('nan')
|
| 192 |
+
time_f[config, "torch"] = f
|
| 193 |
+
time_b[config, "torch"] = b
|
| 194 |
+
|
| 195 |
+
# F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe
|
| 196 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=False):
|
| 197 |
+
q_pt = q.detach().requires_grad_(True).transpose(1, 2)
|
| 198 |
+
k_pt = k.detach().requires_grad_(True).transpose(1, 2)
|
| 199 |
+
v_pt = v.detach().requires_grad_(True).transpose(1, 2)
|
| 200 |
+
f, b = time_fwd_bwd(
|
| 201 |
+
F.scaled_dot_product_attention,
|
| 202 |
+
q_pt, k_pt, v_pt,
|
| 203 |
+
attn_mask=attn_bias,
|
| 204 |
+
dropout_p=dropout_p,
|
| 205 |
+
is_causal=causal,
|
| 206 |
+
repeats=repeats,
|
| 207 |
+
verbose=False
|
| 208 |
+
)
|
| 209 |
+
time_f[config, "sdpa"] = f
|
| 210 |
+
time_b[config, "sdpa"] = b
|
| 211 |
+
|
| 212 |
+
if xops is not None:
|
| 213 |
+
q = q.detach().requires_grad_(True)
|
| 214 |
+
k = k.detach().requires_grad_(True)
|
| 215 |
+
v = v.detach().requires_grad_(True)
|
| 216 |
+
if causal:
|
| 217 |
+
attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))
|
| 218 |
+
# NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:
|
| 219 |
+
# `flshattB@v2.3.6` is not supported because:
|
| 220 |
+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
|
| 221 |
+
# `cutlassB` is not supported because:
|
| 222 |
+
# attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>
|
| 223 |
+
attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)
|
| 224 |
+
else:
|
| 225 |
+
attn_bias_xops = attn_bias.to(dtype=q.dtype)
|
| 226 |
+
f, b = time_fwd_bwd(
|
| 227 |
+
xops.memory_efficient_attention,
|
| 228 |
+
q, k, v,
|
| 229 |
+
attn_bias_xops,
|
| 230 |
+
dropout_p,
|
| 231 |
+
repeats=repeats,
|
| 232 |
+
verbose=False
|
| 233 |
+
)
|
| 234 |
+
time_f[config, "xformers"] = f
|
| 235 |
+
time_b[config, "xformers"] = b
|
| 236 |
+
|
| 237 |
+
q = q.detach().requires_grad_(True)
|
| 238 |
+
k = k.detach().requires_grad_(True)
|
| 239 |
+
v = v.detach().requires_grad_(True)
|
| 240 |
+
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
|
| 241 |
+
f, b = time_fwd_bwd(
|
| 242 |
+
flash_rotary,
|
| 243 |
+
q, k, v,
|
| 244 |
+
cos, sin,
|
| 245 |
+
causal,
|
| 246 |
+
repeats=repeats,
|
| 247 |
+
verbose=False
|
| 248 |
+
)
|
| 249 |
+
time_f[config, "fa2_rotary"] = f
|
| 250 |
+
time_b[config, "fa2_rotary"] = b
|
| 251 |
+
|
| 252 |
+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
| 253 |
+
csv_output = ""
|
| 254 |
+
csv_output += f"{causal},{headdim},{batch_size},{seqlen},"
|
| 255 |
+
for method in methods:
|
| 256 |
+
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
|
| 257 |
+
speed_f[config, method] = efficiency(
|
| 258 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
| 259 |
+
time_f[config, method]
|
| 260 |
+
)
|
| 261 |
+
speed_b[config, method] = efficiency(
|
| 262 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
|
| 263 |
+
time_b[config, method]
|
| 264 |
+
)
|
| 265 |
+
speed_f_b[config, method] = efficiency(
|
| 266 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
|
| 267 |
+
time_f_b[config, method]
|
| 268 |
+
)
|
| 269 |
+
print(
|
| 270 |
+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
|
| 271 |
+
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
|
| 272 |
+
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
|
| 273 |
+
)
|
| 274 |
+
csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},"
|
| 275 |
+
print(csv_output)
|
cookbooks/flash-attention/benchmarks/benchmark_causal.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from einops import rearrange, repeat
|
| 8 |
+
|
| 9 |
+
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
| 10 |
+
from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
| 11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 12 |
+
# # from flash_attn.triton.fused_attention import attention as attention
|
| 13 |
+
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
|
| 14 |
+
# from flash_attn.flash_attn_triton_og import attention as attention_og
|
| 15 |
+
|
| 16 |
+
# from triton.ops.flash_attention import attention as attention_triton
|
| 17 |
+
|
| 18 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
|
| 22 |
+
except ImportError:
|
| 23 |
+
scaled_upper_triang_masked_softmax = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
| 27 |
+
"""
|
| 28 |
+
Arguments:
|
| 29 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
| 30 |
+
dropout_p: float
|
| 31 |
+
Output:
|
| 32 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
| 33 |
+
"""
|
| 34 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
| 35 |
+
q, k, v = qkv.unbind(dim=2)
|
| 36 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
| 37 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
| 38 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
| 39 |
+
# Preallocate attn_weights for `baddbmm`
|
| 40 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
| 41 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
| 42 |
+
'(b h) t s -> b h t s', h=nheads)
|
| 43 |
+
if causal:
|
| 44 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
| 45 |
+
# So we have to construct the mask in float
|
| 46 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
| 47 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 48 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 49 |
+
attention = torch.softmax(scores, dim=-1)
|
| 50 |
+
attention_drop = F.dropout(attention, dropout_p)
|
| 51 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
| 52 |
+
return output.to(dtype=qkv.dtype)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def attention_megatron(qkv):
|
| 56 |
+
"""
|
| 57 |
+
Arguments:
|
| 58 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
| 59 |
+
Output:
|
| 60 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
| 61 |
+
"""
|
| 62 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
| 63 |
+
q, k, v = qkv.unbind(dim=2)
|
| 64 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
| 65 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
| 66 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
| 67 |
+
# Preallocate attn_weights for `baddbmm`
|
| 68 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
| 69 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
| 70 |
+
'(b h) t s -> b h t s', h=nheads)
|
| 71 |
+
attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0)
|
| 72 |
+
output = torch.einsum('bhts,bshd->bthd', attention, v)
|
| 73 |
+
return output.to(dtype=qkv.dtype)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
torch.manual_seed(0)
|
| 77 |
+
repeats = 30
|
| 78 |
+
batch_size = 8
|
| 79 |
+
seqlen = 2048
|
| 80 |
+
nheads = 12
|
| 81 |
+
headdim = 128
|
| 82 |
+
# nheads = 24
|
| 83 |
+
# headdim = 64
|
| 84 |
+
# batch_size = 64
|
| 85 |
+
# seqlen = 512
|
| 86 |
+
# nheads = 8
|
| 87 |
+
# headdim = 128
|
| 88 |
+
dropout_p = 0.0
|
| 89 |
+
causal = True
|
| 90 |
+
dtype = torch.float16
|
| 91 |
+
device = 'cuda'
|
| 92 |
+
|
| 93 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
| 94 |
+
requires_grad=True)
|
| 95 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
| 96 |
+
device=qkv.device)
|
| 97 |
+
|
| 98 |
+
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
| 99 |
+
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
| 100 |
+
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
|
| 101 |
+
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
| 102 |
+
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
|
| 103 |
+
benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
| 104 |
+
pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
|
| 105 |
+
|
| 106 |
+
# for dropout_p in [0.1, 0.0]:
|
| 107 |
+
# for causal in [False, True]:
|
| 108 |
+
# print(f"### {dropout_p = }, {causal = } ###")
|
| 109 |
+
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# nheads_k = 2
|
| 113 |
+
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
| 114 |
+
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
|
| 115 |
+
# requires_grad=True)
|
| 116 |
+
# if fav2_kvpacked_func is not None:
|
| 117 |
+
# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
| 118 |
+
# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
|
| 119 |
+
|
| 120 |
+
# dropout_p = 0.0
|
| 121 |
+
# causal = False
|
| 122 |
+
# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
| 123 |
+
# repeats=repeats, desc='PyTorch Attention')
|
| 124 |
+
|
| 125 |
+
# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
|
| 126 |
+
# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
|
| 127 |
+
|
| 128 |
+
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
| 129 |
+
# requires_grad=True) for _ in range(3)]
|
| 130 |
+
# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
|
| 131 |
+
# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
|
| 132 |
+
|
| 133 |
+
# if scaled_upper_triang_masked_softmax is not None:
|
| 134 |
+
# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
|
| 135 |
+
|
| 136 |
+
# from src.ops.fftconv import fftconv_func
|
| 137 |
+
|
| 138 |
+
# dim = nheads * headdim
|
| 139 |
+
# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
|
| 140 |
+
# k = torch.randn(dim, seqlen, device=device, requires_grad=True)
|
| 141 |
+
# D = torch.randn(dim, device=device, requires_grad=True)
|
| 142 |
+
# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
|
| 143 |
+
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
|
| 144 |
+
# pytorch_profiler(torch.fft.rfft, u.float())
|
| 145 |
+
|
| 146 |
+
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
|
| 147 |
+
ideal_a100_time = flops / 312 / 1e9
|
| 148 |
+
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
|
| 149 |
+
exit(0)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def time_fwd_bwd(func, *args, **kwargs):
|
| 153 |
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
| 154 |
+
return time_f[1].mean, time_b[1].mean
|
| 155 |
+
|
| 156 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
| 157 |
+
causal_vals = [False, True]
|
| 158 |
+
headdim_vals = [64, 128]
|
| 159 |
+
dim = 2048
|
| 160 |
+
dropout_p = 0.0
|
| 161 |
+
|
| 162 |
+
time_f = {}
|
| 163 |
+
time_b = {}
|
| 164 |
+
for causal in causal_vals:
|
| 165 |
+
for headdim in headdim_vals:
|
| 166 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
| 167 |
+
nheads = dim // headdim
|
| 168 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
| 169 |
+
requires_grad=True)
|
| 170 |
+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
| 171 |
+
device=qkv.device)
|
| 172 |
+
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
| 173 |
+
f, b = time_fwd_bwd(
|
| 174 |
+
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
|
| 175 |
+
causal=causal, repeats=repeats, verbose=False
|
| 176 |
+
)
|
| 177 |
+
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
|
| 178 |
+
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
|
| 179 |
+
|
| 180 |
+
qkv = qkv.detach().requires_grad_(True)
|
| 181 |
+
f, b = time_fwd_bwd(
|
| 182 |
+
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
| 183 |
+
)
|
| 184 |
+
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
|
| 185 |
+
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
|
| 186 |
+
|
| 187 |
+
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
| 188 |
+
# requires_grad=True) for _ in range(3)]
|
| 189 |
+
# # Try both values of sequence_parallel and pick the faster one
|
| 190 |
+
# f, b = time_fwd_bwd(
|
| 191 |
+
# attention_triton, q, k, v, causal, headdim**(-0.5),
|
| 192 |
+
# False, repeats=repeats, verbose=False
|
| 193 |
+
# )
|
| 194 |
+
# _, b0 = time_fwd_bwd(
|
| 195 |
+
# attention_triton, q, k, v, causal, headdim**(-0.5),
|
| 196 |
+
# True, repeats=repeats, verbose=False
|
| 197 |
+
# )
|
| 198 |
+
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
|
| 199 |
+
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
|
| 200 |
+
|
| 201 |
+
if seqlen <= 8 * 1024:
|
| 202 |
+
qkv = qkv.detach().requires_grad_(True)
|
| 203 |
+
f, b = time_fwd_bwd(
|
| 204 |
+
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
f, b = float('nan'), float('nan')
|
| 208 |
+
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
|
| 209 |
+
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
|
| 210 |
+
|
| 211 |
+
# q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
| 212 |
+
# requires_grad=True) for _ in range(3)]
|
| 213 |
+
# import xformers.ops as xops
|
| 214 |
+
# f, b = time_fwd_bwd(
|
| 215 |
+
# xops.memory_efficient_attention, q, k, v,
|
| 216 |
+
# attn_bias=xops.LowerTriangularMask() if causal else None,
|
| 217 |
+
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
|
| 218 |
+
# )
|
| 219 |
+
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
|
| 220 |
+
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
import pickle
|
| 224 |
+
with open('flash2_attn_time_h100.plk', 'wb') as fp:
|
| 225 |
+
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
cookbooks/flash-attention/benchmarks/benchmark_flash_attention.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Install the newest triton version with
|
| 2 |
+
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
|
| 3 |
+
import pickle
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
|
| 12 |
+
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
|
| 13 |
+
|
| 14 |
+
from flash_attn import flash_attn_qkvpacked_func
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from triton.ops.flash_attention import attention as attention_triton
|
| 18 |
+
except ImportError:
|
| 19 |
+
attention_triton = None
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import xformers.ops as xops
|
| 23 |
+
except ImportError:
|
| 24 |
+
xops = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
|
| 28 |
+
assert mode in ["fwd", "bwd", "fwd_bwd"]
|
| 29 |
+
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
|
| 30 |
+
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
|
| 31 |
+
|
| 32 |
+
def efficiency(flop, time):
|
| 33 |
+
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
| 37 |
+
"""
|
| 38 |
+
Arguments:
|
| 39 |
+
qkv: (batch_size, seqlen, 3, nheads, head_dim)
|
| 40 |
+
dropout_p: float
|
| 41 |
+
Output:
|
| 42 |
+
output: (batch_size, seqlen, nheads, head_dim)
|
| 43 |
+
"""
|
| 44 |
+
batch_size, seqlen, _, nheads, d = qkv.shape
|
| 45 |
+
q, k, v = qkv.unbind(dim=2)
|
| 46 |
+
q = rearrange(q, 'b t h d -> (b h) t d')
|
| 47 |
+
k = rearrange(k, 'b s h d -> (b h) d s')
|
| 48 |
+
softmax_scale = 1.0 / math.sqrt(d)
|
| 49 |
+
# Preallocate attn_weights for `baddbmm`
|
| 50 |
+
scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
|
| 51 |
+
scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
|
| 52 |
+
'(b h) t s -> b h t s', h=nheads)
|
| 53 |
+
if causal:
|
| 54 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
| 55 |
+
# So we have to construct the mask in float
|
| 56 |
+
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
| 57 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 58 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 59 |
+
attention = torch.softmax(scores, dim=-1)
|
| 60 |
+
attention_drop = F.dropout(attention, dropout_p)
|
| 61 |
+
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
| 62 |
+
return output.to(dtype=qkv.dtype)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def time_fwd_bwd(func, *args, **kwargs):
|
| 66 |
+
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
| 67 |
+
return time_f[1].mean, time_b[1].mean
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
repeats = 30
|
| 71 |
+
device = 'cuda'
|
| 72 |
+
dtype = torch.float16
|
| 73 |
+
|
| 74 |
+
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
| 75 |
+
causal_vals = [False, True]
|
| 76 |
+
headdim_vals = [64, 128]
|
| 77 |
+
dim = 2048
|
| 78 |
+
dropout_p = 0.0
|
| 79 |
+
|
| 80 |
+
methods = (["Flash2", "Pytorch"]
|
| 81 |
+
+ (["Triton"] if attention_triton is not None else [])
|
| 82 |
+
+ (["xformers.c"] if xops is not None else [])
|
| 83 |
+
+ (["xformers.f"] if xops is not None else []))
|
| 84 |
+
|
| 85 |
+
time_f = {}
|
| 86 |
+
time_b = {}
|
| 87 |
+
time_f_b = {}
|
| 88 |
+
speed_f = {}
|
| 89 |
+
speed_b = {}
|
| 90 |
+
speed_f_b = {}
|
| 91 |
+
for causal in causal_vals:
|
| 92 |
+
for headdim in headdim_vals:
|
| 93 |
+
for batch_size, seqlen in bs_seqlen_vals:
|
| 94 |
+
config = (causal, headdim, batch_size, seqlen)
|
| 95 |
+
nheads = dim // headdim
|
| 96 |
+
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
| 97 |
+
requires_grad=True)
|
| 98 |
+
f, b = time_fwd_bwd(
|
| 99 |
+
flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
| 100 |
+
)
|
| 101 |
+
time_f[config, "Flash2"] = f
|
| 102 |
+
time_b[config, "Flash2"] = b
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
qkv = qkv.detach().requires_grad_(True)
|
| 106 |
+
f, b = time_fwd_bwd(
|
| 107 |
+
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
| 108 |
+
)
|
| 109 |
+
except: # Skip if OOM
|
| 110 |
+
f, b = float('nan'), float('nan')
|
| 111 |
+
time_f[config, "Pytorch"] = f
|
| 112 |
+
time_b[config, "Pytorch"] = b
|
| 113 |
+
|
| 114 |
+
if attention_triton is not None:
|
| 115 |
+
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
| 116 |
+
requires_grad=True) for _ in range(3)]
|
| 117 |
+
# Try both values of sequence_parallel and pick the faster one
|
| 118 |
+
try:
|
| 119 |
+
f, b = time_fwd_bwd(
|
| 120 |
+
attention_triton, q, k, v, causal, headdim**(-0.5),
|
| 121 |
+
False, repeats=repeats, verbose=False
|
| 122 |
+
)
|
| 123 |
+
except:
|
| 124 |
+
f, b = float('nan'), float('inf')
|
| 125 |
+
try:
|
| 126 |
+
_, b0 = time_fwd_bwd(
|
| 127 |
+
attention_triton, q, k, v, causal, headdim**(-0.5),
|
| 128 |
+
True, repeats=repeats, verbose=False
|
| 129 |
+
)
|
| 130 |
+
except:
|
| 131 |
+
b0 = float('inf')
|
| 132 |
+
time_f[config, "Triton"] = f
|
| 133 |
+
time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')
|
| 134 |
+
|
| 135 |
+
if xops is not None:
|
| 136 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
| 137 |
+
requires_grad=True) for _ in range(3)]
|
| 138 |
+
f, b = time_fwd_bwd(
|
| 139 |
+
xops.memory_efficient_attention, q, k, v,
|
| 140 |
+
attn_bias=xops.LowerTriangularMask() if causal else None,
|
| 141 |
+
op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
|
| 142 |
+
)
|
| 143 |
+
time_f[config, "xformers.c"] = f
|
| 144 |
+
time_b[config, "xformers.c"] = b
|
| 145 |
+
|
| 146 |
+
if xops is not None:
|
| 147 |
+
q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
| 148 |
+
requires_grad=True) for _ in range(3)]
|
| 149 |
+
f, b = time_fwd_bwd(
|
| 150 |
+
xops.memory_efficient_attention, q, k, v,
|
| 151 |
+
attn_bias=xops.LowerTriangularMask() if causal else None,
|
| 152 |
+
op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
|
| 153 |
+
)
|
| 154 |
+
time_f[config, "xformers.f"] = f
|
| 155 |
+
time_b[config, "xformers.f"] = b
|
| 156 |
+
|
| 157 |
+
print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
|
| 158 |
+
for method in methods:
|
| 159 |
+
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
|
| 160 |
+
speed_f[config, method] = efficiency(
|
| 161 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
|
| 162 |
+
time_f[config, method]
|
| 163 |
+
)
|
| 164 |
+
speed_b[config, method] = efficiency(
|
| 165 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
|
| 166 |
+
time_b[config, method]
|
| 167 |
+
)
|
| 168 |
+
speed_f_b[config, method] = efficiency(
|
| 169 |
+
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
|
| 170 |
+
time_f_b[config, method]
|
| 171 |
+
)
|
| 172 |
+
print(
|
| 173 |
+
f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, "
|
| 174 |
+
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, "
|
| 175 |
+
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# with open('flash2_attn_time.plk', 'wb') as fp:
|
| 180 |
+
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
cookbooks/flash-attention/benchmarks/benchmark_gemm.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.benchmark as benchmark
|
| 4 |
+
|
| 5 |
+
from triton.testing import do_bench
|
| 6 |
+
|
| 7 |
+
if torch.version.cuda:
|
| 8 |
+
backendBLAS = "cuBLAS"
|
| 9 |
+
elif torch.version.hip:
|
| 10 |
+
backendBLAS = "hipBLAS"
|
| 11 |
+
|
| 12 |
+
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
|
| 13 |
+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
| 14 |
+
if verbose:
|
| 15 |
+
print(desc, '- Forward pass')
|
| 16 |
+
t = benchmark.Timer(
|
| 17 |
+
stmt='fn(*inputs, **kwinputs)',
|
| 18 |
+
globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
|
| 19 |
+
num_threads=torch.get_num_threads(),
|
| 20 |
+
)
|
| 21 |
+
m = t.timeit(repeats)
|
| 22 |
+
if verbose:
|
| 23 |
+
print(m)
|
| 24 |
+
return t, m
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
torch.manual_seed(0)
|
| 28 |
+
repeats = 30
|
| 29 |
+
dtype = torch.bfloat16
|
| 30 |
+
device = 'cuda'
|
| 31 |
+
verbose = False
|
| 32 |
+
m, n = 8192, 8192
|
| 33 |
+
|
| 34 |
+
tflops_matmul = {}
|
| 35 |
+
tflops_matmul1 = {}
|
| 36 |
+
for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
|
| 37 |
+
a = torch.randn(m, k, device=device, dtype=dtype)
|
| 38 |
+
b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
|
| 39 |
+
nFLOPS_matmul = 2 * m * n * k
|
| 40 |
+
time.sleep(2) # to reduce power throttling
|
| 41 |
+
timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1]
|
| 42 |
+
tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
|
| 43 |
+
print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
|
| 44 |
+
time.sleep(2) # to reduce power throttling
|
| 45 |
+
ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
|
| 46 |
+
tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
|
| 47 |
+
print(f'[triton.test.do_bench] {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')
|
cookbooks/flash-attention/csrc/flash_attn/flash_api.cpp
ADDED
|
@@ -0,0 +1,1485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
|
| 6 |
+
#include <torch/python.h>
|
| 7 |
+
#include <torch/nn/functional.h>
|
| 8 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 9 |
+
#include <c10/cuda/CUDAStream.h>
|
| 10 |
+
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
|
| 11 |
+
#include "philox_unpack.cuh" // For at::cuda::philox::unpack
|
| 12 |
+
|
| 13 |
+
#include <cutlass/numeric_types.h>
|
| 14 |
+
|
| 15 |
+
#include "namespace_config.h"
|
| 16 |
+
#include "hardware_info.h"
|
| 17 |
+
#include "flash.h"
|
| 18 |
+
#include "static_switch.h"
|
| 19 |
+
|
| 20 |
+
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
| 21 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 22 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 23 |
+
|
| 24 |
+
namespace FLASH_NAMESPACE {
|
| 25 |
+
|
| 26 |
+
void set_params_fprop(Flash_fwd_params ¶ms,
|
| 27 |
+
// sizes
|
| 28 |
+
const size_t b,
|
| 29 |
+
const size_t seqlen_q,
|
| 30 |
+
const size_t seqlen_k,
|
| 31 |
+
const size_t seqlen_q_rounded,
|
| 32 |
+
const size_t seqlen_k_rounded,
|
| 33 |
+
const size_t h,
|
| 34 |
+
const size_t h_k,
|
| 35 |
+
const size_t d,
|
| 36 |
+
const size_t d_rounded,
|
| 37 |
+
// device pointers
|
| 38 |
+
const at::Tensor q,
|
| 39 |
+
const at::Tensor k,
|
| 40 |
+
const at::Tensor v,
|
| 41 |
+
at::Tensor out,
|
| 42 |
+
void *cu_seqlens_q_d,
|
| 43 |
+
void *cu_seqlens_k_d,
|
| 44 |
+
void *seqused_k,
|
| 45 |
+
void *p_d,
|
| 46 |
+
void *softmax_lse_d,
|
| 47 |
+
float p_dropout,
|
| 48 |
+
float softmax_scale,
|
| 49 |
+
int window_size_left,
|
| 50 |
+
int window_size_right,
|
| 51 |
+
const float softcap,
|
| 52 |
+
bool seqlenq_ngroups_swapped=false,
|
| 53 |
+
const bool unpadded_lse=false) {
|
| 54 |
+
|
| 55 |
+
// Reset the parameters
|
| 56 |
+
params = {};
|
| 57 |
+
|
| 58 |
+
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
| 59 |
+
|
| 60 |
+
// Set the pointers and strides.
|
| 61 |
+
params.q_ptr = q.data_ptr();
|
| 62 |
+
params.k_ptr = k.data_ptr();
|
| 63 |
+
params.v_ptr = v.data_ptr();
|
| 64 |
+
// All stride are in elements, not bytes.
|
| 65 |
+
params.q_row_stride = q.stride(-3);
|
| 66 |
+
params.k_row_stride = k.stride(-3);
|
| 67 |
+
params.v_row_stride = v.stride(-3);
|
| 68 |
+
params.q_head_stride = q.stride(-2);
|
| 69 |
+
params.k_head_stride = k.stride(-2);
|
| 70 |
+
params.v_head_stride = v.stride(-2);
|
| 71 |
+
params.o_ptr = out.data_ptr();
|
| 72 |
+
params.o_row_stride = out.stride(-3);
|
| 73 |
+
params.o_head_stride = out.stride(-2);
|
| 74 |
+
|
| 75 |
+
if (cu_seqlens_q_d == nullptr) {
|
| 76 |
+
params.q_batch_stride = q.stride(0);
|
| 77 |
+
params.k_batch_stride = k.stride(0);
|
| 78 |
+
params.v_batch_stride = v.stride(0);
|
| 79 |
+
params.o_batch_stride = out.stride(0);
|
| 80 |
+
if (seqlenq_ngroups_swapped) {
|
| 81 |
+
params.q_batch_stride *= seqlen_q;
|
| 82 |
+
params.o_batch_stride *= seqlen_q;
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
| 87 |
+
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
| 88 |
+
params.seqused_k = static_cast<int *>(seqused_k);
|
| 89 |
+
|
| 90 |
+
// P = softmax(QK^T)
|
| 91 |
+
params.p_ptr = p_d;
|
| 92 |
+
|
| 93 |
+
// Softmax sum
|
| 94 |
+
params.softmax_lse_ptr = softmax_lse_d;
|
| 95 |
+
|
| 96 |
+
// Set the dimensions.
|
| 97 |
+
params.b = b;
|
| 98 |
+
params.h = h;
|
| 99 |
+
params.h_k = h_k;
|
| 100 |
+
params.h_h_k_ratio = h / h_k;
|
| 101 |
+
params.seqlen_q = seqlen_q;
|
| 102 |
+
params.seqlen_k = seqlen_k;
|
| 103 |
+
params.seqlen_q_rounded = seqlen_q_rounded;
|
| 104 |
+
params.seqlen_k_rounded = seqlen_k_rounded;
|
| 105 |
+
params.d = d;
|
| 106 |
+
params.d_rounded = d_rounded;
|
| 107 |
+
|
| 108 |
+
// Set the different scale values.
|
| 109 |
+
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
| 110 |
+
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
|
| 111 |
+
#endif
|
| 112 |
+
if (softcap > 0.0) {
|
| 113 |
+
params.softcap = softmax_scale / softcap;
|
| 114 |
+
params.scale_softmax = softcap;
|
| 115 |
+
params.scale_softmax_log2 = softcap * M_LOG2E;
|
| 116 |
+
} else{
|
| 117 |
+
// Remove potential NaN
|
| 118 |
+
params.softcap = 0.0;
|
| 119 |
+
params.scale_softmax = softmax_scale;
|
| 120 |
+
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// Set this to probability of keeping an element to simplify things.
|
| 124 |
+
params.p_dropout = 1.f - p_dropout;
|
| 125 |
+
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
| 126 |
+
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
| 127 |
+
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
| 128 |
+
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
| 129 |
+
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
| 130 |
+
params.rp_dropout = 1.f / params.p_dropout;
|
| 131 |
+
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
| 132 |
+
TORCH_CHECK(p_dropout < 1.f);
|
| 133 |
+
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
| 134 |
+
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
|
| 135 |
+
#endif
|
| 136 |
+
|
| 137 |
+
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
| 138 |
+
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
| 139 |
+
params.is_causal = window_size_left < 0 && window_size_right == 0;
|
| 140 |
+
|
| 141 |
+
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
|
| 142 |
+
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
|
| 143 |
+
params.window_size_left = window_size_left;
|
| 144 |
+
params.window_size_right = window_size_right;
|
| 145 |
+
|
| 146 |
+
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
| 147 |
+
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
|
| 148 |
+
"This flash attention build does not support local attention.");
|
| 149 |
+
#endif
|
| 150 |
+
|
| 151 |
+
params.is_seqlens_k_cumulative = true;
|
| 152 |
+
|
| 153 |
+
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
| 154 |
+
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
|
| 155 |
+
#endif
|
| 156 |
+
|
| 157 |
+
params.unpadded_lse = unpadded_lse;
|
| 158 |
+
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
void set_params_dgrad(Flash_bwd_params ¶ms,
|
| 162 |
+
// sizes
|
| 163 |
+
const size_t b,
|
| 164 |
+
const size_t seqlen_q,
|
| 165 |
+
const size_t seqlen_k,
|
| 166 |
+
const size_t seqlen_q_rounded,
|
| 167 |
+
const size_t seqlen_k_rounded,
|
| 168 |
+
const size_t h,
|
| 169 |
+
const size_t h_k,
|
| 170 |
+
const size_t d,
|
| 171 |
+
const size_t d_rounded,
|
| 172 |
+
// device pointers
|
| 173 |
+
const at::Tensor q,
|
| 174 |
+
const at::Tensor k,
|
| 175 |
+
const at::Tensor v,
|
| 176 |
+
const at::Tensor out,
|
| 177 |
+
const at::Tensor dout,
|
| 178 |
+
at::Tensor dq,
|
| 179 |
+
at::Tensor dk,
|
| 180 |
+
at::Tensor dv,
|
| 181 |
+
void *cu_seqlens_q_d,
|
| 182 |
+
void *cu_seqlens_k_d,
|
| 183 |
+
void *dq_accum_d,
|
| 184 |
+
void *dk_accum_d,
|
| 185 |
+
void *dv_accum_d,
|
| 186 |
+
void *softmax_lse_d,
|
| 187 |
+
void *dsoftmax_sum_d,
|
| 188 |
+
float p_dropout,
|
| 189 |
+
float softmax_scale,
|
| 190 |
+
int window_size_left,
|
| 191 |
+
int window_size_right,
|
| 192 |
+
const float softcap,
|
| 193 |
+
bool deterministic,
|
| 194 |
+
const bool unpadded_lse) {
|
| 195 |
+
|
| 196 |
+
set_params_fprop(params,
|
| 197 |
+
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
| 198 |
+
q, k, v, out,
|
| 199 |
+
cu_seqlens_q_d,
|
| 200 |
+
cu_seqlens_k_d,
|
| 201 |
+
nullptr,
|
| 202 |
+
nullptr,
|
| 203 |
+
softmax_lse_d,
|
| 204 |
+
p_dropout,
|
| 205 |
+
softmax_scale,
|
| 206 |
+
window_size_left,
|
| 207 |
+
window_size_right,
|
| 208 |
+
softcap,
|
| 209 |
+
false, // seqlenq_ngroups_swapped
|
| 210 |
+
unpadded_lse);
|
| 211 |
+
|
| 212 |
+
// Set the pointers and strides.
|
| 213 |
+
params.do_ptr = dout.data_ptr();
|
| 214 |
+
params.do_row_stride = dout.stride(-3);
|
| 215 |
+
params.do_head_stride = dout.stride(-2);
|
| 216 |
+
params.dq_ptr = dq.data_ptr();
|
| 217 |
+
params.dk_ptr = dk.data_ptr();
|
| 218 |
+
params.dv_ptr = dv.data_ptr();
|
| 219 |
+
params.dq_row_stride = dq.stride(-3);
|
| 220 |
+
params.dk_row_stride = dk.stride(-3);
|
| 221 |
+
params.dv_row_stride = dv.stride(-3);
|
| 222 |
+
params.dq_head_stride = dq.stride(-2);
|
| 223 |
+
params.dk_head_stride = dk.stride(-2);
|
| 224 |
+
params.dv_head_stride = dv.stride(-2);
|
| 225 |
+
|
| 226 |
+
if (cu_seqlens_q_d == nullptr) {
|
| 227 |
+
params.do_batch_stride = dout.stride(0);
|
| 228 |
+
params.dq_batch_stride = dq.stride(0);
|
| 229 |
+
params.dk_batch_stride = dk.stride(0);
|
| 230 |
+
params.dv_batch_stride = dv.stride(0);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
params.dq_accum_ptr = dq_accum_d;
|
| 234 |
+
params.dk_accum_ptr = dk_accum_d;
|
| 235 |
+
params.dv_accum_ptr = dv_accum_d;
|
| 236 |
+
|
| 237 |
+
// Softmax sum
|
| 238 |
+
params.dsoftmax_sum = dsoftmax_sum_d;
|
| 239 |
+
|
| 240 |
+
params.deterministic = deterministic;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
| 244 |
+
FP16_SWITCH(!params.is_bf16, [&] {
|
| 245 |
+
HEADDIM_SWITCH(params.d, [&] {
|
| 246 |
+
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
| 247 |
+
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
| 248 |
+
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
| 249 |
+
} else {
|
| 250 |
+
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
|
| 251 |
+
}
|
| 252 |
+
});
|
| 253 |
+
});
|
| 254 |
+
});
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// Find the number of splits that maximizes the occupancy. For example, if we have
|
| 258 |
+
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
| 259 |
+
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
| 260 |
+
// splits as that would incur more HBM reads/writes.
|
| 261 |
+
// So we find the best efficiency, then find the smallest number of splits that gets 85%
|
| 262 |
+
// of the best efficiency.
|
| 263 |
+
inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
|
| 264 |
+
// If we have enough to almost fill the SMs, then just use 1 split
|
| 265 |
+
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
|
| 266 |
+
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
| 267 |
+
float max_efficiency = 0.f;
|
| 268 |
+
std::vector<float> efficiency;
|
| 269 |
+
efficiency.reserve(max_splits);
|
| 270 |
+
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
| 271 |
+
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
| 272 |
+
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
| 273 |
+
// (i.e. it's 11 splits anyway).
|
| 274 |
+
// So we check if the number of blocks per split is the same as the previous num_splits.
|
| 275 |
+
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
| 276 |
+
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
| 277 |
+
};
|
| 278 |
+
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
| 279 |
+
if (!is_split_eligible(num_splits)) {
|
| 280 |
+
efficiency.push_back(0.f);
|
| 281 |
+
} else {
|
| 282 |
+
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
|
| 283 |
+
float eff = n_waves / ceil(n_waves);
|
| 284 |
+
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
| 285 |
+
if (eff > max_efficiency) { max_efficiency = eff; }
|
| 286 |
+
efficiency.push_back(eff);
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
| 290 |
+
if (!is_split_eligible(num_splits)) { continue; }
|
| 291 |
+
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
| 292 |
+
// printf("num_splits chosen = %d\n", num_splits);
|
| 293 |
+
return num_splits;
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
return 1;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
|
| 300 |
+
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
|
| 301 |
+
const int head_size_rounded, const float p_dropout,
|
| 302 |
+
const int num_splits, const int num_sm, struct c10::TensorOptions opts) {
|
| 303 |
+
|
| 304 |
+
// This needs to match with run_mha_fwd_splitkv_dispatch
|
| 305 |
+
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
| 306 |
+
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
|
| 307 |
+
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
| 308 |
+
// In any case we don't expect seqlen_q to be larger than 64 for inference.
|
| 309 |
+
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
|
| 310 |
+
params.num_splits = num_splits;
|
| 311 |
+
at::Tensor softmax_lse_accum;
|
| 312 |
+
at::Tensor out_accum;
|
| 313 |
+
|
| 314 |
+
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
| 315 |
+
if (num_splits < 1) {
|
| 316 |
+
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
|
| 317 |
+
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
|
| 318 |
+
}
|
| 319 |
+
if (params.num_splits > 1) {
|
| 320 |
+
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
| 321 |
+
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
| 322 |
+
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
| 323 |
+
params.oaccum_ptr = out_accum.data_ptr();
|
| 324 |
+
}
|
| 325 |
+
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
return std::make_tuple(softmax_lse_accum, out_accum);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
void set_params_alibi(Flash_fwd_params ¶ms, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
|
| 332 |
+
#ifdef FLASHATTENTION_DISABLE_ALIBI
|
| 333 |
+
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
|
| 334 |
+
params.alibi_slopes_ptr = nullptr;
|
| 335 |
+
#else
|
| 336 |
+
if (alibi_slopes_.has_value()) {
|
| 337 |
+
auto alibi_slopes = alibi_slopes_.value();
|
| 338 |
+
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
| 339 |
+
CHECK_DEVICE(alibi_slopes);
|
| 340 |
+
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
| 341 |
+
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
| 342 |
+
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
| 343 |
+
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
| 344 |
+
} else {
|
| 345 |
+
params.alibi_slopes_ptr = nullptr;
|
| 346 |
+
}
|
| 347 |
+
#endif
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
std::vector<at::Tensor>
|
| 351 |
+
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
| 352 |
+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
| 353 |
+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
| 354 |
+
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
| 355 |
+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
| 356 |
+
const float p_dropout,
|
| 357 |
+
const float softmax_scale,
|
| 358 |
+
bool is_causal,
|
| 359 |
+
int window_size_left,
|
| 360 |
+
int window_size_right,
|
| 361 |
+
const float softcap,
|
| 362 |
+
const bool return_softmax,
|
| 363 |
+
std::optional<at::Generator> gen_) {
|
| 364 |
+
|
| 365 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 366 |
+
at::cuda::CUDAGuard device_guard{q.device()};
|
| 367 |
+
|
| 368 |
+
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
| 369 |
+
bool is_sm8x_min = cc_major >= 8;
|
| 370 |
+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
|
| 371 |
+
|
| 372 |
+
auto q_dtype = q.dtype();
|
| 373 |
+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
| 374 |
+
"FlashAttention only support fp16 and bf16 data type");
|
| 375 |
+
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
| 376 |
+
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
| 377 |
+
|
| 378 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 379 |
+
|
| 380 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 381 |
+
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 382 |
+
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 383 |
+
|
| 384 |
+
const auto sizes = q.sizes();
|
| 385 |
+
|
| 386 |
+
const int batch_size = sizes[0];
|
| 387 |
+
int seqlen_q = sizes[1];
|
| 388 |
+
int num_heads = sizes[2];
|
| 389 |
+
const int head_size = sizes[3];
|
| 390 |
+
const int seqlen_k = k.size(1);
|
| 391 |
+
const int num_heads_k = k.size(2);
|
| 392 |
+
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
| 393 |
+
TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
|
| 394 |
+
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
|
| 395 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 396 |
+
|
| 397 |
+
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
| 398 |
+
|
| 399 |
+
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
| 400 |
+
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
| 401 |
+
|
| 402 |
+
// causal=true is the same as causal=false in this case
|
| 403 |
+
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
| 404 |
+
if (is_causal) { window_size_right = 0; }
|
| 405 |
+
|
| 406 |
+
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
| 407 |
+
// H/t Daniel Haziza
|
| 408 |
+
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
|
| 409 |
+
const int ngroups = num_heads / num_heads_k;
|
| 410 |
+
if (seqlenq_ngroups_swapped) {
|
| 411 |
+
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
|
| 412 |
+
seqlen_q = ngroups;
|
| 413 |
+
num_heads = num_heads_k;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
| 417 |
+
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
| 418 |
+
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
| 419 |
+
|
| 420 |
+
at::Tensor out;
|
| 421 |
+
if (out_.has_value()) {
|
| 422 |
+
out = out_.value();
|
| 423 |
+
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
| 424 |
+
CHECK_DEVICE(out);
|
| 425 |
+
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
| 426 |
+
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
|
| 427 |
+
if (seqlenq_ngroups_swapped) {
|
| 428 |
+
out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
|
| 429 |
+
}
|
| 430 |
+
} else {
|
| 431 |
+
out = torch::empty_like(q);
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 435 |
+
const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
|
| 436 |
+
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 437 |
+
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 438 |
+
|
| 439 |
+
auto opts = q.options();
|
| 440 |
+
|
| 441 |
+
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
| 442 |
+
at::Tensor p;
|
| 443 |
+
// Only return softmax if there's dropout to reduce compilation time
|
| 444 |
+
if (return_softmax) {
|
| 445 |
+
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
| 446 |
+
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
|
| 447 |
+
}
|
| 448 |
+
else {
|
| 449 |
+
p = torch::empty({ 0 }, opts);
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
Flash_fwd_params params;
|
| 453 |
+
set_params_fprop(params,
|
| 454 |
+
batch_size,
|
| 455 |
+
seqlen_q, seqlen_k,
|
| 456 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
| 457 |
+
num_heads, num_heads_k,
|
| 458 |
+
head_size, head_size_rounded,
|
| 459 |
+
q, k, v, out,
|
| 460 |
+
/*cu_seqlens_q_d=*/nullptr,
|
| 461 |
+
/*cu_seqlens_k_d=*/nullptr,
|
| 462 |
+
/*seqused_k=*/nullptr,
|
| 463 |
+
return_softmax ? p.data_ptr() : nullptr,
|
| 464 |
+
softmax_lse.data_ptr(),
|
| 465 |
+
p_dropout,
|
| 466 |
+
softmax_scale,
|
| 467 |
+
window_size_left,
|
| 468 |
+
window_size_right,
|
| 469 |
+
softcap
|
| 470 |
+
);
|
| 471 |
+
|
| 472 |
+
// Keep references to these tensors to extend their lifetime
|
| 473 |
+
at::Tensor softmax_lse_accum, out_accum;
|
| 474 |
+
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
|
| 475 |
+
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
|
| 476 |
+
head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
|
| 477 |
+
|
| 478 |
+
// number of times random will be generated per thread, to offset philox counter in thc random
|
| 479 |
+
// state
|
| 480 |
+
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
| 481 |
+
int64_t counter_offset = params.b * params.h * 32;
|
| 482 |
+
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
| 483 |
+
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
| 484 |
+
// Forward kernel will populate memory with the seed and offset.
|
| 485 |
+
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
| 486 |
+
|
| 487 |
+
if (p_dropout > 0.0) {
|
| 488 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 489 |
+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 490 |
+
// See Note [Acquire lock when using random generators]
|
| 491 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 492 |
+
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
| 496 |
+
|
| 497 |
+
if (seqlen_k > 0) {
|
| 498 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 499 |
+
run_mha_fwd(params, stream);
|
| 500 |
+
} else {
|
| 501 |
+
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
| 502 |
+
out.zero_();
|
| 503 |
+
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
if (seqlenq_ngroups_swapped) {
|
| 507 |
+
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
|
| 508 |
+
q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
|
| 509 |
+
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
| 510 |
+
}
|
| 511 |
+
return {out, softmax_lse, p, rng_state};
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
std::vector<at::Tensor>
|
| 515 |
+
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
| 516 |
+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
| 517 |
+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
| 518 |
+
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
| 519 |
+
const at::Tensor &cu_seqlens_q, // b+1
|
| 520 |
+
const at::Tensor &cu_seqlens_k, // b+1
|
| 521 |
+
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
| 522 |
+
std::optional<const at::Tensor> &leftpad_k_, // batch_size
|
| 523 |
+
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
| 524 |
+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
| 525 |
+
int max_seqlen_q,
|
| 526 |
+
const int max_seqlen_k,
|
| 527 |
+
const float p_dropout,
|
| 528 |
+
const float softmax_scale,
|
| 529 |
+
const bool zero_tensors,
|
| 530 |
+
bool is_causal,
|
| 531 |
+
int window_size_left,
|
| 532 |
+
int window_size_right,
|
| 533 |
+
const float softcap,
|
| 534 |
+
const bool return_softmax,
|
| 535 |
+
std::optional<at::Generator> gen_) {
|
| 536 |
+
|
| 537 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 538 |
+
at::cuda::CUDAGuard device_guard{q.device()};
|
| 539 |
+
|
| 540 |
+
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
| 541 |
+
bool is_sm8x_min = cc_major >= 8;
|
| 542 |
+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
|
| 543 |
+
|
| 544 |
+
auto q_dtype = q.dtype();
|
| 545 |
+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
| 546 |
+
"FlashAttention only support fp16 and bf16 data type");
|
| 547 |
+
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
| 548 |
+
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
| 549 |
+
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
| 550 |
+
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
| 551 |
+
|
| 552 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 553 |
+
CHECK_DEVICE(cu_seqlens_q);
|
| 554 |
+
CHECK_DEVICE(cu_seqlens_k);
|
| 555 |
+
|
| 556 |
+
at::Tensor block_table;
|
| 557 |
+
const bool paged_KV = block_table_.has_value();
|
| 558 |
+
if (paged_KV) {
|
| 559 |
+
block_table = block_table_.value();
|
| 560 |
+
CHECK_DEVICE(block_table);
|
| 561 |
+
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
| 562 |
+
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 566 |
+
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 567 |
+
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 568 |
+
CHECK_CONTIGUOUS(cu_seqlens_q);
|
| 569 |
+
CHECK_CONTIGUOUS(cu_seqlens_k);
|
| 570 |
+
|
| 571 |
+
const auto sizes = q.sizes();
|
| 572 |
+
|
| 573 |
+
const int batch_size = cu_seqlens_q.numel() - 1;
|
| 574 |
+
int num_heads = sizes[1];
|
| 575 |
+
const int head_size = sizes[2];
|
| 576 |
+
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
|
| 577 |
+
|
| 578 |
+
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
| 579 |
+
|
| 580 |
+
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
| 581 |
+
const int num_blocks = !paged_KV ? 0 : k.size(0);
|
| 582 |
+
const int page_block_size = !paged_KV ? 1 : k.size(1);
|
| 583 |
+
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
|
| 584 |
+
|
| 585 |
+
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
| 586 |
+
if (is_causal) { window_size_right = 0; }
|
| 587 |
+
|
| 588 |
+
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
|
| 589 |
+
|
| 590 |
+
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
| 591 |
+
// H/t Daniel Haziza
|
| 592 |
+
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
|
| 593 |
+
const int ngroups = num_heads / num_heads_k;
|
| 594 |
+
if (seqlenq_ngroups_swapped) {
|
| 595 |
+
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
|
| 596 |
+
max_seqlen_q = ngroups;
|
| 597 |
+
num_heads = num_heads_k;
|
| 598 |
+
cu_seqlens_q_d = nullptr;
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
const int total_q = q.sizes()[0];
|
| 602 |
+
|
| 603 |
+
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
| 604 |
+
TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
|
| 605 |
+
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
|
| 606 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 607 |
+
|
| 608 |
+
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
| 609 |
+
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
| 610 |
+
|
| 611 |
+
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
| 612 |
+
if (!paged_KV) {
|
| 613 |
+
const int total_k = k.size(0);
|
| 614 |
+
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
| 615 |
+
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
| 616 |
+
} else {
|
| 617 |
+
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
|
| 618 |
+
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
|
| 619 |
+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
| 623 |
+
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
| 624 |
+
if (seqused_k.has_value()){
|
| 625 |
+
auto seqused_k_ = seqused_k.value();
|
| 626 |
+
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
| 627 |
+
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
|
| 628 |
+
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
|
| 629 |
+
CHECK_SHAPE(seqused_k_, batch_size);
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
at::Tensor out;
|
| 633 |
+
if (out_.has_value()) {
|
| 634 |
+
out = out_.value();
|
| 635 |
+
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
| 636 |
+
CHECK_DEVICE(out);
|
| 637 |
+
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
| 638 |
+
CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
|
| 639 |
+
if (seqlenq_ngroups_swapped) {
|
| 640 |
+
out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
|
| 641 |
+
}
|
| 642 |
+
} else {
|
| 643 |
+
out = torch::empty_like(q);
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 647 |
+
const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
|
| 648 |
+
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
| 649 |
+
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
| 650 |
+
|
| 651 |
+
auto opts = q.options();
|
| 652 |
+
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
|
| 653 |
+
at::Tensor p;
|
| 654 |
+
// Only return softmax if there's dropout to reduce compilation time
|
| 655 |
+
if (return_softmax) {
|
| 656 |
+
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
| 657 |
+
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
|
| 658 |
+
}
|
| 659 |
+
else {
|
| 660 |
+
p = torch::empty({ 0 }, opts);
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
if (zero_tensors) {
|
| 664 |
+
out.zero_();
|
| 665 |
+
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
| 666 |
+
if (return_softmax) {p.zero_();}
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
Flash_fwd_params params;
|
| 670 |
+
set_params_fprop(params,
|
| 671 |
+
batch_size,
|
| 672 |
+
max_seqlen_q, max_seqlen_k,
|
| 673 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
| 674 |
+
num_heads, num_heads_k,
|
| 675 |
+
head_size, head_size_rounded,
|
| 676 |
+
q, k, v, out,
|
| 677 |
+
cu_seqlens_q_d,
|
| 678 |
+
cu_seqlens_k.data_ptr(),
|
| 679 |
+
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
| 680 |
+
return_softmax ? p.data_ptr() : nullptr,
|
| 681 |
+
softmax_lse.data_ptr(),
|
| 682 |
+
p_dropout,
|
| 683 |
+
softmax_scale,
|
| 684 |
+
window_size_left,
|
| 685 |
+
window_size_right,
|
| 686 |
+
softcap,
|
| 687 |
+
seqlenq_ngroups_swapped,
|
| 688 |
+
/*unpadded_lse*/true);
|
| 689 |
+
params.total_q = total_q;
|
| 690 |
+
|
| 691 |
+
if (paged_KV) {
|
| 692 |
+
params.block_table = block_table.data_ptr<int>();
|
| 693 |
+
params.block_table_batch_stride = block_table.stride(0);
|
| 694 |
+
params.k_batch_stride = k.stride(0);
|
| 695 |
+
params.v_batch_stride = v.stride(0);
|
| 696 |
+
}
|
| 697 |
+
params.page_block_size = page_block_size;
|
| 698 |
+
// Keep references to these tensors to extend their lifetime
|
| 699 |
+
at::Tensor softmax_lse_accum, out_accum;
|
| 700 |
+
if (seqlenq_ngroups_swapped) {
|
| 701 |
+
// Only apply split-k for decoding
|
| 702 |
+
std::tie(softmax_lse_accum, out_accum) =
|
| 703 |
+
set_params_splitkv(params, batch_size, num_heads, head_size,
|
| 704 |
+
max_seqlen_k, max_seqlen_q, head_size_rounded,
|
| 705 |
+
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
if (leftpad_k_.has_value()) {
|
| 709 |
+
auto leftpad_k = leftpad_k_.value();
|
| 710 |
+
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
|
| 711 |
+
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
|
| 712 |
+
CHECK_DEVICE(leftpad_k);
|
| 713 |
+
CHECK_CONTIGUOUS(leftpad_k);
|
| 714 |
+
CHECK_SHAPE(leftpad_k, batch_size);
|
| 715 |
+
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
// number of times random will be generated per thread, to offset philox counter in thc random
|
| 719 |
+
// state
|
| 720 |
+
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
| 721 |
+
int64_t counter_offset = params.b * params.h * 32;
|
| 722 |
+
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
| 723 |
+
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
| 724 |
+
// Forward kernel will populate memory with the seed and offset.
|
| 725 |
+
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
| 726 |
+
|
| 727 |
+
if (p_dropout > 0.0) {
|
| 728 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 729 |
+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 730 |
+
// See Note [Acquire lock when using random generators]
|
| 731 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 732 |
+
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
| 736 |
+
|
| 737 |
+
if (max_seqlen_k > 0) {
|
| 738 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 739 |
+
run_mha_fwd(params, stream, paged_KV);
|
| 740 |
+
} else {
|
| 741 |
+
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
| 742 |
+
out.zero_();
|
| 743 |
+
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
if (seqlenq_ngroups_swapped) {
|
| 747 |
+
int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
|
| 748 |
+
int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
|
| 749 |
+
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
|
| 750 |
+
q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
|
| 751 |
+
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
return {out, softmax_lse, p, rng_state};
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 758 |
+
FP16_SWITCH(!params.is_bf16, [&] {
|
| 759 |
+
HEADDIM_SWITCH(params.d, [&] {
|
| 760 |
+
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
| 761 |
+
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
| 762 |
+
});
|
| 763 |
+
});
|
| 764 |
+
});
|
| 765 |
+
}
|
| 766 |
+
|
| 767 |
+
std::vector<at::Tensor>
|
| 768 |
+
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
| 769 |
+
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
| 770 |
+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
| 771 |
+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
| 772 |
+
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
| 773 |
+
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
| 774 |
+
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
| 775 |
+
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
| 776 |
+
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
| 777 |
+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
| 778 |
+
const float p_dropout, // probability to drop
|
| 779 |
+
const float softmax_scale,
|
| 780 |
+
const bool is_causal,
|
| 781 |
+
int window_size_left,
|
| 782 |
+
int window_size_right,
|
| 783 |
+
const float softcap,
|
| 784 |
+
const bool deterministic,
|
| 785 |
+
std::optional<at::Generator> gen_,
|
| 786 |
+
std::optional<at::Tensor> &rng_state) {
|
| 787 |
+
|
| 788 |
+
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
| 789 |
+
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
| 790 |
+
#endif
|
| 791 |
+
if (is_causal) { window_size_right = 0; }
|
| 792 |
+
|
| 793 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 794 |
+
at::cuda::CUDAGuard device_guard{q.device()};
|
| 795 |
+
|
| 796 |
+
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
| 797 |
+
bool is_sm8x_min = cc_major >= 8;
|
| 798 |
+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
|
| 799 |
+
|
| 800 |
+
bool is_dropout = p_dropout > 0.0;
|
| 801 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 802 |
+
|
| 803 |
+
auto q_dtype = q.dtype();
|
| 804 |
+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
| 805 |
+
"FlashAttention only support fp16 and bf16 data type");
|
| 806 |
+
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
| 807 |
+
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
| 808 |
+
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
| 809 |
+
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
| 810 |
+
|
| 811 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 812 |
+
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
| 813 |
+
|
| 814 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 815 |
+
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 816 |
+
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 817 |
+
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
| 818 |
+
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
| 819 |
+
|
| 820 |
+
const auto sizes = q.sizes();
|
| 821 |
+
|
| 822 |
+
const int batch_size = sizes[0];
|
| 823 |
+
const int seqlen_q = sizes[1];
|
| 824 |
+
const int num_heads = sizes[2];
|
| 825 |
+
const int head_size = sizes[3];
|
| 826 |
+
const int seqlen_k = k.size(1);
|
| 827 |
+
const int num_heads_k = k.size(2);
|
| 828 |
+
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
| 829 |
+
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
| 830 |
+
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
| 831 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 832 |
+
|
| 833 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 834 |
+
const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
|
| 835 |
+
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 836 |
+
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 837 |
+
|
| 838 |
+
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
| 839 |
+
|
| 840 |
+
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
| 841 |
+
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
| 842 |
+
|
| 843 |
+
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
| 844 |
+
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
| 845 |
+
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
| 846 |
+
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
|
| 847 |
+
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
|
| 848 |
+
|
| 849 |
+
at::Tensor dq, dk, dv;
|
| 850 |
+
if (dq_.has_value()) {
|
| 851 |
+
dq = dq_.value();
|
| 852 |
+
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
| 853 |
+
CHECK_DEVICE(dq);
|
| 854 |
+
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
| 855 |
+
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
| 856 |
+
} else {
|
| 857 |
+
dq = torch::empty_like(q);
|
| 858 |
+
}
|
| 859 |
+
if (dk_.has_value()) {
|
| 860 |
+
dk = dk_.value();
|
| 861 |
+
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
| 862 |
+
CHECK_DEVICE(dk);
|
| 863 |
+
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
| 864 |
+
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
| 865 |
+
} else {
|
| 866 |
+
dk = torch::empty_like(k);
|
| 867 |
+
}
|
| 868 |
+
if (dv_.has_value()) {
|
| 869 |
+
dv = dv_.value();
|
| 870 |
+
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
| 871 |
+
CHECK_DEVICE(dv);
|
| 872 |
+
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
| 873 |
+
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
| 874 |
+
} else {
|
| 875 |
+
dv = torch::empty_like(v);
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
// bool loop = seqlen_k > blocksize_c;
|
| 879 |
+
// TODO: change later, for now set to true for simplicity
|
| 880 |
+
bool loop = true;
|
| 881 |
+
|
| 882 |
+
auto opts = q.options();
|
| 883 |
+
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
| 884 |
+
at::Tensor dq_accum;
|
| 885 |
+
at::Tensor dk_accum, dv_accum;
|
| 886 |
+
if (loop) {
|
| 887 |
+
if (!deterministic) {
|
| 888 |
+
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
| 889 |
+
} else {
|
| 890 |
+
const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
|
| 891 |
+
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
| 892 |
+
}
|
| 893 |
+
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
| 894 |
+
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
| 895 |
+
}
|
| 896 |
+
|
| 897 |
+
at::Tensor dk_expanded, dv_expanded;
|
| 898 |
+
if (num_heads_k != num_heads) { // MQA / GQA
|
| 899 |
+
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
| 900 |
+
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
| 901 |
+
} else {
|
| 902 |
+
dk_expanded = dk;
|
| 903 |
+
dv_expanded = dv;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
Flash_bwd_params params;
|
| 907 |
+
|
| 908 |
+
set_params_dgrad(params,
|
| 909 |
+
batch_size,
|
| 910 |
+
seqlen_q, seqlen_k,
|
| 911 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
| 912 |
+
num_heads, num_heads_k,
|
| 913 |
+
head_size, head_size_rounded,
|
| 914 |
+
q, k, v, out,
|
| 915 |
+
dout, dq, dk_expanded, dv_expanded,
|
| 916 |
+
nullptr,
|
| 917 |
+
nullptr,
|
| 918 |
+
loop ? dq_accum.data_ptr() : nullptr,
|
| 919 |
+
// loop ? dk_accum.data_ptr() : nullptr,
|
| 920 |
+
// loop ? dv_accum.data_ptr() : nullptr,
|
| 921 |
+
nullptr,
|
| 922 |
+
nullptr,
|
| 923 |
+
softmax_lse.data_ptr(),
|
| 924 |
+
softmax_d.data_ptr(),
|
| 925 |
+
p_dropout,
|
| 926 |
+
softmax_scale,
|
| 927 |
+
window_size_left,
|
| 928 |
+
window_size_right,
|
| 929 |
+
softcap,
|
| 930 |
+
deterministic,
|
| 931 |
+
/*unpadded_lse*/false);
|
| 932 |
+
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
| 933 |
+
|
| 934 |
+
auto launch = &run_mha_bwd;
|
| 935 |
+
|
| 936 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 937 |
+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 938 |
+
|
| 939 |
+
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
| 940 |
+
int64_t counter_offset = params.b * params.h * 32;
|
| 941 |
+
|
| 942 |
+
if ( rng_state.has_value() ) {
|
| 943 |
+
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
| 944 |
+
} else if( is_dropout ) {
|
| 945 |
+
// See Note [Acquire lock when using random generators]
|
| 946 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 947 |
+
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 948 |
+
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
| 949 |
+
params.rng_state[0] = std::get<0>(seeds);
|
| 950 |
+
params.rng_state[1] = std::get<1>(seeds);
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
| 954 |
+
|
| 955 |
+
if (seqlen_q > 0) {
|
| 956 |
+
launch(params, stream);
|
| 957 |
+
} else {
|
| 958 |
+
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
| 959 |
+
dk_expanded.zero_();
|
| 960 |
+
dv_expanded.zero_();
|
| 961 |
+
softmax_d.zero_();
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
// For MQA/GQA we need to sum dK and dV across the groups
|
| 965 |
+
if (num_heads_k != num_heads) {
|
| 966 |
+
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
| 967 |
+
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
| 968 |
+
}
|
| 969 |
+
|
| 970 |
+
return { dq, dk, dv, softmax_d };
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
std::vector<at::Tensor>
|
| 974 |
+
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
| 975 |
+
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
| 976 |
+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
| 977 |
+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
| 978 |
+
const at::Tensor &out, // total_q x num_heads x head_size
|
| 979 |
+
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
|
| 980 |
+
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
| 981 |
+
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
| 982 |
+
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
| 983 |
+
const at::Tensor &cu_seqlens_q, // b+1
|
| 984 |
+
const at::Tensor &cu_seqlens_k, // b+1
|
| 985 |
+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
| 986 |
+
const int max_seqlen_q,
|
| 987 |
+
const int max_seqlen_k, // max sequence length to choose the kernel
|
| 988 |
+
const float p_dropout, // probability to drop
|
| 989 |
+
const float softmax_scale,
|
| 990 |
+
const bool zero_tensors,
|
| 991 |
+
const bool is_causal,
|
| 992 |
+
int window_size_left,
|
| 993 |
+
int window_size_right,
|
| 994 |
+
const float softcap,
|
| 995 |
+
const bool deterministic,
|
| 996 |
+
std::optional<at::Generator> gen_,
|
| 997 |
+
std::optional<at::Tensor> &rng_state) {
|
| 998 |
+
|
| 999 |
+
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
| 1000 |
+
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
| 1001 |
+
#endif
|
| 1002 |
+
if (is_causal) { window_size_right = 0; }
|
| 1003 |
+
|
| 1004 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 1005 |
+
at::cuda::CUDAGuard device_guard{q.device()};
|
| 1006 |
+
|
| 1007 |
+
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
| 1008 |
+
bool is_sm8x_min = cc_major >= 8;
|
| 1009 |
+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
|
| 1010 |
+
|
| 1011 |
+
bool is_dropout = p_dropout > 0.0;
|
| 1012 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 1013 |
+
|
| 1014 |
+
auto q_dtype = q.dtype();
|
| 1015 |
+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
| 1016 |
+
"FlashAttention only support fp16 and bf16 data type");
|
| 1017 |
+
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
| 1018 |
+
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
| 1019 |
+
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
| 1020 |
+
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
| 1021 |
+
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
| 1022 |
+
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
| 1023 |
+
|
| 1024 |
+
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 1025 |
+
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
| 1026 |
+
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
|
| 1027 |
+
|
| 1028 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1029 |
+
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1030 |
+
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1031 |
+
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
| 1032 |
+
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
| 1033 |
+
CHECK_CONTIGUOUS(cu_seqlens_q);
|
| 1034 |
+
CHECK_CONTIGUOUS(cu_seqlens_k);
|
| 1035 |
+
|
| 1036 |
+
const auto sizes = q.sizes();
|
| 1037 |
+
|
| 1038 |
+
const int total_q = sizes[0];
|
| 1039 |
+
const int batch_size = cu_seqlens_q.numel() - 1;
|
| 1040 |
+
const int num_heads = sizes[1];
|
| 1041 |
+
const int head_size = sizes[2];
|
| 1042 |
+
const int total_k = k.size(0);
|
| 1043 |
+
const int num_heads_k = k.size(1);
|
| 1044 |
+
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
| 1045 |
+
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
| 1046 |
+
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
| 1047 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 1048 |
+
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
| 1049 |
+
|
| 1050 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1051 |
+
const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
|
| 1052 |
+
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
| 1053 |
+
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
| 1054 |
+
|
| 1055 |
+
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
| 1056 |
+
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
| 1057 |
+
|
| 1058 |
+
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
| 1059 |
+
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
| 1060 |
+
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
| 1061 |
+
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
| 1062 |
+
CHECK_SHAPE(dout, total_q, num_heads, head_size);
|
| 1063 |
+
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
| 1064 |
+
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
| 1065 |
+
|
| 1066 |
+
at::Tensor dq, dk, dv;
|
| 1067 |
+
if (dq_.has_value()) {
|
| 1068 |
+
dq = dq_.value();
|
| 1069 |
+
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
| 1070 |
+
CHECK_DEVICE(dq);
|
| 1071 |
+
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
| 1072 |
+
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
| 1073 |
+
} else {
|
| 1074 |
+
dq = torch::empty_like(q);
|
| 1075 |
+
}
|
| 1076 |
+
if (dk_.has_value()) {
|
| 1077 |
+
dk = dk_.value();
|
| 1078 |
+
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
| 1079 |
+
CHECK_DEVICE(dk);
|
| 1080 |
+
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
| 1081 |
+
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
| 1082 |
+
} else {
|
| 1083 |
+
dk = torch::empty_like(k);
|
| 1084 |
+
}
|
| 1085 |
+
if (dv_.has_value()) {
|
| 1086 |
+
dv = dv_.value();
|
| 1087 |
+
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
| 1088 |
+
CHECK_DEVICE(dv);
|
| 1089 |
+
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
| 1090 |
+
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
| 1091 |
+
} else {
|
| 1092 |
+
dv = torch::empty_like(v);
|
| 1093 |
+
}
|
| 1094 |
+
|
| 1095 |
+
// bool loop = max_seqlen_k > blocksize_c;
|
| 1096 |
+
// TODO: change later, for now set to true for simplicity
|
| 1097 |
+
bool loop = true;
|
| 1098 |
+
|
| 1099 |
+
auto opts = q.options();
|
| 1100 |
+
auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
|
| 1101 |
+
at::Tensor dq_accum;
|
| 1102 |
+
if (loop) {
|
| 1103 |
+
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
|
| 1104 |
+
// because that would be too large if there is a very long sequence and the rest of the sequences are short.
|
| 1105 |
+
// Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
|
| 1106 |
+
// Note that 128 is the max block size on the seqlen_q dimension.
|
| 1107 |
+
// For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
|
| 1108 |
+
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
|
| 1109 |
+
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
|
| 1110 |
+
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
|
| 1111 |
+
// Same holds for softmax_d, since LSE is stored in unpadded format.
|
| 1112 |
+
if (!deterministic) {
|
| 1113 |
+
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
| 1114 |
+
} else {
|
| 1115 |
+
const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
|
| 1116 |
+
dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
| 1117 |
+
}
|
| 1118 |
+
}
|
| 1119 |
+
|
| 1120 |
+
at::Tensor dk_expanded, dv_expanded;
|
| 1121 |
+
if (num_heads_k != num_heads) { // MQA / GQA
|
| 1122 |
+
dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
|
| 1123 |
+
dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
|
| 1124 |
+
} else {
|
| 1125 |
+
dk_expanded = dk;
|
| 1126 |
+
dv_expanded = dv;
|
| 1127 |
+
}
|
| 1128 |
+
|
| 1129 |
+
if( zero_tensors ) {
|
| 1130 |
+
dq.zero_();
|
| 1131 |
+
dk_expanded.zero_();
|
| 1132 |
+
dv_expanded.zero_();
|
| 1133 |
+
softmax_d.zero_();
|
| 1134 |
+
}
|
| 1135 |
+
|
| 1136 |
+
Flash_bwd_params params;
|
| 1137 |
+
|
| 1138 |
+
set_params_dgrad(params,
|
| 1139 |
+
batch_size,
|
| 1140 |
+
max_seqlen_q, max_seqlen_k,
|
| 1141 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
| 1142 |
+
num_heads, num_heads_k,
|
| 1143 |
+
head_size, head_size_rounded,
|
| 1144 |
+
q, k, v, out,
|
| 1145 |
+
dout, dq, dk_expanded, dv_expanded,
|
| 1146 |
+
cu_seqlens_q.data_ptr(),
|
| 1147 |
+
cu_seqlens_k.data_ptr(),
|
| 1148 |
+
loop ? dq_accum.data_ptr() : nullptr,
|
| 1149 |
+
nullptr,
|
| 1150 |
+
nullptr,
|
| 1151 |
+
softmax_lse.data_ptr(),
|
| 1152 |
+
softmax_d.data_ptr(),
|
| 1153 |
+
p_dropout,
|
| 1154 |
+
softmax_scale,
|
| 1155 |
+
window_size_left,
|
| 1156 |
+
window_size_right,
|
| 1157 |
+
softcap,
|
| 1158 |
+
deterministic,
|
| 1159 |
+
/*unpadded_lse*/true);
|
| 1160 |
+
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
| 1161 |
+
params.total_q = total_q;
|
| 1162 |
+
|
| 1163 |
+
auto launch = &run_mha_bwd;
|
| 1164 |
+
|
| 1165 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
| 1166 |
+
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 1167 |
+
|
| 1168 |
+
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
| 1169 |
+
int64_t counter_offset = params.b * params.h * 32;
|
| 1170 |
+
|
| 1171 |
+
if ( rng_state.has_value() ) {
|
| 1172 |
+
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
| 1173 |
+
} else if( is_dropout ) {
|
| 1174 |
+
// See Note [Acquire lock when using random generators]
|
| 1175 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 1176 |
+
params.philox_args = gen->philox_cuda_state(counter_offset);
|
| 1177 |
+
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
| 1178 |
+
params.rng_state[0] = std::get<0>(seeds);
|
| 1179 |
+
params.rng_state[1] = std::get<1>(seeds);
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
| 1183 |
+
|
| 1184 |
+
if (max_seqlen_q > 0) {
|
| 1185 |
+
launch(params, stream);
|
| 1186 |
+
} else {
|
| 1187 |
+
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
| 1188 |
+
dk_expanded.zero_();
|
| 1189 |
+
dv_expanded.zero_();
|
| 1190 |
+
softmax_d.zero_();
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
// For MQA/GQA we need to sum dK and dV across the groups
|
| 1194 |
+
if (num_heads_k != num_heads) {
|
| 1195 |
+
at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
|
| 1196 |
+
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
|
| 1197 |
+
}
|
| 1198 |
+
|
| 1199 |
+
return { dq, dk, dv, softmax_d };
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
std::vector<at::Tensor>
|
| 1203 |
+
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
| 1204 |
+
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
| 1205 |
+
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
| 1206 |
+
std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
| 1207 |
+
std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
| 1208 |
+
std::optional<const at::Tensor> &seqlens_k_, // batch_size
|
| 1209 |
+
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
| 1210 |
+
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
| 1211 |
+
std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
|
| 1212 |
+
std::optional<const at::Tensor> &leftpad_k_, // batch_size
|
| 1213 |
+
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
| 1214 |
+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
| 1215 |
+
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
| 1216 |
+
const float softmax_scale,
|
| 1217 |
+
bool is_causal,
|
| 1218 |
+
int window_size_left,
|
| 1219 |
+
int window_size_right,
|
| 1220 |
+
const float softcap,
|
| 1221 |
+
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
| 1222 |
+
int num_splits
|
| 1223 |
+
) {
|
| 1224 |
+
|
| 1225 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 1226 |
+
at::cuda::CUDAGuard device_guard{q.device()};
|
| 1227 |
+
|
| 1228 |
+
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
| 1229 |
+
bool is_sm8x_min = cc_major >= 8;
|
| 1230 |
+
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
|
| 1231 |
+
|
| 1232 |
+
auto q_dtype = q.dtype();
|
| 1233 |
+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
| 1234 |
+
"FlashAttention only support fp16 and bf16 data type");
|
| 1235 |
+
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
| 1236 |
+
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
|
| 1237 |
+
|
| 1238 |
+
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
|
| 1239 |
+
|
| 1240 |
+
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1241 |
+
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1242 |
+
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1243 |
+
|
| 1244 |
+
at::Tensor block_table;
|
| 1245 |
+
const bool paged_KV = block_table_.has_value();
|
| 1246 |
+
if (paged_KV) {
|
| 1247 |
+
TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
|
| 1248 |
+
block_table = block_table_.value();
|
| 1249 |
+
CHECK_DEVICE(block_table);
|
| 1250 |
+
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
| 1251 |
+
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
| 1252 |
+
}
|
| 1253 |
+
|
| 1254 |
+
const auto sizes = q.sizes();
|
| 1255 |
+
|
| 1256 |
+
const int batch_size = sizes[0];
|
| 1257 |
+
int seqlen_q = sizes[1];
|
| 1258 |
+
int num_heads = sizes[2];
|
| 1259 |
+
const int head_size_og = sizes[3];
|
| 1260 |
+
|
| 1261 |
+
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
| 1262 |
+
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
|
| 1263 |
+
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
|
| 1264 |
+
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
|
| 1265 |
+
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
|
| 1266 |
+
const int num_heads_k = kcache.size(2);
|
| 1267 |
+
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
|
| 1268 |
+
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
| 1269 |
+
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
| 1270 |
+
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 1271 |
+
|
| 1272 |
+
// causal=true is the same as causal=false in this case
|
| 1273 |
+
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
| 1274 |
+
if (is_causal) { window_size_right = 0; }
|
| 1275 |
+
|
| 1276 |
+
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
| 1277 |
+
// H/t Daniel Haziza
|
| 1278 |
+
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
| 1279 |
+
if (seqlenq_ngroups_swapped) {
|
| 1280 |
+
const int ngroups = num_heads / num_heads_k;
|
| 1281 |
+
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
| 1282 |
+
seqlen_q = ngroups;
|
| 1283 |
+
num_heads = num_heads_k;
|
| 1284 |
+
}
|
| 1285 |
+
|
| 1286 |
+
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
| 1287 |
+
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
| 1288 |
+
|
| 1289 |
+
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
| 1290 |
+
if (!paged_KV) {
|
| 1291 |
+
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
| 1292 |
+
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
| 1293 |
+
} else {
|
| 1294 |
+
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
|
| 1295 |
+
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
|
| 1296 |
+
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
| 1297 |
+
}
|
| 1298 |
+
|
| 1299 |
+
at::Tensor q_padded, kcache_padded, vcache_padded;
|
| 1300 |
+
if (head_size_og % 8 != 0) {
|
| 1301 |
+
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
| 1302 |
+
kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
| 1303 |
+
vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
| 1304 |
+
} else {
|
| 1305 |
+
q_padded = q;
|
| 1306 |
+
kcache_padded = kcache;
|
| 1307 |
+
vcache_padded = vcache;
|
| 1308 |
+
}
|
| 1309 |
+
|
| 1310 |
+
at::Tensor out;
|
| 1311 |
+
if (out_.has_value()) {
|
| 1312 |
+
out = out_.value();
|
| 1313 |
+
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
| 1314 |
+
CHECK_DEVICE(out);
|
| 1315 |
+
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
| 1316 |
+
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
| 1317 |
+
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
| 1318 |
+
} else {
|
| 1319 |
+
out = torch::empty_like(q_padded);
|
| 1320 |
+
}
|
| 1321 |
+
|
| 1322 |
+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1323 |
+
const int head_size = round_multiple(head_size_og, 8);
|
| 1324 |
+
const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
|
| 1325 |
+
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 1326 |
+
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 1327 |
+
|
| 1328 |
+
auto opts = q.options();
|
| 1329 |
+
|
| 1330 |
+
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
| 1331 |
+
|
| 1332 |
+
Flash_fwd_params params;
|
| 1333 |
+
set_params_fprop(params,
|
| 1334 |
+
batch_size,
|
| 1335 |
+
seqlen_q, seqlen_k,
|
| 1336 |
+
seqlen_q_rounded, seqlen_k_rounded,
|
| 1337 |
+
num_heads, num_heads_k,
|
| 1338 |
+
head_size, head_size_rounded,
|
| 1339 |
+
q_padded, kcache_padded, vcache_padded, out,
|
| 1340 |
+
/*cu_seqlens_q_d=*/nullptr,
|
| 1341 |
+
/*cu_seqlens_k_d=*/nullptr,
|
| 1342 |
+
/*seqused_k=*/nullptr,
|
| 1343 |
+
/*p_ptr=*/nullptr,
|
| 1344 |
+
softmax_lse.data_ptr(),
|
| 1345 |
+
/*p_dropout=*/0.f,
|
| 1346 |
+
softmax_scale,
|
| 1347 |
+
window_size_left,
|
| 1348 |
+
window_size_right,
|
| 1349 |
+
softcap
|
| 1350 |
+
);
|
| 1351 |
+
|
| 1352 |
+
at::Tensor k, v, k_padded, v_padded;
|
| 1353 |
+
if (k_.has_value()) {
|
| 1354 |
+
TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
|
| 1355 |
+
TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
|
| 1356 |
+
TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
|
| 1357 |
+
k = k_.value();
|
| 1358 |
+
v = v_.value();
|
| 1359 |
+
TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
|
| 1360 |
+
TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
|
| 1361 |
+
CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 1362 |
+
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
|
| 1363 |
+
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
|
| 1364 |
+
int seqlen_knew = k.size(1);
|
| 1365 |
+
CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
|
| 1366 |
+
CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
|
| 1367 |
+
if (head_size_og % 8 != 0) {
|
| 1368 |
+
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
| 1369 |
+
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
| 1370 |
+
} else {
|
| 1371 |
+
k_padded = k;
|
| 1372 |
+
v_padded = v;
|
| 1373 |
+
}
|
| 1374 |
+
params.seqlen_knew = seqlen_knew;
|
| 1375 |
+
params.knew_ptr = k_padded.data_ptr();
|
| 1376 |
+
params.vnew_ptr = v_padded.data_ptr();
|
| 1377 |
+
// All stride are in elements, not bytes.
|
| 1378 |
+
params.knew_batch_stride = k_padded.stride(0);
|
| 1379 |
+
params.vnew_batch_stride = v_padded.stride(0);
|
| 1380 |
+
params.knew_row_stride = k_padded.stride(-3);
|
| 1381 |
+
params.vnew_row_stride = v_padded.stride(-3);
|
| 1382 |
+
params.knew_head_stride = k_padded.stride(-2);
|
| 1383 |
+
params.vnew_head_stride = v_padded.stride(-2);
|
| 1384 |
+
}
|
| 1385 |
+
|
| 1386 |
+
if (seqlens_k_.has_value()) {
|
| 1387 |
+
auto seqlens_k = seqlens_k_.value();
|
| 1388 |
+
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
|
| 1389 |
+
CHECK_DEVICE(seqlens_k);
|
| 1390 |
+
CHECK_CONTIGUOUS(seqlens_k);
|
| 1391 |
+
CHECK_SHAPE(seqlens_k, batch_size);
|
| 1392 |
+
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
|
| 1393 |
+
}
|
| 1394 |
+
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
|
| 1395 |
+
if (leftpad_k_.has_value()) {
|
| 1396 |
+
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
|
| 1397 |
+
auto leftpad_k = leftpad_k_.value();
|
| 1398 |
+
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
|
| 1399 |
+
CHECK_DEVICE(leftpad_k);
|
| 1400 |
+
CHECK_CONTIGUOUS(leftpad_k);
|
| 1401 |
+
CHECK_SHAPE(leftpad_k, batch_size);
|
| 1402 |
+
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
|
| 1403 |
+
}
|
| 1404 |
+
|
| 1405 |
+
if (rotary_cos_.has_value()) {
|
| 1406 |
+
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
|
| 1407 |
+
auto rotary_cos = rotary_cos_.value();
|
| 1408 |
+
CHECK_DEVICE(rotary_cos);
|
| 1409 |
+
params.rotary_dim = rotary_cos.size(1) * 2;
|
| 1410 |
+
TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
|
| 1411 |
+
TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
|
| 1412 |
+
const int seqlen_ro = rotary_cos.size(0);
|
| 1413 |
+
TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
|
| 1414 |
+
CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
|
| 1415 |
+
CHECK_CONTIGUOUS(rotary_cos);
|
| 1416 |
+
TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
|
| 1417 |
+
|
| 1418 |
+
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
|
| 1419 |
+
auto rotary_sin = rotary_sin_.value();
|
| 1420 |
+
CHECK_DEVICE(rotary_sin);
|
| 1421 |
+
CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
|
| 1422 |
+
CHECK_CONTIGUOUS(rotary_sin);
|
| 1423 |
+
TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
|
| 1424 |
+
params.rotary_cos_ptr = rotary_cos.data_ptr();
|
| 1425 |
+
params.rotary_sin_ptr = rotary_sin.data_ptr();
|
| 1426 |
+
params.is_rotary_interleaved = is_rotary_interleaved;
|
| 1427 |
+
} else {
|
| 1428 |
+
params.rotary_dim = 0;
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
if (cache_batch_idx_.has_value()) {
|
| 1432 |
+
auto cache_batch_idx = cache_batch_idx_.value();
|
| 1433 |
+
CHECK_DEVICE(cache_batch_idx);
|
| 1434 |
+
CHECK_CONTIGUOUS(cache_batch_idx);
|
| 1435 |
+
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
|
| 1436 |
+
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
|
| 1437 |
+
}
|
| 1438 |
+
|
| 1439 |
+
// Keep references to these tensors to extend their lifetime
|
| 1440 |
+
at::Tensor softmax_lse_accum, out_accum;
|
| 1441 |
+
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
|
| 1442 |
+
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
|
| 1443 |
+
head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts);
|
| 1444 |
+
|
| 1445 |
+
if (paged_KV) {
|
| 1446 |
+
params.block_table = block_table.data_ptr<int>();
|
| 1447 |
+
params.block_table_batch_stride = block_table.stride(0);
|
| 1448 |
+
}
|
| 1449 |
+
params.page_block_size = page_block_size;
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
| 1453 |
+
|
| 1454 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 1455 |
+
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
|
| 1456 |
+
// or paged KV cache
|
| 1457 |
+
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
|
| 1458 |
+
|
| 1459 |
+
if (head_size_og % 8 != 0) {
|
| 1460 |
+
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
| 1461 |
+
if (out_.has_value()) { out_.value().copy_(out); }
|
| 1462 |
+
if (k_.has_value()) {
|
| 1463 |
+
// It's expensive to copy the KV cache here for the case where head size not divisible by 8,
|
| 1464 |
+
// but we don't expect to get this case in practice. This is just so that the code works for that case.
|
| 1465 |
+
kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
|
| 1466 |
+
vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
|
| 1467 |
+
}
|
| 1468 |
+
}
|
| 1469 |
+
|
| 1470 |
+
if (seqlenq_ngroups_swapped) {
|
| 1471 |
+
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
| 1472 |
+
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
| 1473 |
+
}
|
| 1474 |
+
return {out, softmax_lse};
|
| 1475 |
+
}
|
| 1476 |
+
} // namespace FLASH_NAMESPACE
|
| 1477 |
+
|
| 1478 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 1479 |
+
m.doc() = "FlashAttention";
|
| 1480 |
+
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
|
| 1481 |
+
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
|
| 1482 |
+
m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
|
| 1483 |
+
m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
|
| 1484 |
+
m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
|
| 1485 |
+
}
|
cookbooks/flash-attention/csrc/flash_attn/src/alibi.h
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cmath>
|
| 2 |
+
|
| 3 |
+
#include "namespace_config.h"
|
| 4 |
+
#include <cute/tensor.hpp>
|
| 5 |
+
|
| 6 |
+
#include <cutlass/cutlass.h>
|
| 7 |
+
#include <cutlass/array.h>
|
| 8 |
+
|
| 9 |
+
#include "utils.h"
|
| 10 |
+
|
| 11 |
+
namespace FLASH_NAMESPACE {
|
| 12 |
+
|
| 13 |
+
using namespace cute;
|
| 14 |
+
|
| 15 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 16 |
+
|
| 17 |
+
template <bool Is_causal>
|
| 18 |
+
struct Alibi {
|
| 19 |
+
|
| 20 |
+
const float alibi_slope;
|
| 21 |
+
const int max_seqlen_k, max_seqlen_q;
|
| 22 |
+
|
| 23 |
+
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
| 24 |
+
: alibi_slope(alibi_slope)
|
| 25 |
+
, max_seqlen_k(max_seqlen_k)
|
| 26 |
+
, max_seqlen_q(max_seqlen_q) {
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
template <typename Engine, typename Layout>
|
| 31 |
+
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
| 32 |
+
const int col_idx_offset_,
|
| 33 |
+
const int row_idx_offset,
|
| 34 |
+
const int warp_row_stride) {
|
| 35 |
+
// tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
| 36 |
+
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
| 37 |
+
const int lane_id = threadIdx.x % 32;
|
| 38 |
+
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
| 39 |
+
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
| 40 |
+
#pragma unroll
|
| 41 |
+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
| 42 |
+
const int col_idx_base = col_idx_offset + nj * 8;
|
| 43 |
+
#pragma unroll
|
| 44 |
+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
| 45 |
+
const int col_idx = col_idx_base + j;
|
| 46 |
+
#pragma unroll
|
| 47 |
+
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
| 48 |
+
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
} else { // Bias depends on both row_idx and col_idx
|
| 53 |
+
#pragma unroll
|
| 54 |
+
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
| 55 |
+
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
| 56 |
+
#pragma unroll
|
| 57 |
+
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
| 58 |
+
const int row_idx = row_idx_base + i * 8;
|
| 59 |
+
#pragma unroll
|
| 60 |
+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
| 61 |
+
const int col_idx_base = col_idx_offset + nj * 8;
|
| 62 |
+
#pragma unroll
|
| 63 |
+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
| 64 |
+
const int col_idx = col_idx_base + j;
|
| 65 |
+
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/block_info.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include "namespace_config.h"
|
| 8 |
+
namespace FLASH_NAMESPACE {
|
| 9 |
+
|
| 10 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 11 |
+
|
| 12 |
+
template<bool Varlen=true>
|
| 13 |
+
struct BlockInfo {
|
| 14 |
+
|
| 15 |
+
template<typename Params>
|
| 16 |
+
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
| 17 |
+
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
| 18 |
+
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
| 19 |
+
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
| 20 |
+
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
| 21 |
+
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
| 22 |
+
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
|
| 23 |
+
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
|
| 24 |
+
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
| 25 |
+
{
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
template <typename index_t>
|
| 29 |
+
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
| 30 |
+
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
template <typename index_t>
|
| 34 |
+
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
| 35 |
+
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
const int sum_s_q;
|
| 39 |
+
const int sum_s_k;
|
| 40 |
+
const int actual_seqlen_q;
|
| 41 |
+
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
| 42 |
+
const int leftpad_k;
|
| 43 |
+
const int seqlen_k_cache;
|
| 44 |
+
const int actual_seqlen_k;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/dropout.h
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include "namespace_config.h"
|
| 8 |
+
#include "philox.cuh"
|
| 9 |
+
#include "utils.h"
|
| 10 |
+
|
| 11 |
+
namespace FLASH_NAMESPACE {
|
| 12 |
+
|
| 13 |
+
struct Dropout {
|
| 14 |
+
|
| 15 |
+
const unsigned long long seed, offset;
|
| 16 |
+
const uint8_t p_dropout_in_uint8_t;
|
| 17 |
+
|
| 18 |
+
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
|
| 19 |
+
const uint8_t p_dropout_in_uint8_t,
|
| 20 |
+
const int bid, const int hid, const int tid, const int nheads)
|
| 21 |
+
: seed(seed)
|
| 22 |
+
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
|
| 23 |
+
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
| 27 |
+
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
|
| 28 |
+
int block_row_start, int block_col_start, int block_row_stride) {
|
| 29 |
+
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
|
| 30 |
+
Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));
|
| 31 |
+
using T = typename Engine::value_type;
|
| 32 |
+
auto encode_dropout = [](bool keep, T val) {
|
| 33 |
+
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
| 34 |
+
};
|
| 35 |
+
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
| 36 |
+
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
| 37 |
+
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
| 38 |
+
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
| 39 |
+
#pragma unroll
|
| 40 |
+
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
| 41 |
+
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
| 42 |
+
#pragma unroll
|
| 43 |
+
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
| 44 |
+
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
| 45 |
+
uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
| 46 |
+
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
| 47 |
+
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
| 48 |
+
// Special implementation for 16-bit types: we duplicate the threshold to the
|
| 49 |
+
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
| 50 |
+
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
| 51 |
+
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
| 52 |
+
// the random value is less than the threshold.
|
| 53 |
+
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
| 54 |
+
// We're exploiting the fact that floating point comparison is equivalent to integer
|
| 55 |
+
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
| 56 |
+
if (!encode_dropout_in_sign_bit
|
| 57 |
+
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
| 58 |
+
uint16_t rnd_16[16];
|
| 59 |
+
#pragma unroll
|
| 60 |
+
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
| 61 |
+
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
| 62 |
+
#pragma unroll
|
| 63 |
+
for (int j = 0; j < 2; j++) {
|
| 64 |
+
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
| 65 |
+
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
| 66 |
+
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
| 67 |
+
#pragma unroll
|
| 68 |
+
for (int i = 0; i < 4; i++) {
|
| 69 |
+
uint32_t mask;
|
| 70 |
+
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
| 71 |
+
tensor_uint32(i) &= mask;
|
| 72 |
+
}
|
| 73 |
+
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
| 74 |
+
}
|
| 75 |
+
} else {
|
| 76 |
+
#pragma unroll
|
| 77 |
+
for (int j = 0; j < 2; j++) {
|
| 78 |
+
#pragma unroll
|
| 79 |
+
for (int i = 0; i < 8; i++) {
|
| 80 |
+
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
| 81 |
+
}
|
| 82 |
+
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
| 83 |
+
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
| 87 |
+
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
| 88 |
+
// // }
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash.h
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include "namespace_config.h"
|
| 8 |
+
|
| 9 |
+
#include <cuda.h>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
|
| 13 |
+
|
| 14 |
+
namespace FLASH_NAMESPACE {
|
| 15 |
+
constexpr int TOTAL_DIM = 0;
|
| 16 |
+
constexpr int H_DIM = 1;
|
| 17 |
+
constexpr int D_DIM = 2;
|
| 18 |
+
|
| 19 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 20 |
+
|
| 21 |
+
struct Qkv_params {
|
| 22 |
+
using index_t = int64_t;
|
| 23 |
+
// The QKV matrices.
|
| 24 |
+
void *__restrict__ q_ptr;
|
| 25 |
+
void *__restrict__ k_ptr;
|
| 26 |
+
void *__restrict__ v_ptr;
|
| 27 |
+
|
| 28 |
+
// The stride between rows of the Q, K and V matrices.
|
| 29 |
+
index_t q_batch_stride;
|
| 30 |
+
index_t k_batch_stride;
|
| 31 |
+
index_t v_batch_stride;
|
| 32 |
+
index_t q_row_stride;
|
| 33 |
+
index_t k_row_stride;
|
| 34 |
+
index_t v_row_stride;
|
| 35 |
+
index_t q_head_stride;
|
| 36 |
+
index_t k_head_stride;
|
| 37 |
+
index_t v_head_stride;
|
| 38 |
+
|
| 39 |
+
// The number of heads.
|
| 40 |
+
int h, h_k;
|
| 41 |
+
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
| 42 |
+
// different from nheads (query).
|
| 43 |
+
int h_h_k_ratio; // precompute h / h_k,
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
struct Flash_fwd_params : public Qkv_params {
|
| 49 |
+
|
| 50 |
+
// The O matrix (output).
|
| 51 |
+
void * __restrict__ o_ptr;
|
| 52 |
+
void * __restrict__ oaccum_ptr;
|
| 53 |
+
|
| 54 |
+
// The stride between rows of O.
|
| 55 |
+
index_t o_batch_stride;
|
| 56 |
+
index_t o_row_stride;
|
| 57 |
+
index_t o_head_stride;
|
| 58 |
+
|
| 59 |
+
// The pointer to the P matrix.
|
| 60 |
+
void * __restrict__ p_ptr;
|
| 61 |
+
|
| 62 |
+
// The pointer to the softmax sum.
|
| 63 |
+
void * __restrict__ softmax_lse_ptr;
|
| 64 |
+
void * __restrict__ softmax_lseaccum_ptr;
|
| 65 |
+
|
| 66 |
+
// The dimensions.
|
| 67 |
+
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
|
| 68 |
+
|
| 69 |
+
// The scaling factors for the kernel.
|
| 70 |
+
float scale_softmax;
|
| 71 |
+
float scale_softmax_log2;
|
| 72 |
+
|
| 73 |
+
// array of length b+1 holding starting offset of each sequence.
|
| 74 |
+
int * __restrict__ cu_seqlens_q;
|
| 75 |
+
int * __restrict__ cu_seqlens_k;
|
| 76 |
+
int * __restrict__ leftpad_k;
|
| 77 |
+
|
| 78 |
+
// If provided, the actual length of each k sequence.
|
| 79 |
+
int * __restrict__ seqused_k;
|
| 80 |
+
|
| 81 |
+
int *__restrict__ blockmask;
|
| 82 |
+
|
| 83 |
+
// The K_new and V_new matrices.
|
| 84 |
+
void * __restrict__ knew_ptr;
|
| 85 |
+
void * __restrict__ vnew_ptr;
|
| 86 |
+
|
| 87 |
+
// The stride between rows of the Q, K and V matrices.
|
| 88 |
+
index_t knew_batch_stride;
|
| 89 |
+
index_t vnew_batch_stride;
|
| 90 |
+
index_t knew_row_stride;
|
| 91 |
+
index_t vnew_row_stride;
|
| 92 |
+
index_t knew_head_stride;
|
| 93 |
+
index_t vnew_head_stride;
|
| 94 |
+
|
| 95 |
+
// The cos and sin matrices for rotary embedding.
|
| 96 |
+
void * __restrict__ rotary_cos_ptr;
|
| 97 |
+
void * __restrict__ rotary_sin_ptr;
|
| 98 |
+
|
| 99 |
+
// The indices to index into the KV cache.
|
| 100 |
+
int * __restrict__ cache_batch_idx;
|
| 101 |
+
|
| 102 |
+
// Paged KV cache
|
| 103 |
+
int * __restrict__ block_table;
|
| 104 |
+
index_t block_table_batch_stride;
|
| 105 |
+
int page_block_size;
|
| 106 |
+
|
| 107 |
+
// The dropout probability (probability of keeping an activation).
|
| 108 |
+
float p_dropout;
|
| 109 |
+
// uint32_t p_dropout_in_uint;
|
| 110 |
+
// uint16_t p_dropout_in_uint16_t;
|
| 111 |
+
uint8_t p_dropout_in_uint8_t;
|
| 112 |
+
|
| 113 |
+
// Scale factor of 1 / (1 - p_dropout).
|
| 114 |
+
float rp_dropout;
|
| 115 |
+
float scale_softmax_rp_dropout;
|
| 116 |
+
|
| 117 |
+
// Local window size
|
| 118 |
+
int window_size_left, window_size_right;
|
| 119 |
+
float softcap;
|
| 120 |
+
|
| 121 |
+
// Random state.
|
| 122 |
+
at::PhiloxCudaState philox_args;
|
| 123 |
+
|
| 124 |
+
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
| 125 |
+
uint64_t * rng_state;
|
| 126 |
+
|
| 127 |
+
bool is_bf16;
|
| 128 |
+
bool is_causal;
|
| 129 |
+
|
| 130 |
+
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
| 131 |
+
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
| 132 |
+
bool is_seqlens_k_cumulative;
|
| 133 |
+
|
| 134 |
+
bool is_rotary_interleaved;
|
| 135 |
+
|
| 136 |
+
int num_splits; // For split-KV version
|
| 137 |
+
|
| 138 |
+
void * __restrict__ alibi_slopes_ptr;
|
| 139 |
+
index_t alibi_slopes_batch_stride;
|
| 140 |
+
|
| 141 |
+
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
|
| 142 |
+
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 146 |
+
|
| 147 |
+
struct Flash_bwd_params : public Flash_fwd_params {
|
| 148 |
+
|
| 149 |
+
// The dO and dQKV matrices.
|
| 150 |
+
void *__restrict__ do_ptr;
|
| 151 |
+
void *__restrict__ dq_ptr;
|
| 152 |
+
void *__restrict__ dk_ptr;
|
| 153 |
+
void *__restrict__ dv_ptr;
|
| 154 |
+
|
| 155 |
+
// To accumulate dQ
|
| 156 |
+
void *__restrict__ dq_accum_ptr;
|
| 157 |
+
void *__restrict__ dk_accum_ptr;
|
| 158 |
+
void *__restrict__ dv_accum_ptr;
|
| 159 |
+
|
| 160 |
+
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
| 161 |
+
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
| 162 |
+
// dv_accum_ptr;
|
| 163 |
+
|
| 164 |
+
// The stride between rows of the dO, dQ, dK and dV matrices.
|
| 165 |
+
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
| 166 |
+
// The code probably won't work for arrays larger than 2GB.
|
| 167 |
+
index_t do_batch_stride;
|
| 168 |
+
index_t do_row_stride;
|
| 169 |
+
index_t do_head_stride;
|
| 170 |
+
index_t dq_batch_stride;
|
| 171 |
+
index_t dk_batch_stride;
|
| 172 |
+
index_t dv_batch_stride;
|
| 173 |
+
index_t dq_row_stride;
|
| 174 |
+
index_t dk_row_stride;
|
| 175 |
+
index_t dv_row_stride;
|
| 176 |
+
index_t dq_head_stride;
|
| 177 |
+
index_t dk_head_stride;
|
| 178 |
+
index_t dv_head_stride;
|
| 179 |
+
|
| 180 |
+
// The pointer to the softmax d sum.
|
| 181 |
+
void *__restrict__ dsoftmax_sum;
|
| 182 |
+
|
| 183 |
+
bool deterministic;
|
| 184 |
+
index_t dq_accum_split_stride;
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 188 |
+
|
| 189 |
+
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 190 |
+
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 191 |
+
|
| 192 |
+
template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
| 193 |
+
|
| 194 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|
cookbooks/flash-attention/csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2024, Tri Dao.
|
| 2 |
+
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
+
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
+
#include "namespace_config.h"
|
| 5 |
+
#include "flash_bwd_launch_template.h"
|
| 6 |
+
|
| 7 |
+
namespace FLASH_NAMESPACE {
|
| 8 |
+
|
| 9 |
+
template<>
|
| 10 |
+
void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
+
run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
} // namespace FLASH_NAMESPACE
|