| # Path to <10% float compute: analysis |
|
|
| The user's original target: **less than 10% of training compute time should be float operations**. Multi-session work has produced an int8 GEMM kernel; this note analyzes how close to <10% that gets us, and what's needed to close the gap. |
|
|
| ## 300M v76 baseline profile (from prior session, RTX 5090) |
|
|
| | Operation | % of CUDA time | Float? | |
| |---|---|---| |
| | aten::mm (BitLinear GEMM) | 23.5% | yes (bf16) | |
| | aten::bmm (attention QK + AV) | 20.5% | yes (bf16) | |
| | aten::topk (RWSP top-K) | 19.2% | yes | |
| | triton ALiBi+mask | 13.3% | yes | |
| | masked_fill | 4.8% | yes | |
| | Other (norms, activations, optimizer) | ~18% | yes | |
| | **Float total** | **~99%** | | |
| |
| So at baseline, essentially **everything is float**. |
| |
| ## With current int8 GEMM kernel (forward + dx backward) |
| |
| The Triton kernel (built and CPU-validated) replaces: |
| - BitLinear forward GEMM: ±1 × ±1 → int32 (exact) |
| - BitLinear dx backward: dy_q × W_int8 → int32 + per-row alpha rescale |
| |
| Conservative estimate (assuming 2× speedup over bf16 cuBLAS at our shapes, |
| which is what int8 tensor cores promise on Blackwell): |
| |
| | Operation | Time (% of original total) | Now integer? | |
| |---|---|---| |
| | BitLinear fwd GEMM (1/3 of the 23.5%) | 7.8% → ~3.9% | YES (int8) | |
| | BitLinear dx bwd GEMM (1/3 of 23.5%) | 7.8% → ~3.9% | YES (int8) | |
| | BitLinear dW bwd GEMM (1/3) | 7.8% | NO (per-row α can't fold into int_mm) | |
| | Attention bmm | 20.5% | NO (Q,K are ±1 but A is float post-softmax) | |
| | Topk | 19.2% | NO (sorting on float scores) | |
| | ALiBi | 13.3% | NO (could be int with redesign) | |
| | masked_fill, norms, etc | ~22% | NO | |
| | **New total time** | ~84% (saved 16% from int8 paths) | | |
| | **Integer fraction** | (3.9 + 3.9) / 84 = **9.3%** | | |
| | **Float fraction** | 90.7% | | |
| |
| So the int8 kernel work alone gets us to ~9% **integer** compute (90% float). |
| **Far from the <10% float target.** |
| |
| ## What would reach <10% float? |
| |
| ### Path A: integer-ize attention (~30% reachable) |
| Currently bmm uses Q, K, V at ±1 (after `sign_ste_clipped` on BitLinearScaled |
| output) but the matmul is float. Needs: |
| 1. Q×K.T → int8 (since Q, K are ±1) — easy, same kernel pattern |
| 2. ALiBi as int (slopes are powers of 2, distances are int) — easy |
| 3. Softmax → integer-friendly approximation (Gumbel-hard with integer |
| tiebreak?) — HARD |
| 4. A×V → impossible to keep integer (A is post-softmax float) |
| |
| So ~half of attention (QK, ALiBi, masking) could become integer, ~30% of |
| total CUDA time. Combined with int8 BitLinear, gets us to: |
| - Integer time: 9% (BitLinear) + ~13% (QK + ALiBi) = ~22% |
| - Float time: ~78% |
| |
| Still far from <10% float. |
| |
| ### Path B: replace softmax with hard top-K |
| RWSP-G's "top-K survivors with rank weighting" is ALREADY a step in this |
| direction. With `K=1` (argmax + rank weighting=1), we get a hard pointer. |
| |
| This eliminates both A×V's float dependency and topk's float sorting (since |
| the "top-1" is just argmax → integer). But it would also kill val_bpc by |
| losing softmax's smoothness. |
|
|
| ### Path C: reframe the goal |
| "<10% float compute" may not be the right metric. More tractable framings: |
| - **Storage bits per parameter** (already 1-bit on BitLinear weights) |
| - **GEMM FLOPs in integer mode** (achievable: forward + dx backward as int8) |
| - **Energy per training step** (different metric, requires hardware-specific |
| measurement) |
|
|
| ## Realistic path forward |
|
|
| 1. **First** validate the existing int8 GEMM kernel on GPU. Check: |
| - Triton int8 vs bf16 cuBLAS speedup at training shapes (predicted 1-3×) |
| - Numerical correctness (sign-STE gradients match) |
| - End-to-end training step throughput |
|
|
| 2. **If kernel wins**, integer-ize the attention QK matmul. Same pattern as |
| BitLinear. Adds ~13% more integer compute. |
|
|
| 3. **Defer "true <10% float"** — it requires fundamental architectural |
| changes (e.g., replacing softmax with integer-friendly attention, which |
| is unsolved as far as I know). The int8 GEMM work is a meaningful step |
| that can land in v76 paper or as a follow-up engineering paper. |
|
|
| ## Bottom line |
|
|
| The int8 kernel work moves us from 100% float to ~91% float (with optimistic |
| GPU benchmarks). To hit <10% float, attention and softmax need |
| fundamental redesigns that go beyond this session's scope. The honest |
| framing: this session enabled **1-bit weights and 1-bit GEMM compute** for |
| linear layers; full 1-bit training requires more research than fit in the |
| project window. |
|
|
| The synth K-shift work is independent of the integer-compute goal — it's |
| about making sign-STE training converge to a useful attention pattern, which |
| matters whether the matmul is bf16 or int8. |
|
|