Instructions to use kernels-community/flash-attn2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-attn2 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-attn2") - Notebooks
- Google Colab
- Kaggle
File size: 5,264 Bytes
a7165c8 c743a32 a7165c8 c743a32 9002ff5 d774688 dd2f0f9 9002ff5 a7165c8 c743a32 a7165c8 39b4aba 876ac68 b0d3c12 9002ff5 b0d3c12 9002ff5 a7165c8 876ac68 b0d3c12 9002ff5 b0d3c12 a7165c8 c743a32 876ac68 b0d3c12 9002ff5 b0d3c12 a7165c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | [general]
name = "flash_attn"
universal=false
[torch]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
[kernel.flash_attn]
backend = "cuda"
cuda-capabilities = [
"8.0",
"9.0",
"10.0",
"12.0",
]
src = [
"flash_attn/flash_api.cpp",
"flash_attn/src/philox_unpack.cuh",
"flash_attn/src/namespace_config.h",
"flash_attn/src/hardware_info.h",
"flash_attn/src/flash.h",
"flash_attn/src/static_switch.h",
"flash_attn/src/alibi.h",
"flash_attn/src/block_info.h",
"flash_attn/src/dropout.h",
"flash_attn/src/kernel_traits.h",
"flash_attn/src/mask.h",
"flash_attn/src/philox.cuh",
"flash_attn/src/rotary.h",
"flash_attn/src/softmax.h",
"flash_attn/src/utils.h",
# bwd kernels - commented out since mha_bwd functions are disabled
"flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
"flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"flash_attn/src/flash_bwd_kernel.h",
"flash_attn/src/flash_bwd_launch_template.h",
"flash_attn/src/flash_bwd_preprocess_kernel.h",
## fwd kernels - keeping only FP16 kernels for hdim 64 and 128 (both causal and non-causal)
"flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"flash_attn/src/flash_fwd_kernel.h",
"flash_attn/src/flash_fwd_launch_template.h",
# split kernels - keeping only FP16 kernels for hdim 64 and 128 (both causal and non-causal)
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
]
depends = ["torch", "cutlass_3_6"]
|