| # Pure 1-bit CPU training — what works and what doesn't |
|
|
| ## Goal |
|
|
| Build a CPU-only 1-bit (±1) language model training script using "tricks |
| from the C inference" (bitnet.cpp i2_s kernel pattern). |
| |
| ## Update: AVX-512 VPOPCNTQ extension WINS |
| |
| After confirming pure PyTorch couldn't beat BLAS, I built a C extension |
| (`bit_cpu_avx.py`) using `__builtin_popcountll` on uint64-packed signs, |
| compiled with `-march=native -mavx512vpopcntdq` for AMD Zen 4 / Intel Ice |
| Lake+ CPUs. |
| |
| ### Kernel-only benchmark (256×1536×1536, 300M training shape) |
| |
| | Method | Time | Speedup | |
| |---|---|---| |
| | **AVX-512 popcount (gemm only)** | **0.19 ms** | **3.34×** | |
| | bf16 BLAS (gemm only) | 0.63 ms | 1.0× | |
| | pure-PyTorch bit-pack (uint8 LUT) | 143 ms | 0.004× | |
| |
| The AVX kernel is 3.3× faster than BLAS bf16 at typical training shapes. |
| |
| ### End-to-end training (TinyLM, Hamlet text, 200-300 steps) |
| |
| | Model size | Forward kernel | Final loss | tok/s | E2E speedup | |
| |---|---|---|---|---| |
| | d=64 (76K params) | bf16 BLAS | 2.35 | 195K | 1.0× | |
| | d=64 (76K params) | AVX-512 | 2.06 | 207K | 1.06× | |
| | d=128 (284K params) | bf16 BLAS | 2.15 | 118K | 1.0× | |
| | d=128 (284K params) | AVX-512 | 1.81 | 130K | 1.10× | |
| | d=256 (1.1M params) | bf16 BLAS | 2.07 | 47K | 1.0× | |
| | **d=256 (1.1M params)** | **AVX-512** | **1.77** | **55K** | **1.18×** | |
| |
| End-to-end speedup grows with d (10% → 18% → ...) because the matmul |
| fraction of total compute grows. Importantly, **AVX is BOTH faster and |
| more numerically precise** (int32 exact dot vs bf16 truncation), so |
| the loss converges to a lower value. |
| |
| ## Pure PyTorch (no C extension) is dead-end |
| |
| | Approach | 512×768×768 GEMM time | |
| |---|---| |
| | **bf16 BLAS sign-matmul** (the practical 1-bit path) | **0.39 ms** | |
| | fp32 BLAS sign-matmul | 0.89 ms | |
| | int8 sign-matmul (cast back to fp32 for matmul) | 0.95 ms | |
| | Bit-packed uint8 XOR + LUT popcount | 143.0 ms (367× slower) | |
| |
| The 256-entry popcount LUT works (0 numerical diff vs naive), but the |
| pytorch-level `xor.unsqueeze(1) ^ w.unsqueeze(0)` broadcast and per-byte |
| LUT lookup dominate the cost. To win, you'd need either: |
| - A C extension calling `_mm256_sad_epu8` or `__builtin_popcountll` |
| - TVM/Triton CPU codegen with bit ops (Triton has CPU backend; untested) |
|
|
| ## What does work: BLAS bf16 sign-matmul + STE |
|
|
| `train_1bit_cpu.py` — 76K-parameter character LM trained on Hamlet: |
|
|
| | Variant | Final loss (300 steps) | tok/s | Quality | |
| |---|---|---|---| |
| | 1-bit STE (`(x.sign() @ w.sign().t())` in bf16) | 2.35 | 195K | weak | |
| | FP32 baseline (same arch) | 0.30 | 236K | good | |
|
|
| The 1-bit model trains correctly (loss 4.15 → 2.35) but is much weaker |
| than fp32 at this small scale. 1-bit needs more parameters to compensate |
| for low precision; the 76K size is too small to matter. |
|
|
| Speed: 1-bit is **17% slower** than fp32 on CPU because: |
| - bf16 BLAS GEMM ≈ fp32 BLAS GEMM speed at small sizes (no Tensor Cores |
| on CPU, just SIMD) |
| - Sign'ing inputs adds an extra elementwise pass |
| - STE backward needs sign() of weights too |
|
|
| ## Files |
|
|
| | File | Purpose | |
| |---|---| |
| | `bit_cpu_kernel.py` | Pure-PyTorch bit-packed XNOR+LUT popcount. Correct but slow. | |
| | `bit_cpu_v2.py` | Benchmark vs alternatives (fp32, bf16, int8). | |
| | `train_1bit_cpu.py` | Practical CPU 1-bit training (uses BLAS bf16 path). | |
| | `train_fp32_cpu_baseline.py` | Same arch with fp32 — control. | |
| | `bench_bit_cpu.py` | GEMM-only timing. | |
|
|
| ## To actually win on CPU |
|
|
| 1. Build a C extension wrapping `__builtin_popcountll` for 64-bit chunks |
| 2. Pack weights 64-bit-aligned (K must be multiple of 64) |
| 3. The matmul becomes `popcount(X[m] XOR W[n])` per (m,n) pair |
| 4. With AVX-512 VPOPCNTQ (Ice Lake+), peak ~1 popcount/cycle/lane |
| 5. This is what bitnet.cpp's i2_s kernel does (for ternary) |
| |
| For ±1 binary specifically, the math is simpler than ternary: |
| - `dot(±1 vector a, ±1 vector b) = K - 2 * popcount(pack(a) XOR pack(b))` |
| - No zero handling needed (ternary i2_s kernel handles -1, 0, +1) |
|
|
| ## Honest takeaway (revised) |
|
|
| The original conclusion ("pure PyTorch can't beat BLAS, just use bf16 |
| sign-matmul") was correct *for pure PyTorch*. But torch.utils.cpp_extension |
| gives us trivial access to `__builtin_popcountll`. With a 90-line C extension |
| compiled with `-march=native -mavx512vpopcntdq`, we get 3.3× kernel speedup |
| and 10-18% end-to-end training speedup at typical sizes — AND the int32 |
| exact dot product is more accurate than bf16 BLAS, leading to better loss |
| convergence. |
| |
| **This is the actual analog of bitnet.cpp's inference path, but applied to |
| training**: pack signs into uint64, XOR, popcount, count = K - 2*pop, fp32 |
| output. Backward stays STE through bf16 BLAS for `dx` and `dw` (sign'd |
| inputs). |
| |
| ## Files |
| |
| | File | Purpose | |
| |---|---| |
| | `bit_cpu_avx.py` | **AVX-512 popcount C extension** (the win) | |
| | `bench_bit_avx.py` | Benchmark vs BLAS bf16 | |
| | `train_1bit_cpu_avx.py` | **Production CPU 1-bit training** with AVX kernel | |
| | `bit_cpu_kernel.py` | Pure-PyTorch bit-packed (slow; for comparison) | |
| | `bit_cpu_v2.py` | Multi-strategy benchmark | |
| | `train_1bit_cpu.py` | Pure-PyTorch (bf16 BLAS) baseline | |
| | `train_fp32_cpu_baseline.py` | Same arch fp32 — control | |
| | `bench_bit_cpu.py` | GEMM-only timing comparison | |
| |