# Pure 1-bit CPU training — what works and what doesn't ## Goal Build a CPU-only 1-bit (±1) language model training script using "tricks from the C inference" (bitnet.cpp i2_s kernel pattern). ## Update: AVX-512 VPOPCNTQ extension WINS After confirming pure PyTorch couldn't beat BLAS, I built a C extension (`bit_cpu_avx.py`) using `__builtin_popcountll` on uint64-packed signs, compiled with `-march=native -mavx512vpopcntdq` for AMD Zen 4 / Intel Ice Lake+ CPUs. ### Kernel-only benchmark (256×1536×1536, 300M training shape) | Method | Time | Speedup | |---|---|---| | **AVX-512 popcount (gemm only)** | **0.19 ms** | **3.34×** | | bf16 BLAS (gemm only) | 0.63 ms | 1.0× | | pure-PyTorch bit-pack (uint8 LUT) | 143 ms | 0.004× | The AVX kernel is 3.3× faster than BLAS bf16 at typical training shapes. ### End-to-end training (TinyLM, Hamlet text, 200-300 steps) | Model size | Forward kernel | Final loss | tok/s | E2E speedup | |---|---|---|---|---| | d=64 (76K params) | bf16 BLAS | 2.35 | 195K | 1.0× | | d=64 (76K params) | AVX-512 | 2.06 | 207K | 1.06× | | d=128 (284K params) | bf16 BLAS | 2.15 | 118K | 1.0× | | d=128 (284K params) | AVX-512 | 1.81 | 130K | 1.10× | | d=256 (1.1M params) | bf16 BLAS | 2.07 | 47K | 1.0× | | **d=256 (1.1M params)** | **AVX-512** | **1.77** | **55K** | **1.18×** | End-to-end speedup grows with d (10% → 18% → ...) because the matmul fraction of total compute grows. Importantly, **AVX is BOTH faster and more numerically precise** (int32 exact dot vs bf16 truncation), so the loss converges to a lower value. ## Pure PyTorch (no C extension) is dead-end | Approach | 512×768×768 GEMM time | |---|---| | **bf16 BLAS sign-matmul** (the practical 1-bit path) | **0.39 ms** | | fp32 BLAS sign-matmul | 0.89 ms | | int8 sign-matmul (cast back to fp32 for matmul) | 0.95 ms | | Bit-packed uint8 XOR + LUT popcount | 143.0 ms (367× slower) | The 256-entry popcount LUT works (0 numerical diff vs naive), but the pytorch-level `xor.unsqueeze(1) ^ w.unsqueeze(0)` broadcast and per-byte LUT lookup dominate the cost. To win, you'd need either: - A C extension calling `_mm256_sad_epu8` or `__builtin_popcountll` - TVM/Triton CPU codegen with bit ops (Triton has CPU backend; untested) ## What does work: BLAS bf16 sign-matmul + STE `train_1bit_cpu.py` — 76K-parameter character LM trained on Hamlet: | Variant | Final loss (300 steps) | tok/s | Quality | |---|---|---|---| | 1-bit STE (`(x.sign() @ w.sign().t())` in bf16) | 2.35 | 195K | weak | | FP32 baseline (same arch) | 0.30 | 236K | good | The 1-bit model trains correctly (loss 4.15 → 2.35) but is much weaker than fp32 at this small scale. 1-bit needs more parameters to compensate for low precision; the 76K size is too small to matter. Speed: 1-bit is **17% slower** than fp32 on CPU because: - bf16 BLAS GEMM ≈ fp32 BLAS GEMM speed at small sizes (no Tensor Cores on CPU, just SIMD) - Sign'ing inputs adds an extra elementwise pass - STE backward needs sign() of weights too ## Files | File | Purpose | |---|---| | `bit_cpu_kernel.py` | Pure-PyTorch bit-packed XNOR+LUT popcount. Correct but slow. | | `bit_cpu_v2.py` | Benchmark vs alternatives (fp32, bf16, int8). | | `train_1bit_cpu.py` | Practical CPU 1-bit training (uses BLAS bf16 path). | | `train_fp32_cpu_baseline.py` | Same arch with fp32 — control. | | `bench_bit_cpu.py` | GEMM-only timing. | ## To actually win on CPU 1. Build a C extension wrapping `__builtin_popcountll` for 64-bit chunks 2. Pack weights 64-bit-aligned (K must be multiple of 64) 3. The matmul becomes `popcount(X[m] XOR W[n])` per (m,n) pair 4. With AVX-512 VPOPCNTQ (Ice Lake+), peak ~1 popcount/cycle/lane 5. This is what bitnet.cpp's i2_s kernel does (for ternary) For ±1 binary specifically, the math is simpler than ternary: - `dot(±1 vector a, ±1 vector b) = K - 2 * popcount(pack(a) XOR pack(b))` - No zero handling needed (ternary i2_s kernel handles -1, 0, +1) ## Honest takeaway (revised) The original conclusion ("pure PyTorch can't beat BLAS, just use bf16 sign-matmul") was correct *for pure PyTorch*. But torch.utils.cpp_extension gives us trivial access to `__builtin_popcountll`. With a 90-line C extension compiled with `-march=native -mavx512vpopcntdq`, we get 3.3× kernel speedup and 10-18% end-to-end training speedup at typical sizes — AND the int32 exact dot product is more accurate than bf16 BLAS, leading to better loss convergence. **This is the actual analog of bitnet.cpp's inference path, but applied to training**: pack signs into uint64, XOR, popcount, count = K - 2*pop, fp32 output. Backward stays STE through bf16 BLAS for `dx` and `dw` (sign'd inputs). ## Files | File | Purpose | |---|---| | `bit_cpu_avx.py` | **AVX-512 popcount C extension** (the win) | | `bench_bit_avx.py` | Benchmark vs BLAS bf16 | | `train_1bit_cpu_avx.py` | **Production CPU 1-bit training** with AVX kernel | | `bit_cpu_kernel.py` | Pure-PyTorch bit-packed (slow; for comparison) | | `bit_cpu_v2.py` | Multi-strategy benchmark | | `train_1bit_cpu.py` | Pure-PyTorch (bf16 BLAS) baseline | | `train_fp32_cpu_baseline.py` | Same arch fp32 — control | | `bench_bit_cpu.py` | GEMM-only timing comparison |