bitnet-1bitllm / notes /cpu_1bit_training.md
hidude562's picture
1bitllm code (checkpoints to follow)
4754707 verified

Pure 1-bit CPU training — what works and what doesn't

Goal

Build a CPU-only 1-bit (±1) language model training script using "tricks from the C inference" (bitnet.cpp i2_s kernel pattern).

Update: AVX-512 VPOPCNTQ extension WINS

After confirming pure PyTorch couldn't beat BLAS, I built a C extension (bit_cpu_avx.py) using __builtin_popcountll on uint64-packed signs, compiled with -march=native -mavx512vpopcntdq for AMD Zen 4 / Intel Ice Lake+ CPUs.

Kernel-only benchmark (256×1536×1536, 300M training shape)

Method Time Speedup
AVX-512 popcount (gemm only) 0.19 ms 3.34×
bf16 BLAS (gemm only) 0.63 ms 1.0×
pure-PyTorch bit-pack (uint8 LUT) 143 ms 0.004×

The AVX kernel is 3.3× faster than BLAS bf16 at typical training shapes.

End-to-end training (TinyLM, Hamlet text, 200-300 steps)

Model size Forward kernel Final loss tok/s E2E speedup
d=64 (76K params) bf16 BLAS 2.35 195K 1.0×
d=64 (76K params) AVX-512 2.06 207K 1.06×
d=128 (284K params) bf16 BLAS 2.15 118K 1.0×
d=128 (284K params) AVX-512 1.81 130K 1.10×
d=256 (1.1M params) bf16 BLAS 2.07 47K 1.0×
d=256 (1.1M params) AVX-512 1.77 55K 1.18×

End-to-end speedup grows with d (10% → 18% → ...) because the matmul fraction of total compute grows. Importantly, AVX is BOTH faster and more numerically precise (int32 exact dot vs bf16 truncation), so the loss converges to a lower value.

Pure PyTorch (no C extension) is dead-end

Approach 512×768×768 GEMM time
bf16 BLAS sign-matmul (the practical 1-bit path) 0.39 ms
fp32 BLAS sign-matmul 0.89 ms
int8 sign-matmul (cast back to fp32 for matmul) 0.95 ms
Bit-packed uint8 XOR + LUT popcount 143.0 ms (367× slower)

The 256-entry popcount LUT works (0 numerical diff vs naive), but the pytorch-level xor.unsqueeze(1) ^ w.unsqueeze(0) broadcast and per-byte LUT lookup dominate the cost. To win, you'd need either:

  • A C extension calling _mm256_sad_epu8 or __builtin_popcountll
  • TVM/Triton CPU codegen with bit ops (Triton has CPU backend; untested)

What does work: BLAS bf16 sign-matmul + STE

train_1bit_cpu.py — 76K-parameter character LM trained on Hamlet:

Variant Final loss (300 steps) tok/s Quality
1-bit STE ((x.sign() @ w.sign().t()) in bf16) 2.35 195K weak
FP32 baseline (same arch) 0.30 236K good

The 1-bit model trains correctly (loss 4.15 → 2.35) but is much weaker than fp32 at this small scale. 1-bit needs more parameters to compensate for low precision; the 76K size is too small to matter.

Speed: 1-bit is 17% slower than fp32 on CPU because:

  • bf16 BLAS GEMM ≈ fp32 BLAS GEMM speed at small sizes (no Tensor Cores on CPU, just SIMD)
  • Sign'ing inputs adds an extra elementwise pass
  • STE backward needs sign() of weights too

Files

File Purpose
bit_cpu_kernel.py Pure-PyTorch bit-packed XNOR+LUT popcount. Correct but slow.
bit_cpu_v2.py Benchmark vs alternatives (fp32, bf16, int8).
train_1bit_cpu.py Practical CPU 1-bit training (uses BLAS bf16 path).
train_fp32_cpu_baseline.py Same arch with fp32 — control.
bench_bit_cpu.py GEMM-only timing.

To actually win on CPU

  1. Build a C extension wrapping __builtin_popcountll for 64-bit chunks
  2. Pack weights 64-bit-aligned (K must be multiple of 64)
  3. The matmul becomes popcount(X[m] XOR W[n]) per (m,n) pair
  4. With AVX-512 VPOPCNTQ (Ice Lake+), peak ~1 popcount/cycle/lane
  5. This is what bitnet.cpp's i2_s kernel does (for ternary)

For ±1 binary specifically, the math is simpler than ternary:

  • dot(±1 vector a, ±1 vector b) = K - 2 * popcount(pack(a) XOR pack(b))
  • No zero handling needed (ternary i2_s kernel handles -1, 0, +1)

Honest takeaway (revised)

The original conclusion ("pure PyTorch can't beat BLAS, just use bf16 sign-matmul") was correct for pure PyTorch. But torch.utils.cpp_extension gives us trivial access to __builtin_popcountll. With a 90-line C extension compiled with -march=native -mavx512vpopcntdq, we get 3.3× kernel speedup and 10-18% end-to-end training speedup at typical sizes — AND the int32 exact dot product is more accurate than bf16 BLAS, leading to better loss convergence.

This is the actual analog of bitnet.cpp's inference path, but applied to training: pack signs into uint64, XOR, popcount, count = K - 2*pop, fp32 output. Backward stays STE through bf16 BLAS for dx and dw (sign'd inputs).

Files

File Purpose
bit_cpu_avx.py AVX-512 popcount C extension (the win)
bench_bit_avx.py Benchmark vs BLAS bf16
train_1bit_cpu_avx.py Production CPU 1-bit training with AVX kernel
bit_cpu_kernel.py Pure-PyTorch bit-packed (slow; for comparison)
bit_cpu_v2.py Multi-strategy benchmark
train_1bit_cpu.py Pure-PyTorch (bf16 BLAS) baseline
train_fp32_cpu_baseline.py Same arch fp32 — control
bench_bit_cpu.py GEMM-only timing comparison