flash-sparse-attention
Flash Sparse Attention Triton kernels — dense, sparse, and gated attention with forward, backward, and decode paths.
Supported backends
cudarocmxpu
Usage
from kernels import get_kernel
fsa = get_kernel("JingzeShi/flash-sparse-attn", version=1, trust_remote_code=True)
# Dense forward
out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)
# Decode with KV cache
out = fsa.flash_dense_attn_with_kvcache_func(q, k, v)
# Sparse attention
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)
# Gated attention
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)
Available functions
flash_dense_attn_funcflash_dense_attn_with_kvcache_funcflash_dense_attn_varlen_funcflash_dense_attn_varlen_with_kvcache_funcflash_sparse_attn_funcflash_sparse_attn_with_kvcache_funcflash_sparse_attn_varlen_funcflash_sparse_attn_varlen_with_kvcache_funcflash_gated_attn_funcflash_gated_attn_with_kvcache_funcflash_gated_attn_varlen_funcflash_gated_attn_varlen_with_kvcache_func
Source
Originally from HKUSTDial/flash-sparse-attention.
- Downloads last month
- -
bsd-3-clause
Supported hardwares new
CUDA
ROCm
XPU








