metadata
license: mit
tags:
- bitnet
- speculative-decoding
- medusa
- ternary-weights
- efficient-inference
- cpu-inference
language:
- en
base_model: microsoft/BitNet-b1.58-2B-4T
library_name: gguf
pipeline_tag: text-generation
MedusaBitNet 2B-4T
First integration of Medusa speculative decoding with BitNet b1.58 ternary-weight inference.
4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates 2.08 tokens per backbone step β measured end-to-end across 5,136 tokens in 2,473 speculative decode steps.
Measured Results
Speculative Decoding (real end-to-end, Python, 20 sequences)
| Metric | Value |
|---|---|
| Tokens per backbone step | 2.08 (5,136 tokens / 2,473 steps) |
| Head 1 acceptance (t+1) | 63.0% |
| Head 2 acceptance (t+2) | 29.0% |
| Head 3 acceptance (t+3) | 11.1% |
| Head 4 acceptance (t+4) | 4.6% |
| Medusa head size | 13 MB (f16) |
| Total model size | 764 MB (backbone + heads) |
Head-to-Head Throughput (same hardware, same prompts, llama-cli)
| Model | Params | Gen tok/s | Size |
|---|---|---|---|
| Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB |
| Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB |
| BitNet b1.58 2B (I2_S) | 2.4B | 72.7 | 1187 MB |
| Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB |
All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only.
Quality (Microsoft published, 18 tasks)
BitNet b1.58 2B-4T scores 54.19 avg β beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β they predict ahead, not modify.
What's Proven vs What Needs Work
Measured (real data):
- 2.08 tokens per backbone step (end-to-end speculative decode loop)
- Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6%
- Head-to-head throughput: 4 models on identical hardware
- Training: loss 9.85 β 3.32 in 2000 steps on Zen 5 CPU
Not yet proven:
- Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path.
- Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step)
Files
medusa_heads_step2000.ptβ Trained Medusa head weights (4 heads, 1 layer each, hidden=2560)ggml-model-i2_s-medusa.ggufβ Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16)benchmark_headtohead.jsonβ Raw head-to-head benchmark databenchmark_results.jsonβ Efficiency benchmark datafigures/β All charts (see below)
Architecture
BitNet b1.58 2B-4T (frozen) 4 Medusa Heads (13 MB)
βββββββββββββββββββββββ ββββββββββββββββββββ
β 30 layers β β Head 1: t+1 63.0%β
β 2560 hidden β ββhβββ β Head 2: t+2 29.0%β βββ 2.08 tok/step
β Ternary {-1, 0, 1} β β Head 3: t+3 11.1%β
β 751 MB (I2_S) β β Head 4: t+4 4.6%β
βββββββββββββββββββββββ ββββββββββββββββββββ
Each head: h + W_out @ SiLU(W_in @ h), projected through the shared lm_head.
Training
| Detail | Value |
|---|---|
| Data | tatsu-lab/alpaca (52K examples) |
| Method | Cache backbone hidden states, train heads on cached features |
| Steps | 2000 (loss 9.85 β 3.32) |
| Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only |
| Time | ~4h caching + ~7h training = ~11h total |
| Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps |
Usage
Python (verified working)
import torch
from model import MedusaHeads
ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu")
heads = MedusaHeads(hidden_size=2560, vocab_size=128256,
num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16)
heads.load_state_dict(ckpt["heads"])
C++ (loads and runs, speculation pending kernel fix)
cd bitnet.cpp/3rdparty/llama.cpp
git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch
# Build, then:
./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16
Credits
- Medusa: Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. ICML 2024. Apache 2.0.
- BitNet b1.58: Microsoft Research. Model (MIT). bitnet.cpp (MIT).
- llama.cpp: Georgi Gerganov et al. MIT.
- Built with: Claude Code (Anthropic, Opus 4.6)
Citation
@misc{corcoran2025medusabitnet,
title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs},
author={Parrish Corcoran},
year={2025},
url={https://github.com/parrishcorcoran/MedusaBitNet}
}
License
MIT