bitnet-1bitllm / notes /integer_compute_analysis.md
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified

Path to <10% float compute: analysis

The user's original target: less than 10% of training compute time should be float operations. Multi-session work has produced an int8 GEMM kernel; this note analyzes how close to <10% that gets us, and what's needed to close the gap.

300M v76 baseline profile (from prior session, RTX 5090)

Operation % of CUDA time Float?
aten::mm (BitLinear GEMM) 23.5% yes (bf16)
aten::bmm (attention QK + AV) 20.5% yes (bf16)
aten::topk (RWSP top-K) 19.2% yes
triton ALiBi+mask 13.3% yes
masked_fill 4.8% yes
Other (norms, activations, optimizer) ~18% yes
Float total ~99%

So at baseline, essentially everything is float.

With current int8 GEMM kernel (forward + dx backward)

The Triton kernel (built and CPU-validated) replaces:

  • BitLinear forward GEMM: ±1 × ±1 → int32 (exact)
  • BitLinear dx backward: dy_q × W_int8 → int32 + per-row alpha rescale

Conservative estimate (assuming 2× speedup over bf16 cuBLAS at our shapes, which is what int8 tensor cores promise on Blackwell):

Operation Time (% of original total) Now integer?
BitLinear fwd GEMM (1/3 of the 23.5%) 7.8% → ~3.9% YES (int8)
BitLinear dx bwd GEMM (1/3 of 23.5%) 7.8% → ~3.9% YES (int8)
BitLinear dW bwd GEMM (1/3) 7.8% NO (per-row α can't fold into int_mm)
Attention bmm 20.5% NO (Q,K are ±1 but A is float post-softmax)
Topk 19.2% NO (sorting on float scores)
ALiBi 13.3% NO (could be int with redesign)
masked_fill, norms, etc ~22% NO
New total time ~84% (saved 16% from int8 paths)
Integer fraction (3.9 + 3.9) / 84 = 9.3%
Float fraction 90.7%

So the int8 kernel work alone gets us to ~9% integer compute (90% float). Far from the <10% float target.

What would reach <10% float?

Path A: integer-ize attention (~30% reachable)

Currently bmm uses Q, K, V at ±1 (after sign_ste_clipped on BitLinearScaled output) but the matmul is float. Needs:

  1. Q×K.T → int8 (since Q, K are ±1) — easy, same kernel pattern
  2. ALiBi as int (slopes are powers of 2, distances are int) — easy
  3. Softmax → integer-friendly approximation (Gumbel-hard with integer tiebreak?) — HARD
  4. A×V → impossible to keep integer (A is post-softmax float)

So ~half of attention (QK, ALiBi, masking) could become integer, ~30% of total CUDA time. Combined with int8 BitLinear, gets us to:

  • Integer time: 9% (BitLinear) + ~13% (QK + ALiBi) = ~22%
  • Float time: ~78%

Still far from <10% float.

Path B: replace softmax with hard top-K

RWSP-G's "top-K survivors with rank weighting" is ALREADY a step in this direction. With K=1 (argmax + rank weighting=1), we get a hard pointer.

This eliminates both A×V's float dependency and topk's float sorting (since the "top-1" is just argmax → integer). But it would also kill val_bpc by losing softmax's smoothness.

Path C: reframe the goal

"<10% float compute" may not be the right metric. More tractable framings:

  • Storage bits per parameter (already 1-bit on BitLinear weights)
  • GEMM FLOPs in integer mode (achievable: forward + dx backward as int8)
  • Energy per training step (different metric, requires hardware-specific measurement)

Realistic path forward

  1. First validate the existing int8 GEMM kernel on GPU. Check:

    • Triton int8 vs bf16 cuBLAS speedup at training shapes (predicted 1-3×)
    • Numerical correctness (sign-STE gradients match)
    • End-to-end training step throughput
  2. If kernel wins, integer-ize the attention QK matmul. Same pattern as BitLinear. Adds ~13% more integer compute.

  3. Defer "true <10% float" — it requires fundamental architectural changes (e.g., replacing softmax with integer-friendly attention, which is unsolved as far as I know). The int8 GEMM work is a meaningful step that can land in v76 paper or as a follow-up engineering paper.

Bottom line

The int8 kernel work moves us from 100% float to ~91% float (with optimistic GPU benchmarks). To hit <10% float, attention and softmax need fundamental redesigns that go beyond this session's scope. The honest framing: this session enabled 1-bit weights and 1-bit GEMM compute for linear layers; full 1-bit training requires more research than fit in the project window.

The synth K-shift work is independent of the integer-compute goal — it's about making sign-STE training converge to a useful attention pattern, which matters whether the matmul is bf16 or int8.