flash-sparse-attention
Flash Sparse Attention Triton kernels — dense, sparse, and gated attention with forward, backward, and decode paths.
Usage
from kernels import get_kernel
fsa = get_kernel("JingzeShi/flash-sparse-attention", version=1)
# Dense forward
out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)
# Decode with KV cache
out = fsa.flash_dense_attn_with_kvcache_func(q, k, v)
# Sparse attention
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)
# Gated attention
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)
Available functions
flash_dense_attn_funcflash_dense_attn_with_kvcache_funcflash_dense_attn_varlen_funcflash_dense_attn_varlen_with_kvcache_funcflash_sparse_attn_funcflash_sparse_attn_with_kvcache_funcflash_sparse_attn_varlen_funcflash_sparse_attn_varlen_with_kvcache_funcflash_gated_attn_funcflash_gated_attn_with_kvcache_funcflash_gated_attn_varlen_funcflash_gated_attn_varlen_with_kvcache_func
Source
Originally from HKUSTDial/flash-sparse-attention.
- Downloads last month
- 8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support