| --- |
| license: mit |
| tags: |
| - bitnet |
| - speculative-decoding |
| - medusa |
| - ternary-weights |
| - efficient-inference |
| - cpu-inference |
| language: |
| - en |
| base_model: microsoft/BitNet-b1.58-2B-4T |
| library_name: gguf |
| pipeline_tag: text-generation |
| --- |
| |
| # MedusaBitNet 2B-4T |
|
|
| **First integration of [Medusa speculative decoding](https://github.com/FasterDecoding/Medusa) with [BitNet b1.58](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ternary-weight inference.** |
|
|
| 4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates **2.08 tokens per backbone step** β measured end-to-end across 5,136 tokens in 2,473 speculative decode steps. |
|
|
| ## Measured Results |
|
|
| ### Speculative Decoding (real end-to-end, Python, 20 sequences) |
|
|
| | Metric | Value | |
| |---|---| |
| | **Tokens per backbone step** | **2.08** (5,136 tokens / 2,473 steps) | |
| | Head 1 acceptance (t+1) | 63.0% | |
| | Head 2 acceptance (t+2) | 29.0% | |
| | Head 3 acceptance (t+3) | 11.1% | |
| | Head 4 acceptance (t+4) | 4.6% | |
| | Medusa head size | 13 MB (f16) | |
| | Total model size | 764 MB (backbone + heads) | |
|
|
| ### Head-to-Head Throughput (same hardware, same prompts, llama-cli) |
|
|
| | Model | Params | Gen tok/s | Size | |
| |---|---|---|---| |
| | Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB | |
| | Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB | |
| | **BitNet b1.58 2B (I2_S)** | **2.4B** | **72.7** | **1187 MB** | |
| | Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB | |
| |
| All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only. |
| |
| ### Quality (Microsoft published, 18 tasks) |
| |
| BitNet b1.58 2B-4T scores **54.19 avg** β beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β they predict ahead, not modify. |
| |
| ## What's Proven vs What Needs Work |
| |
| **Measured (real data):** |
| - 2.08 tokens per backbone step (end-to-end speculative decode loop) |
| - Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6% |
| - Head-to-head throughput: 4 models on identical hardware |
| - Training: loss 9.85 β 3.32 in 2000 steps on Zen 5 CPU |
| |
| **Not yet proven:** |
| - Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path. |
| - Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step) |
| |
| ## Files |
| |
| - `medusa_heads_step2000.pt` β Trained Medusa head weights (4 heads, 1 layer each, hidden=2560) |
| - `ggml-model-i2_s-medusa.gguf` β Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16) |
| - `benchmark_headtohead.json` β Raw head-to-head benchmark data |
| - `benchmark_results.json` β Efficiency benchmark data |
| - `figures/` β All charts (see below) |
| |
| ## Architecture |
| |
| ``` |
| BitNet b1.58 2B-4T (frozen) 4 Medusa Heads (13 MB) |
| βββββββββββββββββββββββ ββββββββββββββββββββ |
| β 30 layers β β Head 1: t+1 63.0%β |
| β 2560 hidden β ββhβββ β Head 2: t+2 29.0%β βββ 2.08 tok/step |
| β Ternary {-1, 0, 1} β β Head 3: t+3 11.1%β |
| β 751 MB (I2_S) β β Head 4: t+4 4.6%β |
| βββββββββββββββββββββββ ββββββββββββββββββββ |
| ``` |
| |
| Each head: `h + W_out @ SiLU(W_in @ h)`, projected through the shared lm_head. |
| |
| ## Training |
| |
| | Detail | Value | |
| |---|---| |
| | Data | [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (52K examples) | |
| | Method | Cache backbone hidden states, train heads on cached features | |
| | Steps | 2000 (loss 9.85 β 3.32) | |
| | Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only | |
| | Time | ~4h caching + ~7h training = ~11h total | |
| | Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps | |
| |
| ## Usage |
| |
| ### Python (verified working) |
| ```python |
| import torch |
| from model import MedusaHeads |
| |
| ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu") |
| heads = MedusaHeads(hidden_size=2560, vocab_size=128256, |
| num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16) |
| heads.load_state_dict(ckpt["heads"]) |
| ``` |
| |
| ### C++ (loads and runs, speculation pending kernel fix) |
| ```bash |
| cd bitnet.cpp/3rdparty/llama.cpp |
| git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch |
| # Build, then: |
| ./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16 |
| ``` |
| |
| ## Credits |
| |
| - **Medusa:** Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. [ICML 2024](https://arxiv.org/abs/2401.10774). Apache 2.0. |
| - **BitNet b1.58:** Microsoft Research. [Model](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) (MIT). [bitnet.cpp](https://github.com/microsoft/BitNet) (MIT). |
| - **llama.cpp:** Georgi Gerganov et al. MIT. |
| - **Built with:** [Claude Code](https://claude.ai/claude-code) (Anthropic, Opus 4.6) |
|
|
| ## Citation |
|
|
| ```bibtex |
| @misc{corcoran2025medusabitnet, |
| title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs}, |
| author={Parrish Corcoran}, |
| year={2025}, |
| url={https://github.com/parrishcorcoran/MedusaBitNet} |
| } |
| ``` |
|
|
| ## License |
|
|
| MIT |
|
|