Pure 1-bit CPU training — what works and what doesn't
Goal
Build a CPU-only 1-bit (±1) language model training script using "tricks from the C inference" (bitnet.cpp i2_s kernel pattern).
Update: AVX-512 VPOPCNTQ extension WINS
After confirming pure PyTorch couldn't beat BLAS, I built a C extension
(bit_cpu_avx.py) using __builtin_popcountll on uint64-packed signs,
compiled with -march=native -mavx512vpopcntdq for AMD Zen 4 / Intel Ice
Lake+ CPUs.
Kernel-only benchmark (256×1536×1536, 300M training shape)
| Method | Time | Speedup |
|---|---|---|
| AVX-512 popcount (gemm only) | 0.19 ms | 3.34× |
| bf16 BLAS (gemm only) | 0.63 ms | 1.0× |
| pure-PyTorch bit-pack (uint8 LUT) | 143 ms | 0.004× |
The AVX kernel is 3.3× faster than BLAS bf16 at typical training shapes.
End-to-end training (TinyLM, Hamlet text, 200-300 steps)
| Model size | Forward kernel | Final loss | tok/s | E2E speedup |
|---|---|---|---|---|
| d=64 (76K params) | bf16 BLAS | 2.35 | 195K | 1.0× |
| d=64 (76K params) | AVX-512 | 2.06 | 207K | 1.06× |
| d=128 (284K params) | bf16 BLAS | 2.15 | 118K | 1.0× |
| d=128 (284K params) | AVX-512 | 1.81 | 130K | 1.10× |
| d=256 (1.1M params) | bf16 BLAS | 2.07 | 47K | 1.0× |
| d=256 (1.1M params) | AVX-512 | 1.77 | 55K | 1.18× |
End-to-end speedup grows with d (10% → 18% → ...) because the matmul fraction of total compute grows. Importantly, AVX is BOTH faster and more numerically precise (int32 exact dot vs bf16 truncation), so the loss converges to a lower value.
Pure PyTorch (no C extension) is dead-end
| Approach | 512×768×768 GEMM time |
|---|---|
| bf16 BLAS sign-matmul (the practical 1-bit path) | 0.39 ms |
| fp32 BLAS sign-matmul | 0.89 ms |
| int8 sign-matmul (cast back to fp32 for matmul) | 0.95 ms |
| Bit-packed uint8 XOR + LUT popcount | 143.0 ms (367× slower) |
The 256-entry popcount LUT works (0 numerical diff vs naive), but the
pytorch-level xor.unsqueeze(1) ^ w.unsqueeze(0) broadcast and per-byte
LUT lookup dominate the cost. To win, you'd need either:
- A C extension calling
_mm256_sad_epu8or__builtin_popcountll - TVM/Triton CPU codegen with bit ops (Triton has CPU backend; untested)
What does work: BLAS bf16 sign-matmul + STE
train_1bit_cpu.py — 76K-parameter character LM trained on Hamlet:
| Variant | Final loss (300 steps) | tok/s | Quality |
|---|---|---|---|
1-bit STE ((x.sign() @ w.sign().t()) in bf16) |
2.35 | 195K | weak |
| FP32 baseline (same arch) | 0.30 | 236K | good |
The 1-bit model trains correctly (loss 4.15 → 2.35) but is much weaker than fp32 at this small scale. 1-bit needs more parameters to compensate for low precision; the 76K size is too small to matter.
Speed: 1-bit is 17% slower than fp32 on CPU because:
- bf16 BLAS GEMM ≈ fp32 BLAS GEMM speed at small sizes (no Tensor Cores on CPU, just SIMD)
- Sign'ing inputs adds an extra elementwise pass
- STE backward needs sign() of weights too
Files
| File | Purpose |
|---|---|
bit_cpu_kernel.py |
Pure-PyTorch bit-packed XNOR+LUT popcount. Correct but slow. |
bit_cpu_v2.py |
Benchmark vs alternatives (fp32, bf16, int8). |
train_1bit_cpu.py |
Practical CPU 1-bit training (uses BLAS bf16 path). |
train_fp32_cpu_baseline.py |
Same arch with fp32 — control. |
bench_bit_cpu.py |
GEMM-only timing. |
To actually win on CPU
- Build a C extension wrapping
__builtin_popcountllfor 64-bit chunks - Pack weights 64-bit-aligned (K must be multiple of 64)
- The matmul becomes
popcount(X[m] XOR W[n])per (m,n) pair - With AVX-512 VPOPCNTQ (Ice Lake+), peak ~1 popcount/cycle/lane
- This is what bitnet.cpp's i2_s kernel does (for ternary)
For ±1 binary specifically, the math is simpler than ternary:
dot(±1 vector a, ±1 vector b) = K - 2 * popcount(pack(a) XOR pack(b))- No zero handling needed (ternary i2_s kernel handles -1, 0, +1)
Honest takeaway (revised)
The original conclusion ("pure PyTorch can't beat BLAS, just use bf16
sign-matmul") was correct for pure PyTorch. But torch.utils.cpp_extension
gives us trivial access to __builtin_popcountll. With a 90-line C extension
compiled with -march=native -mavx512vpopcntdq, we get 3.3× kernel speedup
and 10-18% end-to-end training speedup at typical sizes — AND the int32
exact dot product is more accurate than bf16 BLAS, leading to better loss
convergence.
This is the actual analog of bitnet.cpp's inference path, but applied to
training: pack signs into uint64, XOR, popcount, count = K - 2*pop, fp32
output. Backward stays STE through bf16 BLAS for dx and dw (sign'd
inputs).
Files
| File | Purpose |
|---|---|
bit_cpu_avx.py |
AVX-512 popcount C extension (the win) |
bench_bit_avx.py |
Benchmark vs BLAS bf16 |
train_1bit_cpu_avx.py |
Production CPU 1-bit training with AVX kernel |
bit_cpu_kernel.py |
Pure-PyTorch bit-packed (slow; for comparison) |
bit_cpu_v2.py |
Multi-strategy benchmark |
train_1bit_cpu.py |
Pure-PyTorch (bf16 BLAS) baseline |
train_fp32_cpu_baseline.py |
Same arch fp32 — control |
bench_bit_cpu.py |
GEMM-only timing comparison |