custom-nki-kernels / README.md
Jingya's picture
Jingya HF Staff
Update README.md
2582a71 verified
---
license: mit
---
# NKI Kernel Experiments β€” Flux2-klein-4B on Neuron
Hardware: AWS Trn1.32xlarge (32 NeuronCores), TP=4, bfloat16
Model: `black-forest-labs/FLUX.2-klein-4B`
Shapes: B=1, 512Γ—512 β†’ img_S=256 (2Γ— patchify), txt_S=512, inner_dim H=3072, n_heads=24, head_dim=128
---
## 1. RoPE kernel (`nkilib.core.embeddings.rope`)
### Kernel constraints
- `d_head ∈ {64, 128}` (Flux2-klein: 128 βœ“)
- `S ≀ 512` β€” applied to sequence-length dimension before attention
- `n_heads ≀ 16` per rank β€” after TP=4 sharding: 24/4 = 6 βœ“
- Input layout must be `[B, n_heads, S, d_head]`
### Flux2-klein applicability
| Block type | S | Fits S≀512? | Notes |
|---|---|---|---|
| Single-stream | img_S + txt_S = 256 + 512 = 768 | **No** | RoPE is applied to the concatenated image+text sequence |
| Double-stream (image) | img_S = 256 | Yes | But double-stream blocks apply RoPE inside FluxAttnProcessor after separate Q/K projections β€” hooks into NKI require custom processor |
| Double-stream (text) | txt_S = 512 | Yes (boundary) | |
**Verdict: Not practical.** Single-stream blocks (20/25 total) exceed S=512. Double-stream (5/25) would require custom processors. The XLA compiler already fuses RoPE with the surrounding matmuls in the same NEFF β€” a standalone NKI kernel would break that fusion (see Β§3).
---
## 2. Pipeline integration β€” `--fused-qkv` flag
Implementation: `Flux2AttnProcessorFusedQKV` in `pipeline.py`, activated by `--fused-qkv`.
Replaces 3 separate `to_q / to_k / to_v` ColwisePar linear calls in double-stream blocks with a single NKI `nki_qkv` kernel call.
### Timing (warm steps, avg of last 3/20 steps)
| Mode | Steps | Warm avg (s/step) | vs baseline |
|---|---|---|---|
| Eager, baseline | 20 | **0.824 s/step** | 1Γ— |
| Eager, `--fused-qkv` | 20 | **14.86 s/step** | **18Γ— slower** |
Output correctness: identical pixel range, mean, and std at every step β€” the kernel produces correct results.
### Root cause: XLA whole-block fusion
In eager (lazy-XLA) mode, the XLA compiler traces the entire transformer block as one HLO program and compiles it into a **single NEFF** (`neff_cache/{hash}.neff`). This fuses:
- All Q/K/V projections
- RoPE embeddings
- Flash attention (via custom prim decomposition)
- Output projection + MLP
- Layer norms
Inserting a standalone NKI kernel (`@nki.jit`) creates **opaque tensor boundaries** β€” XLA cannot inline or fuse across NKI kernel calls. The compiler sees:
```
[XLA subgraph] β†’ NKI qkv kernel β†’ [XLA subgraph]
```
instead of one monolithic NEFF. This fragmentation:
1. Adds kernel launch overhead (PCIe round-trips for each NKI call)
2. Prevents data reuse that XLA would achieve within the fused NEFF
3. Defeats the cache: the fragmented graphs generate different, smaller NEFFs with no sharing benefit
The 18Γ— slowdown is consistent with this β€” the baseline fused NEFF is highly optimised; the fragmented version is not.
---
## 3. Compile mode + fused QKV (`--mode compile --fused-qkv`) β€” bug fix note
---
## 4. Flash attention kernel (`flux2_flash_attn`)
Script: `examples/flux2-klein/nki_flash_attn.py`
Run: `torchrun --nproc_per_node=4 flux2-klein/nki_flash_attn.py`
Two-pass online softmax, BLOCK_Q=128, BLOCK_K=128, bidirectional (no causal mask).
Uses the older NKI ISA API (`sbuf.view / psum.view / hbm.view / nisa.*`).
### Algorithm
For each head (looped over N=6 sequentially in one kernel instance):
For each Q tile (q_idx = 0..5):
Pass 1 of online softmax (here collapsed into single-pass via exp-only):
For each K tile (ks = 0..5):
score_T = k_tile.T @ q_tile (BLOCK_K, BLOCK_Q) via nc_matmul transposed trick
probs_T = exp(score_T * scale)
out_psum += probs_T.T @ v_tile (BLOCK_Q, D)
row_sum += probs_T.T @ ones_v (BLOCK_Q, 1)
out = out_psum / row_sum β†’ bf16 β†’ HBM
Note: this is an unnormalized (non-numerically-stable) softmax β€” no row_max subtraction.
Suitable for correctness test; may overflow for long sequences or large activations.
## 5. Compile mode full comparison
Hardware: trn2.3xlarge, TP=4, bfloat16, 512Γ—512, 4 steps, random weights, 4 runs (1 cold + 3 warm)
Date: 2026-03-31 | neuronxcc: `2.0.236418.0a0+9af338ad`
All four compile-mode variants measured on the same neuronxcc build for a fair apples-to-apples comparison.
### Vanilla compile (no custom kernels)
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 533.449s | 3.868s | 3.868s | 3.868s | 545.053s |
| 2–4 | WARM | 3.868s | 3.869s | 3.869s | 3.869s | 15.475s |
Cold: **533.4s** Β· Warm avg: **3.869 s/step** Β· Throughput: **0.258 steps/s**
### Compile + `--fused-qkv`
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 651.147s | 19.874s | 3.859s | 3.859s | 678.740s |
| 2–4 | WARM | 3.859s | 3.859s | 3.860s | 3.860s | 15.438s |
Cold: **651.1s** Β· Warm avg: **3.859 s/step** Β· Throughput: **0.259 steps/s**
### Compile + `--flash-attn`
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 862.344s | 19.601s | 4.159s | 4.159s | 890.263s |
| 2–4 | WARM | 4.159s | 4.159s | 4.159s | 4.159s | 16.636s |
Cold: **862.3s** Β· Warm avg: **4.159 s/step** Β· Throughput: **0.240 steps/s**
### Compile + `--fused-qkv --flash-attn` (combined)
| Run | Type | step01 | step02 | step03 | step04 | total |
|---|---|---|---|---|---|---|
| 1 | COLD | 830.249s | 19.558s | 4.149s | 4.149s | 858.105s |
| 2–4 | WARM | 4.149s | 4.149s | 4.149s | 4.149s | 16.597s |
Cold: **830.2s** Β· Warm avg: **4.149 s/step** Β· Throughput: **0.241 steps/s**
### Summary table
| Mode | Cold (s) | Warm avg/step | Throughput | vs vanilla compile |
|---|---|---|---|---|
| Eager, baseline | 9.3s | **0.835 s/step** | 1.198 steps/s | 4.6Γ— faster |
| Compile, vanilla | 533.4s | 3.869 s/step | 0.258 steps/s | 1Γ— (baseline) |
| Compile, `--fused-qkv` | 651.1s | **3.859 s/step** | **0.259 steps/s** | βˆ’0.3% (noise) |
| Compile, `--flash-attn` | 862.3s | 4.159 s/step | 0.240 steps/s | +7.5% slower |
| Compile, `--fused-qkv --flash-attn` | 830.2s | 4.149 s/step | 0.241 steps/s | +7.2% slower |
### Interpretation
- **Fused-QKV has no measurable effect in compile mode** (3.859 vs 3.869 β€” within run-to-run noise).
The Dynamo+NEFF compiler already fuses QKV projections at the HLO level; the explicit NKI kernel
neither helps nor hurts, but adds 118s to cold compilation.
- **Flash-attn is ~7% slower than vanilla** regardless of whether fused-QKV is also enabled.
The unnormalized single-pass softmax and sequential head loop are less efficient than the
compiler's built-in attention decomposition (two-pass numerically stable, better SPMD utilisation).
- **Combining both kernels** gives the same result as flash-attn alone (4.149 vs 4.159 β€” within noise).
fused-QKV contributes nothing additional in compile mode.
- **Cold compilation time** grows with NKI kernel count: vanilla (533s) β†’ fused-qkv (651s, +22%) β†’
combined (830s, +56%) β†’ flash-attn alone (862s, +62%). Each NKI kernel adds a separate KLIR
compilation pass inside neuronxcc.
---
## 8. Conclusions
| Kernel | Correct | Practical for eager | Practical for compile |
|---|---|---|---|
| NKI RoPE | β€” | No (S > 512 for single-stream) | No (same constraint) |
| NKI QKV | Yes | **No** β€” breaks XLA fusion (18Γ— slower) | Negligible effect (within noise) |
| NKI Flash Attention | Yes (cosine=0.9999) | TBD | **No** β€” 7% slower than vanilla, +62% compile time |
| NKI QKV + Flash Attention | Yes | **No** | Same as flash-attn alone |