MedusaBitNet-2B-4T / README.md
parrishcorcoran's picture
Upload README.md with huggingface_hub
5b11478 verified
metadata
license: mit
tags:
  - bitnet
  - speculative-decoding
  - medusa
  - ternary-weights
  - efficient-inference
  - cpu-inference
language:
  - en
base_model: microsoft/BitNet-b1.58-2B-4T
library_name: gguf
pipeline_tag: text-generation

MedusaBitNet 2B-4T

First integration of Medusa speculative decoding with BitNet b1.58 ternary-weight inference.

4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates 2.08 tokens per backbone step β€” measured end-to-end across 5,136 tokens in 2,473 speculative decode steps.

Measured Results

Speculative Decoding (real end-to-end, Python, 20 sequences)

Metric Value
Tokens per backbone step 2.08 (5,136 tokens / 2,473 steps)
Head 1 acceptance (t+1) 63.0%
Head 2 acceptance (t+2) 29.0%
Head 3 acceptance (t+3) 11.1%
Head 4 acceptance (t+4) 4.6%
Medusa head size 13 MB (f16)
Total model size 764 MB (backbone + heads)

Head-to-Head Throughput (same hardware, same prompts, llama-cli)

Model Params Gen tok/s Size
Llama 3.2 1B (Q4_K_M) 1.0B 115.9 808 MB
Qwen2.5 1.5B (Q4_K_M) 1.5B 88.8 1117 MB
BitNet b1.58 2B (I2_S) 2.4B 72.7 1187 MB
Gemma 2 2B (Q4_K_M) 2.0B 50.5 1709 MB

All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only.

Quality (Microsoft published, 18 tasks)

BitNet b1.58 2B-4T scores 54.19 avg β€” beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β€” they predict ahead, not modify.

What's Proven vs What Needs Work

Measured (real data):

  • 2.08 tokens per backbone step (end-to-end speculative decode loop)
  • Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6%
  • Head-to-head throughput: 4 models on identical hardware
  • Training: loss 9.85 β†’ 3.32 in 2000 steps on Zen 5 CPU

Not yet proven:

  • Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path.
  • Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step)

Files

  • medusa_heads_step2000.pt β€” Trained Medusa head weights (4 heads, 1 layer each, hidden=2560)
  • ggml-model-i2_s-medusa.gguf β€” Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16)
  • benchmark_headtohead.json β€” Raw head-to-head benchmark data
  • benchmark_results.json β€” Efficiency benchmark data
  • figures/ β€” All charts (see below)

Architecture

BitNet b1.58 2B-4T (frozen)     4 Medusa Heads (13 MB)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ 30 layers           β”‚         β”‚ Head 1: t+1 63.0%β”‚
β”‚ 2560 hidden         β”‚ ──h──→  β”‚ Head 2: t+2 29.0%β”‚  ──→  2.08 tok/step
β”‚ Ternary {-1, 0, 1}  β”‚         β”‚ Head 3: t+3 11.1%β”‚
β”‚ 751 MB (I2_S)       β”‚         β”‚ Head 4: t+4  4.6%β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Each head: h + W_out @ SiLU(W_in @ h), projected through the shared lm_head.

Training

Detail Value
Data tatsu-lab/alpaca (52K examples)
Method Cache backbone hidden states, train heads on cached features
Steps 2000 (loss 9.85 β†’ 3.32)
Hardware AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only
Time ~4h caching + ~7h training = ~11h total
Optimizer AdamW, lr=1e-3, cosine schedule, 50 warmup steps

Usage

Python (verified working)

import torch
from model import MedusaHeads

ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu")
heads = MedusaHeads(hidden_size=2560, vocab_size=128256,
                    num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16)
heads.load_state_dict(ckpt["heads"])

C++ (loads and runs, speculation pending kernel fix)

cd bitnet.cpp/3rdparty/llama.cpp
git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch
# Build, then:
./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16

Credits

  • Medusa: Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. ICML 2024. Apache 2.0.
  • BitNet b1.58: Microsoft Research. Model (MIT). bitnet.cpp (MIT).
  • llama.cpp: Georgi Gerganov et al. MIT.
  • Built with: Claude Code (Anthropic, Opus 4.6)

Citation

@misc{corcoran2025medusabitnet,
  title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs},
  author={Parrish Corcoran},
  year={2025},
  url={https://github.com/parrishcorcoran/MedusaBitNet}
}

License

MIT