MedusaBitNet-2B-4T / README.md
parrishcorcoran's picture
Upload README.md with huggingface_hub
5b11478 verified
---
license: mit
tags:
- bitnet
- speculative-decoding
- medusa
- ternary-weights
- efficient-inference
- cpu-inference
language:
- en
base_model: microsoft/BitNet-b1.58-2B-4T
library_name: gguf
pipeline_tag: text-generation
---
# MedusaBitNet 2B-4T
**First integration of [Medusa speculative decoding](https://github.com/FasterDecoding/Medusa) with [BitNet b1.58](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ternary-weight inference.**
4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates **2.08 tokens per backbone step** β€” measured end-to-end across 5,136 tokens in 2,473 speculative decode steps.
## Measured Results
### Speculative Decoding (real end-to-end, Python, 20 sequences)
| Metric | Value |
|---|---|
| **Tokens per backbone step** | **2.08** (5,136 tokens / 2,473 steps) |
| Head 1 acceptance (t+1) | 63.0% |
| Head 2 acceptance (t+2) | 29.0% |
| Head 3 acceptance (t+3) | 11.1% |
| Head 4 acceptance (t+4) | 4.6% |
| Medusa head size | 13 MB (f16) |
| Total model size | 764 MB (backbone + heads) |
### Head-to-Head Throughput (same hardware, same prompts, llama-cli)
| Model | Params | Gen tok/s | Size |
|---|---|---|---|
| Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB |
| Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB |
| **BitNet b1.58 2B (I2_S)** | **2.4B** | **72.7** | **1187 MB** |
| Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB |
All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only.
### Quality (Microsoft published, 18 tasks)
BitNet b1.58 2B-4T scores **54.19 avg** β€” beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β€” they predict ahead, not modify.
## What's Proven vs What Needs Work
**Measured (real data):**
- 2.08 tokens per backbone step (end-to-end speculative decode loop)
- Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6%
- Head-to-head throughput: 4 models on identical hardware
- Training: loss 9.85 β†’ 3.32 in 2000 steps on Zen 5 CPU
**Not yet proven:**
- Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path.
- Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step)
## Files
- `medusa_heads_step2000.pt` β€” Trained Medusa head weights (4 heads, 1 layer each, hidden=2560)
- `ggml-model-i2_s-medusa.gguf` β€” Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16)
- `benchmark_headtohead.json` β€” Raw head-to-head benchmark data
- `benchmark_results.json` β€” Efficiency benchmark data
- `figures/` β€” All charts (see below)
## Architecture
```
BitNet b1.58 2B-4T (frozen) 4 Medusa Heads (13 MB)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ 30 layers β”‚ β”‚ Head 1: t+1 63.0%β”‚
β”‚ 2560 hidden β”‚ ──h──→ β”‚ Head 2: t+2 29.0%β”‚ ──→ 2.08 tok/step
β”‚ Ternary {-1, 0, 1} β”‚ β”‚ Head 3: t+3 11.1%β”‚
β”‚ 751 MB (I2_S) β”‚ β”‚ Head 4: t+4 4.6%β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
Each head: `h + W_out @ SiLU(W_in @ h)`, projected through the shared lm_head.
## Training
| Detail | Value |
|---|---|
| Data | [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (52K examples) |
| Method | Cache backbone hidden states, train heads on cached features |
| Steps | 2000 (loss 9.85 β†’ 3.32) |
| Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only |
| Time | ~4h caching + ~7h training = ~11h total |
| Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps |
## Usage
### Python (verified working)
```python
import torch
from model import MedusaHeads
ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu")
heads = MedusaHeads(hidden_size=2560, vocab_size=128256,
num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16)
heads.load_state_dict(ckpt["heads"])
```
### C++ (loads and runs, speculation pending kernel fix)
```bash
cd bitnet.cpp/3rdparty/llama.cpp
git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch
# Build, then:
./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16
```
## Credits
- **Medusa:** Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. [ICML 2024](https://arxiv.org/abs/2401.10774). Apache 2.0.
- **BitNet b1.58:** Microsoft Research. [Model](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) (MIT). [bitnet.cpp](https://github.com/microsoft/BitNet) (MIT).
- **llama.cpp:** Georgi Gerganov et al. MIT.
- **Built with:** [Claude Code](https://claude.ai/claude-code) (Anthropic, Opus 4.6)
## Citation
```bibtex
@misc{corcoran2025medusabitnet,
title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs},
author={Parrish Corcoran},
year={2025},
url={https://github.com/parrishcorcoran/MedusaBitNet}
}
```
## License
MIT