Kernels

flash-sparse-attention

Flash Sparse Attention Triton kernels — dense, sparse, and gated attention with forward, backward, and decode paths.

Usage

from kernels import get_kernel

fsa = get_kernel("JingzeShi/flash-sparse-attention", version=1)

# Dense forward
out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)

# Decode with KV cache
out = fsa.flash_dense_attn_with_kvcache_func(q, k, v)

# Sparse attention
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)

# Gated attention
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)

Available functions

  • flash_dense_attn_func
  • flash_dense_attn_with_kvcache_func
  • flash_dense_attn_varlen_func
  • flash_dense_attn_varlen_with_kvcache_func
  • flash_sparse_attn_func
  • flash_sparse_attn_with_kvcache_func
  • flash_sparse_attn_varlen_func
  • flash_sparse_attn_varlen_with_kvcache_func
  • flash_gated_attn_func
  • flash_gated_attn_with_kvcache_func
  • flash_gated_attn_varlen_func
  • flash_gated_attn_varlen_with_kvcache_func

Source

Originally from HKUSTDial/flash-sparse-attention.

Downloads last month
8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support