Kernels
File size: 2,469 Bytes
4392d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94db4e0
4392d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94db4e0
4392d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
---
license: bsd-3-clause
tags:
  - kernels
---

# sglang-flash-attn3

Pre-built Flash Attention 3 (forward-only) CUDA kernels from [sgl-flash-attn](https://github.com/sgl-project/sgl-flash-attn), packaged for the [HuggingFace kernels library](https://github.com/huggingface/kernels). Requires Hopper (sm_90+).

Kernel source: [kernels-community/sgl-flash-attn3](https://github.com/huggingface/kernels-community/tree/main/sgl-flash-attn3)

## Usage

```bash
pip install kernels
```

```python
from kernels import get_kernel

fa3 = get_kernel("kernels-community/sgl-flash-attn3", revision="v1")

fa3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=True)
fa3.flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=True)
fa3.is_fa3_supported()  # True on H100/H200
```

## SGLang Integration

Entry point: [`python/sglang/srt/layers/attention/flashattention_backend.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/flashattention_backend.py)

Original:
```python
from sgl_kernel.flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_fa3
from sgl_kernel.flash_attn import flash_attn_with_kvcache as flash_attn_with_kvcache_fa3
```

Replace with:
```python
from kernels import get_kernel
_fa3_mod = get_kernel("kernels-community/sgl-flash-attn3", revision="v1")
flash_attn_varlen_func_fa3 = _fa3_mod.flash_attn_varlen_func
flash_attn_with_kvcache_fa3 = _fa3_mod.flash_attn_with_kvcache
```

Same pattern in 5 other files:
- `dual_chunk_flashattention_backend.py`
- `nsa_backend.py`
- `xpu_backend.py`
- `vision.py`
- `multimodal_gen/runtime/layers/attention/backends/flash_attn.py`


## Benchmarks

H100 NVL, Qwen2.5-3B-Instruct, FA3. All deltas within noise - **zero performance regression**.

| Scenario | Native `sgl_kernel` FA3 tok/s | HF Hub FA3 tok/s | Δ |
|:--|--:|--:|:--|
| Short Gen (128→32) | 40,934 | 39,878 | -2.6% |
| Long Gen (256→1024) | 25,054 | 26,239 | +4.7% |
| Long Prefill (2048→128) | 53,833 | 54,283 | +0.8% |
| Bursty (512→256, 16rps) | 6,518 | 6,527 | +0.1% |
| High Concurrency (256→256) | 40,666 | 40,522 | -0.4% |

## Credits

- [Tri Dao](https://tridao.me/) - [Flash Attention 3](https://tridao.me/blog/2024/flash3/)
- [SGLang](https://github.com/sgl-project/sglang) - `sgl_kernel` FA3 implementation
- [HuggingFace](https://huggingface.co/kernels-community) - [kernel-builder](https://huggingface.co/blog/kernel-builder) infrastructure