--- license: mit tags: - bitnet - speculative-decoding - medusa - ternary-weights - efficient-inference - cpu-inference language: - en base_model: microsoft/BitNet-b1.58-2B-4T library_name: gguf pipeline_tag: text-generation --- # MedusaBitNet 2B-4T **First integration of [Medusa speculative decoding](https://github.com/FasterDecoding/Medusa) with [BitNet b1.58](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ternary-weight inference.** 4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates **2.08 tokens per backbone step** — measured end-to-end across 5,136 tokens in 2,473 speculative decode steps. ## Measured Results ### Speculative Decoding (real end-to-end, Python, 20 sequences) | Metric | Value | |---|---| | **Tokens per backbone step** | **2.08** (5,136 tokens / 2,473 steps) | | Head 1 acceptance (t+1) | 63.0% | | Head 2 acceptance (t+2) | 29.0% | | Head 3 acceptance (t+3) | 11.1% | | Head 4 acceptance (t+4) | 4.6% | | Medusa head size | 13 MB (f16) | | Total model size | 764 MB (backbone + heads) | ### Head-to-Head Throughput (same hardware, same prompts, llama-cli) | Model | Params | Gen tok/s | Size | |---|---|---|---| | Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB | | Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB | | **BitNet b1.58 2B (I2_S)** | **2.4B** | **72.7** | **1187 MB** | | Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB | All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only. ### Quality (Microsoft published, 18 tasks) BitNet b1.58 2B-4T scores **54.19 avg** — beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality — they predict ahead, not modify. ## What's Proven vs What Needs Work **Measured (real data):** - 2.08 tokens per backbone step (end-to-end speculative decode loop) - Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6% - Head-to-head throughput: 4 models on identical hardware - Training: loss 9.85 → 3.32 in 2000 steps on Zen 5 CPU **Not yet proven:** - Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path. - Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step) ## Files - `medusa_heads_step2000.pt` — Trained Medusa head weights (4 heads, 1 layer each, hidden=2560) - `ggml-model-i2_s-medusa.gguf` — Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16) - `benchmark_headtohead.json` — Raw head-to-head benchmark data - `benchmark_results.json` — Efficiency benchmark data - `figures/` — All charts (see below) ## Architecture ``` BitNet b1.58 2B-4T (frozen) 4 Medusa Heads (13 MB) ┌─────────────────────┐ ┌──────────────────┐ │ 30 layers │ │ Head 1: t+1 63.0%│ │ 2560 hidden │ ──h──→ │ Head 2: t+2 29.0%│ ──→ 2.08 tok/step │ Ternary {-1, 0, 1} │ │ Head 3: t+3 11.1%│ │ 751 MB (I2_S) │ │ Head 4: t+4 4.6%│ └─────────────────────┘ └──────────────────┘ ``` Each head: `h + W_out @ SiLU(W_in @ h)`, projected through the shared lm_head. ## Training | Detail | Value | |---|---| | Data | [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (52K examples) | | Method | Cache backbone hidden states, train heads on cached features | | Steps | 2000 (loss 9.85 → 3.32) | | Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only | | Time | ~4h caching + ~7h training = ~11h total | | Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps | ## Usage ### Python (verified working) ```python import torch from model import MedusaHeads ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu") heads = MedusaHeads(hidden_size=2560, vocab_size=128256, num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16) heads.load_state_dict(ckpt["heads"]) ``` ### C++ (loads and runs, speculation pending kernel fix) ```bash cd bitnet.cpp/3rdparty/llama.cpp git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch # Build, then: ./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16 ``` ## Credits - **Medusa:** Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. [ICML 2024](https://arxiv.org/abs/2401.10774). Apache 2.0. - **BitNet b1.58:** Microsoft Research. [Model](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) (MIT). [bitnet.cpp](https://github.com/microsoft/BitNet) (MIT). - **llama.cpp:** Georgi Gerganov et al. MIT. - **Built with:** [Claude Code](https://claude.ai/claude-code) (Anthropic, Opus 4.6) ## Citation ```bibtex @misc{corcoran2025medusabitnet, title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs}, author={Parrish Corcoran}, year={2025}, url={https://github.com/parrishcorcoran/MedusaBitNet} } ``` ## License MIT