| --- |
| license: mit |
| --- |
| |
|
|
| # NKI Kernel Experiments β Flux2-klein-4B on Neuron |
|
|
| Hardware: AWS Trn1.32xlarge (32 NeuronCores), TP=4, bfloat16 |
| Model: `black-forest-labs/FLUX.2-klein-4B` |
| Shapes: B=1, 512Γ512 β img_S=256 (2Γ patchify), txt_S=512, inner_dim H=3072, n_heads=24, head_dim=128 |
| |
| --- |
| |
| ## 1. RoPE kernel (`nkilib.core.embeddings.rope`) |
| |
| ### Kernel constraints |
| - `d_head β {64, 128}` (Flux2-klein: 128 β) |
| - `S β€ 512` β applied to sequence-length dimension before attention |
| - `n_heads β€ 16` per rank β after TP=4 sharding: 24/4 = 6 β |
| - Input layout must be `[B, n_heads, S, d_head]` |
|
|
| ### Flux2-klein applicability |
|
|
| | Block type | S | Fits Sβ€512? | Notes | |
| |---|---|---|---| |
| | Single-stream | img_S + txt_S = 256 + 512 = 768 | **No** | RoPE is applied to the concatenated image+text sequence | |
| | Double-stream (image) | img_S = 256 | Yes | But double-stream blocks apply RoPE inside FluxAttnProcessor after separate Q/K projections β hooks into NKI require custom processor | |
| | Double-stream (text) | txt_S = 512 | Yes (boundary) | | |
|
|
| **Verdict: Not practical.** Single-stream blocks (20/25 total) exceed S=512. Double-stream (5/25) would require custom processors. The XLA compiler already fuses RoPE with the surrounding matmuls in the same NEFF β a standalone NKI kernel would break that fusion (see Β§3). |
|
|
| --- |
|
|
| ## 2. Pipeline integration β `--fused-qkv` flag |
|
|
| Implementation: `Flux2AttnProcessorFusedQKV` in `pipeline.py`, activated by `--fused-qkv`. |
| Replaces 3 separate `to_q / to_k / to_v` ColwisePar linear calls in double-stream blocks with a single NKI `nki_qkv` kernel call. |
|
|
| ### Timing (warm steps, avg of last 3/20 steps) |
|
|
| | Mode | Steps | Warm avg (s/step) | vs baseline | |
| |---|---|---|---| |
| | Eager, baseline | 20 | **0.824 s/step** | 1Γ | |
| | Eager, `--fused-qkv` | 20 | **14.86 s/step** | **18Γ slower** | |
|
|
| Output correctness: identical pixel range, mean, and std at every step β the kernel produces correct results. |
|
|
| ### Root cause: XLA whole-block fusion |
|
|
| In eager (lazy-XLA) mode, the XLA compiler traces the entire transformer block as one HLO program and compiles it into a **single NEFF** (`neff_cache/{hash}.neff`). This fuses: |
| - All Q/K/V projections |
| - RoPE embeddings |
| - Flash attention (via custom prim decomposition) |
| - Output projection + MLP |
| - Layer norms |
|
|
| Inserting a standalone NKI kernel (`@nki.jit`) creates **opaque tensor boundaries** β XLA cannot inline or fuse across NKI kernel calls. The compiler sees: |
|
|
| ``` |
| [XLA subgraph] β NKI qkv kernel β [XLA subgraph] |
| ``` |
|
|
| instead of one monolithic NEFF. This fragmentation: |
| 1. Adds kernel launch overhead (PCIe round-trips for each NKI call) |
| 2. Prevents data reuse that XLA would achieve within the fused NEFF |
| 3. Defeats the cache: the fragmented graphs generate different, smaller NEFFs with no sharing benefit |
|
|
| The 18Γ slowdown is consistent with this β the baseline fused NEFF is highly optimised; the fragmented version is not. |
|
|
| --- |
|
|
| ## 3. Compile mode + fused QKV (`--mode compile --fused-qkv`) β bug fix note |
|
|
|
|
| --- |
|
|
| ## 4. Flash attention kernel (`flux2_flash_attn`) |
|
|
| Script: `examples/flux2-klein/nki_flash_attn.py` |
| Run: `torchrun --nproc_per_node=4 flux2-klein/nki_flash_attn.py` |
|
|
| Two-pass online softmax, BLOCK_Q=128, BLOCK_K=128, bidirectional (no causal mask). |
| Uses the older NKI ISA API (`sbuf.view / psum.view / hbm.view / nisa.*`). |
|
|
| ### Algorithm |
|
|
| For each head (looped over N=6 sequentially in one kernel instance): |
| For each Q tile (q_idx = 0..5): |
| Pass 1 of online softmax (here collapsed into single-pass via exp-only): |
| For each K tile (ks = 0..5): |
| score_T = k_tile.T @ q_tile (BLOCK_K, BLOCK_Q) via nc_matmul transposed trick |
| probs_T = exp(score_T * scale) |
| out_psum += probs_T.T @ v_tile (BLOCK_Q, D) |
| row_sum += probs_T.T @ ones_v (BLOCK_Q, 1) |
| out = out_psum / row_sum β bf16 β HBM |
| |
| Note: this is an unnormalized (non-numerically-stable) softmax β no row_max subtraction. |
| Suitable for correctness test; may overflow for long sequences or large activations. |
|
|
|
|
| ## 5. Compile mode full comparison |
|
|
| Hardware: trn2.3xlarge, TP=4, bfloat16, 512Γ512, 4 steps, random weights, 4 runs (1 cold + 3 warm) |
| Date: 2026-03-31 | neuronxcc: `2.0.236418.0a0+9af338ad` |
|
|
| All four compile-mode variants measured on the same neuronxcc build for a fair apples-to-apples comparison. |
|
|
| ### Vanilla compile (no custom kernels) |
|
|
| | Run | Type | step01 | step02 | step03 | step04 | total | |
| |---|---|---|---|---|---|---| |
| | 1 | COLD | 533.449s | 3.868s | 3.868s | 3.868s | 545.053s | |
| | 2β4 | WARM | 3.868s | 3.869s | 3.869s | 3.869s | 15.475s | |
|
|
| Cold: **533.4s** Β· Warm avg: **3.869 s/step** Β· Throughput: **0.258 steps/s** |
|
|
| ### Compile + `--fused-qkv` |
|
|
| | Run | Type | step01 | step02 | step03 | step04 | total | |
| |---|---|---|---|---|---|---| |
| | 1 | COLD | 651.147s | 19.874s | 3.859s | 3.859s | 678.740s | |
| | 2β4 | WARM | 3.859s | 3.859s | 3.860s | 3.860s | 15.438s | |
|
|
| Cold: **651.1s** Β· Warm avg: **3.859 s/step** Β· Throughput: **0.259 steps/s** |
|
|
| ### Compile + `--flash-attn` |
|
|
| | Run | Type | step01 | step02 | step03 | step04 | total | |
| |---|---|---|---|---|---|---| |
| | 1 | COLD | 862.344s | 19.601s | 4.159s | 4.159s | 890.263s | |
| | 2β4 | WARM | 4.159s | 4.159s | 4.159s | 4.159s | 16.636s | |
|
|
| Cold: **862.3s** Β· Warm avg: **4.159 s/step** Β· Throughput: **0.240 steps/s** |
|
|
| ### Compile + `--fused-qkv --flash-attn` (combined) |
|
|
| | Run | Type | step01 | step02 | step03 | step04 | total | |
| |---|---|---|---|---|---|---| |
| | 1 | COLD | 830.249s | 19.558s | 4.149s | 4.149s | 858.105s | |
| | 2β4 | WARM | 4.149s | 4.149s | 4.149s | 4.149s | 16.597s | |
|
|
| Cold: **830.2s** Β· Warm avg: **4.149 s/step** Β· Throughput: **0.241 steps/s** |
|
|
| ### Summary table |
|
|
| | Mode | Cold (s) | Warm avg/step | Throughput | vs vanilla compile | |
| |---|---|---|---|---| |
| | Eager, baseline | 9.3s | **0.835 s/step** | 1.198 steps/s | 4.6Γ faster | |
| | Compile, vanilla | 533.4s | 3.869 s/step | 0.258 steps/s | 1Γ (baseline) | |
| | Compile, `--fused-qkv` | 651.1s | **3.859 s/step** | **0.259 steps/s** | β0.3% (noise) | |
| | Compile, `--flash-attn` | 862.3s | 4.159 s/step | 0.240 steps/s | +7.5% slower | |
| | Compile, `--fused-qkv --flash-attn` | 830.2s | 4.149 s/step | 0.241 steps/s | +7.2% slower | |
|
|
| ### Interpretation |
|
|
| - **Fused-QKV has no measurable effect in compile mode** (3.859 vs 3.869 β within run-to-run noise). |
| The Dynamo+NEFF compiler already fuses QKV projections at the HLO level; the explicit NKI kernel |
| neither helps nor hurts, but adds 118s to cold compilation. |
| - **Flash-attn is ~7% slower than vanilla** regardless of whether fused-QKV is also enabled. |
| The unnormalized single-pass softmax and sequential head loop are less efficient than the |
| compiler's built-in attention decomposition (two-pass numerically stable, better SPMD utilisation). |
| - **Combining both kernels** gives the same result as flash-attn alone (4.149 vs 4.159 β within noise). |
| fused-QKV contributes nothing additional in compile mode. |
| - **Cold compilation time** grows with NKI kernel count: vanilla (533s) β fused-qkv (651s, +22%) β |
| combined (830s, +56%) β flash-attn alone (862s, +62%). Each NKI kernel adds a separate KLIR |
| compilation pass inside neuronxcc. |
|
|
| --- |
|
|
| ## 8. Conclusions |
|
|
| | Kernel | Correct | Practical for eager | Practical for compile | |
| |---|---|---|---| |
| | NKI RoPE | β | No (S > 512 for single-stream) | No (same constraint) | |
| | NKI QKV | Yes | **No** β breaks XLA fusion (18Γ slower) | Negligible effect (within noise) | |
| | NKI Flash Attention | Yes (cosine=0.9999) | TBD | **No** β 7% slower than vanilla, +62% compile time | |
| | NKI QKV + Flash Attention | Yes | **No** | Same as flash-attn alone | |