bitnet-1bitllm / notes /cpu_1bit_training.md
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified
# Pure 1-bit CPU training — what works and what doesn't
## Goal
Build a CPU-only 1-bit (±1) language model training script using "tricks
from the C inference" (bitnet.cpp i2_s kernel pattern).
## Update: AVX-512 VPOPCNTQ extension WINS
After confirming pure PyTorch couldn't beat BLAS, I built a C extension
(`bit_cpu_avx.py`) using `__builtin_popcountll` on uint64-packed signs,
compiled with `-march=native -mavx512vpopcntdq` for AMD Zen 4 / Intel Ice
Lake+ CPUs.
### Kernel-only benchmark (256×1536×1536, 300M training shape)
| Method | Time | Speedup |
|---|---|---|
| **AVX-512 popcount (gemm only)** | **0.19 ms** | **3.34×** |
| bf16 BLAS (gemm only) | 0.63 ms | 1.0× |
| pure-PyTorch bit-pack (uint8 LUT) | 143 ms | 0.004× |
The AVX kernel is 3.3× faster than BLAS bf16 at typical training shapes.
### End-to-end training (TinyLM, Hamlet text, 200-300 steps)
| Model size | Forward kernel | Final loss | tok/s | E2E speedup |
|---|---|---|---|---|
| d=64 (76K params) | bf16 BLAS | 2.35 | 195K | 1.0× |
| d=64 (76K params) | AVX-512 | 2.06 | 207K | 1.06× |
| d=128 (284K params) | bf16 BLAS | 2.15 | 118K | 1.0× |
| d=128 (284K params) | AVX-512 | 1.81 | 130K | 1.10× |
| d=256 (1.1M params) | bf16 BLAS | 2.07 | 47K | 1.0× |
| **d=256 (1.1M params)** | **AVX-512** | **1.77** | **55K** | **1.18×** |
End-to-end speedup grows with d (10% → 18% → ...) because the matmul
fraction of total compute grows. Importantly, **AVX is BOTH faster and
more numerically precise** (int32 exact dot vs bf16 truncation), so
the loss converges to a lower value.
## Pure PyTorch (no C extension) is dead-end
| Approach | 512×768×768 GEMM time |
|---|---|
| **bf16 BLAS sign-matmul** (the practical 1-bit path) | **0.39 ms** |
| fp32 BLAS sign-matmul | 0.89 ms |
| int8 sign-matmul (cast back to fp32 for matmul) | 0.95 ms |
| Bit-packed uint8 XOR + LUT popcount | 143.0 ms (367× slower) |
The 256-entry popcount LUT works (0 numerical diff vs naive), but the
pytorch-level `xor.unsqueeze(1) ^ w.unsqueeze(0)` broadcast and per-byte
LUT lookup dominate the cost. To win, you'd need either:
- A C extension calling `_mm256_sad_epu8` or `__builtin_popcountll`
- TVM/Triton CPU codegen with bit ops (Triton has CPU backend; untested)
## What does work: BLAS bf16 sign-matmul + STE
`train_1bit_cpu.py` — 76K-parameter character LM trained on Hamlet:
| Variant | Final loss (300 steps) | tok/s | Quality |
|---|---|---|---|
| 1-bit STE (`(x.sign() @ w.sign().t())` in bf16) | 2.35 | 195K | weak |
| FP32 baseline (same arch) | 0.30 | 236K | good |
The 1-bit model trains correctly (loss 4.15 → 2.35) but is much weaker
than fp32 at this small scale. 1-bit needs more parameters to compensate
for low precision; the 76K size is too small to matter.
Speed: 1-bit is **17% slower** than fp32 on CPU because:
- bf16 BLAS GEMM ≈ fp32 BLAS GEMM speed at small sizes (no Tensor Cores
on CPU, just SIMD)
- Sign'ing inputs adds an extra elementwise pass
- STE backward needs sign() of weights too
## Files
| File | Purpose |
|---|---|
| `bit_cpu_kernel.py` | Pure-PyTorch bit-packed XNOR+LUT popcount. Correct but slow. |
| `bit_cpu_v2.py` | Benchmark vs alternatives (fp32, bf16, int8). |
| `train_1bit_cpu.py` | Practical CPU 1-bit training (uses BLAS bf16 path). |
| `train_fp32_cpu_baseline.py` | Same arch with fp32 — control. |
| `bench_bit_cpu.py` | GEMM-only timing. |
## To actually win on CPU
1. Build a C extension wrapping `__builtin_popcountll` for 64-bit chunks
2. Pack weights 64-bit-aligned (K must be multiple of 64)
3. The matmul becomes `popcount(X[m] XOR W[n])` per (m,n) pair
4. With AVX-512 VPOPCNTQ (Ice Lake+), peak ~1 popcount/cycle/lane
5. This is what bitnet.cpp's i2_s kernel does (for ternary)
For ±1 binary specifically, the math is simpler than ternary:
- `dot(±1 vector a, ±1 vector b) = K - 2 * popcount(pack(a) XOR pack(b))`
- No zero handling needed (ternary i2_s kernel handles -1, 0, +1)
## Honest takeaway (revised)
The original conclusion ("pure PyTorch can't beat BLAS, just use bf16
sign-matmul") was correct *for pure PyTorch*. But torch.utils.cpp_extension
gives us trivial access to `__builtin_popcountll`. With a 90-line C extension
compiled with `-march=native -mavx512vpopcntdq`, we get 3.3× kernel speedup
and 10-18% end-to-end training speedup at typical sizes — AND the int32
exact dot product is more accurate than bf16 BLAS, leading to better loss
convergence.
**This is the actual analog of bitnet.cpp's inference path, but applied to
training**: pack signs into uint64, XOR, popcount, count = K - 2*pop, fp32
output. Backward stays STE through bf16 BLAS for `dx` and `dw` (sign'd
inputs).
## Files
| File | Purpose |
|---|---|
| `bit_cpu_avx.py` | **AVX-512 popcount C extension** (the win) |
| `bench_bit_avx.py` | Benchmark vs BLAS bf16 |
| `train_1bit_cpu_avx.py` | **Production CPU 1-bit training** with AVX kernel |
| `bit_cpu_kernel.py` | Pure-PyTorch bit-packed (slow; for comparison) |
| `bit_cpu_v2.py` | Multi-strategy benchmark |
| `train_1bit_cpu.py` | Pure-PyTorch (bf16 BLAS) baseline |
| `train_fp32_cpu_baseline.py` | Same arch fp32 — control |
| `bench_bit_cpu.py` | GEMM-only timing comparison |