Kernels
JingzeShi's picture
Uploaded using `kernel-builder`.
793d976 verified
metadata
library_name: kernels
license: bsd-3-clause

flash-sparse-attention

Flash Sparse Attention Triton kernels — dense, sparse, and gated attention with forward, backward, and decode paths.

Usage

from kernels import get_kernel

fsa = get_kernel("JingzeShi/flash-sparse-attention", version=1)

# Dense forward
out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)

# Decode with KV cache
out = fsa.flash_dense_attn_with_kvcache_func(q, k, v)

# Sparse attention
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)

# Gated attention
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)

Available functions

  • flash_dense_attn_func
  • flash_dense_attn_with_kvcache_func
  • flash_dense_attn_varlen_func
  • flash_dense_attn_varlen_with_kvcache_func
  • flash_sparse_attn_func
  • flash_sparse_attn_with_kvcache_func
  • flash_sparse_attn_varlen_func
  • flash_sparse_attn_varlen_with_kvcache_func
  • flash_gated_attn_func
  • flash_gated_attn_with_kvcache_func
  • flash_gated_attn_varlen_func
  • flash_gated_attn_varlen_with_kvcache_func

Source

Originally from HKUSTDial/flash-sparse-attention.