Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Paper β’ 2401.10774 β’ Published β’ 60
First integration of Medusa speculative decoding with BitNet b1.58 ternary-weight inference.
4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates 2.08 tokens per backbone step β measured end-to-end across 5,136 tokens in 2,473 speculative decode steps.
| Metric | Value |
|---|---|
| Tokens per backbone step | 2.08 (5,136 tokens / 2,473 steps) |
| Head 1 acceptance (t+1) | 63.0% |
| Head 2 acceptance (t+2) | 29.0% |
| Head 3 acceptance (t+3) | 11.1% |
| Head 4 acceptance (t+4) | 4.6% |
| Medusa head size | 13 MB (f16) |
| Total model size | 764 MB (backbone + heads) |
| Model | Params | Gen tok/s | Size |
|---|---|---|---|
| Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB |
| Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB |
| BitNet b1.58 2B (I2_S) | 2.4B | 72.7 | 1187 MB |
| Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB |
All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only.
BitNet b1.58 2B-4T scores 54.19 avg β beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β they predict ahead, not modify.
Measured (real data):
Not yet proven:
medusa_heads_step2000.pt β Trained Medusa head weights (4 heads, 1 layer each, hidden=2560)ggml-model-i2_s-medusa.gguf β Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16)benchmark_headtohead.json β Raw head-to-head benchmark databenchmark_results.json β Efficiency benchmark datafigures/ β All charts (see below)BitNet b1.58 2B-4T (frozen) 4 Medusa Heads (13 MB)
βββββββββββββββββββββββ ββββββββββββββββββββ
β 30 layers β β Head 1: t+1 63.0%β
β 2560 hidden β ββhβββ β Head 2: t+2 29.0%β βββ 2.08 tok/step
β Ternary {-1, 0, 1} β β Head 3: t+3 11.1%β
β 751 MB (I2_S) β β Head 4: t+4 4.6%β
βββββββββββββββββββββββ ββββββββββββββββββββ
Each head: h + W_out @ SiLU(W_in @ h), projected through the shared lm_head.
| Detail | Value |
|---|---|
| Data | tatsu-lab/alpaca (52K examples) |
| Method | Cache backbone hidden states, train heads on cached features |
| Steps | 2000 (loss 9.85 β 3.32) |
| Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only |
| Time | ~4h caching + ~7h training = ~11h total |
| Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps |
import torch
from model import MedusaHeads
ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu")
heads = MedusaHeads(hidden_size=2560, vocab_size=128256,
num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16)
heads.load_state_dict(ckpt["heads"])
cd bitnet.cpp/3rdparty/llama.cpp
git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch
# Build, then:
./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16
@misc{corcoran2025medusabitnet,
title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs},
author={Parrish Corcoran},
year={2025},
url={https://github.com/parrishcorcoran/MedusaBitNet}
}
MIT
We're not able to determine the quantization variants.