--- license: mit --- # NKI Kernel Experiments — Flux2-klein-4B on Neuron Hardware: AWS Trn1.32xlarge (32 NeuronCores), TP=4, bfloat16 Model: `black-forest-labs/FLUX.2-klein-4B` Shapes: B=1, 512×512 → img_S=256 (2× patchify), txt_S=512, inner_dim H=3072, n_heads=24, head_dim=128 --- ## 1. RoPE kernel (`nkilib.core.embeddings.rope`) ### Kernel constraints - `d_head ∈ {64, 128}` (Flux2-klein: 128 ✓) - `S ≤ 512` — applied to sequence-length dimension before attention - `n_heads ≤ 16` per rank — after TP=4 sharding: 24/4 = 6 ✓ - Input layout must be `[B, n_heads, S, d_head]` ### Flux2-klein applicability | Block type | S | Fits S≤512? | Notes | |---|---|---|---| | Single-stream | img_S + txt_S = 256 + 512 = 768 | **No** | RoPE is applied to the concatenated image+text sequence | | Double-stream (image) | img_S = 256 | Yes | But double-stream blocks apply RoPE inside FluxAttnProcessor after separate Q/K projections — hooks into NKI require custom processor | | Double-stream (text) | txt_S = 512 | Yes (boundary) | | **Verdict: Not practical.** Single-stream blocks (20/25 total) exceed S=512. Double-stream (5/25) would require custom processors. The XLA compiler already fuses RoPE with the surrounding matmuls in the same NEFF — a standalone NKI kernel would break that fusion (see §3). --- ## 2. Pipeline integration — `--fused-qkv` flag Implementation: `Flux2AttnProcessorFusedQKV` in `pipeline.py`, activated by `--fused-qkv`. Replaces 3 separate `to_q / to_k / to_v` ColwisePar linear calls in double-stream blocks with a single NKI `nki_qkv` kernel call. ### Timing (warm steps, avg of last 3/20 steps) | Mode | Steps | Warm avg (s/step) | vs baseline | |---|---|---|---| | Eager, baseline | 20 | **0.824 s/step** | 1× | | Eager, `--fused-qkv` | 20 | **14.86 s/step** | **18× slower** | Output correctness: identical pixel range, mean, and std at every step — the kernel produces correct results. ### Root cause: XLA whole-block fusion In eager (lazy-XLA) mode, the XLA compiler traces the entire transformer block as one HLO program and compiles it into a **single NEFF** (`neff_cache/{hash}.neff`). This fuses: - All Q/K/V projections - RoPE embeddings - Flash attention (via custom prim decomposition) - Output projection + MLP - Layer norms Inserting a standalone NKI kernel (`@nki.jit`) creates **opaque tensor boundaries** — XLA cannot inline or fuse across NKI kernel calls. The compiler sees: ``` [XLA subgraph] → NKI qkv kernel → [XLA subgraph] ``` instead of one monolithic NEFF. This fragmentation: 1. Adds kernel launch overhead (PCIe round-trips for each NKI call) 2. Prevents data reuse that XLA would achieve within the fused NEFF 3. Defeats the cache: the fragmented graphs generate different, smaller NEFFs with no sharing benefit The 18× slowdown is consistent with this — the baseline fused NEFF is highly optimised; the fragmented version is not. --- ## 3. Compile mode + fused QKV (`--mode compile --fused-qkv`) — bug fix note --- ## 4. Flash attention kernel (`flux2_flash_attn`) Script: `examples/flux2-klein/nki_flash_attn.py` Run: `torchrun --nproc_per_node=4 flux2-klein/nki_flash_attn.py` Two-pass online softmax, BLOCK_Q=128, BLOCK_K=128, bidirectional (no causal mask). Uses the older NKI ISA API (`sbuf.view / psum.view / hbm.view / nisa.*`). ### Algorithm For each head (looped over N=6 sequentially in one kernel instance): For each Q tile (q_idx = 0..5): Pass 1 of online softmax (here collapsed into single-pass via exp-only): For each K tile (ks = 0..5): score_T = k_tile.T @ q_tile (BLOCK_K, BLOCK_Q) via nc_matmul transposed trick probs_T = exp(score_T * scale) out_psum += probs_T.T @ v_tile (BLOCK_Q, D) row_sum += probs_T.T @ ones_v (BLOCK_Q, 1) out = out_psum / row_sum → bf16 → HBM Note: this is an unnormalized (non-numerically-stable) softmax — no row_max subtraction. Suitable for correctness test; may overflow for long sequences or large activations. ## 5. Compile mode full comparison Hardware: trn2.3xlarge, TP=4, bfloat16, 512×512, 4 steps, random weights, 4 runs (1 cold + 3 warm) Date: 2026-03-31 | neuronxcc: `2.0.236418.0a0+9af338ad` All four compile-mode variants measured on the same neuronxcc build for a fair apples-to-apples comparison. ### Vanilla compile (no custom kernels) | Run | Type | step01 | step02 | step03 | step04 | total | |---|---|---|---|---|---|---| | 1 | COLD | 533.449s | 3.868s | 3.868s | 3.868s | 545.053s | | 2–4 | WARM | 3.868s | 3.869s | 3.869s | 3.869s | 15.475s | Cold: **533.4s** · Warm avg: **3.869 s/step** · Throughput: **0.258 steps/s** ### Compile + `--fused-qkv` | Run | Type | step01 | step02 | step03 | step04 | total | |---|---|---|---|---|---|---| | 1 | COLD | 651.147s | 19.874s | 3.859s | 3.859s | 678.740s | | 2–4 | WARM | 3.859s | 3.859s | 3.860s | 3.860s | 15.438s | Cold: **651.1s** · Warm avg: **3.859 s/step** · Throughput: **0.259 steps/s** ### Compile + `--flash-attn` | Run | Type | step01 | step02 | step03 | step04 | total | |---|---|---|---|---|---|---| | 1 | COLD | 862.344s | 19.601s | 4.159s | 4.159s | 890.263s | | 2–4 | WARM | 4.159s | 4.159s | 4.159s | 4.159s | 16.636s | Cold: **862.3s** · Warm avg: **4.159 s/step** · Throughput: **0.240 steps/s** ### Compile + `--fused-qkv --flash-attn` (combined) | Run | Type | step01 | step02 | step03 | step04 | total | |---|---|---|---|---|---|---| | 1 | COLD | 830.249s | 19.558s | 4.149s | 4.149s | 858.105s | | 2–4 | WARM | 4.149s | 4.149s | 4.149s | 4.149s | 16.597s | Cold: **830.2s** · Warm avg: **4.149 s/step** · Throughput: **0.241 steps/s** ### Summary table | Mode | Cold (s) | Warm avg/step | Throughput | vs vanilla compile | |---|---|---|---|---| | Eager, baseline | 9.3s | **0.835 s/step** | 1.198 steps/s | 4.6× faster | | Compile, vanilla | 533.4s | 3.869 s/step | 0.258 steps/s | 1× (baseline) | | Compile, `--fused-qkv` | 651.1s | **3.859 s/step** | **0.259 steps/s** | −0.3% (noise) | | Compile, `--flash-attn` | 862.3s | 4.159 s/step | 0.240 steps/s | +7.5% slower | | Compile, `--fused-qkv --flash-attn` | 830.2s | 4.149 s/step | 0.241 steps/s | +7.2% slower | ### Interpretation - **Fused-QKV has no measurable effect in compile mode** (3.859 vs 3.869 — within run-to-run noise). The Dynamo+NEFF compiler already fuses QKV projections at the HLO level; the explicit NKI kernel neither helps nor hurts, but adds 118s to cold compilation. - **Flash-attn is ~7% slower than vanilla** regardless of whether fused-QKV is also enabled. The unnormalized single-pass softmax and sequential head loop are less efficient than the compiler's built-in attention decomposition (two-pass numerically stable, better SPMD utilisation). - **Combining both kernels** gives the same result as flash-attn alone (4.149 vs 4.159 — within noise). fused-QKV contributes nothing additional in compile mode. - **Cold compilation time** grows with NKI kernel count: vanilla (533s) → fused-qkv (651s, +22%) → combined (830s, +56%) → flash-attn alone (862s, +62%). Each NKI kernel adds a separate KLIR compilation pass inside neuronxcc. --- ## 8. Conclusions | Kernel | Correct | Practical for eager | Practical for compile | |---|---|---|---| | NKI RoPE | — | No (S > 512 for single-stream) | No (same constraint) | | NKI QKV | Yes | **No** — breaks XLA fusion (18× slower) | Negligible effect (within noise) | | NKI Flash Attention | Yes (cosine=0.9999) | TBD | **No** — 7% slower than vanilla, +62% compile time | | NKI QKV + Flash Attention | Yes | **No** | Same as flash-attn alone |