File size: 5,247 Bytes
ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 5b11478 ccf50e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | ---
license: mit
tags:
- bitnet
- speculative-decoding
- medusa
- ternary-weights
- efficient-inference
- cpu-inference
language:
- en
base_model: microsoft/BitNet-b1.58-2B-4T
library_name: gguf
pipeline_tag: text-generation
---
# MedusaBitNet 2B-4T
**First integration of [Medusa speculative decoding](https://github.com/FasterDecoding/Medusa) with [BitNet b1.58](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ternary-weight inference.**
4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates **2.08 tokens per backbone step** β measured end-to-end across 5,136 tokens in 2,473 speculative decode steps.
## Measured Results
### Speculative Decoding (real end-to-end, Python, 20 sequences)
| Metric | Value |
|---|---|
| **Tokens per backbone step** | **2.08** (5,136 tokens / 2,473 steps) |
| Head 1 acceptance (t+1) | 63.0% |
| Head 2 acceptance (t+2) | 29.0% |
| Head 3 acceptance (t+3) | 11.1% |
| Head 4 acceptance (t+4) | 4.6% |
| Medusa head size | 13 MB (f16) |
| Total model size | 764 MB (backbone + heads) |
### Head-to-Head Throughput (same hardware, same prompts, llama-cli)
| Model | Params | Gen tok/s | Size |
|---|---|---|---|
| Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB |
| Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB |
| **BitNet b1.58 2B (I2_S)** | **2.4B** | **72.7** | **1187 MB** |
| Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB |
All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only.
### Quality (Microsoft published, 18 tasks)
BitNet b1.58 2B-4T scores **54.19 avg** β beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β they predict ahead, not modify.
## What's Proven vs What Needs Work
**Measured (real data):**
- 2.08 tokens per backbone step (end-to-end speculative decode loop)
- Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6%
- Head-to-head throughput: 4 models on identical hardware
- Training: loss 9.85 β 3.32 in 2000 steps on Zen 5 CPU
**Not yet proven:**
- Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path.
- Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step)
## Files
- `medusa_heads_step2000.pt` β Trained Medusa head weights (4 heads, 1 layer each, hidden=2560)
- `ggml-model-i2_s-medusa.gguf` β Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16)
- `benchmark_headtohead.json` β Raw head-to-head benchmark data
- `benchmark_results.json` β Efficiency benchmark data
- `figures/` β All charts (see below)
## Architecture
```
BitNet b1.58 2B-4T (frozen) 4 Medusa Heads (13 MB)
βββββββββββββββββββββββ ββββββββββββββββββββ
β 30 layers β β Head 1: t+1 63.0%β
β 2560 hidden β ββhβββ β Head 2: t+2 29.0%β βββ 2.08 tok/step
β Ternary {-1, 0, 1} β β Head 3: t+3 11.1%β
β 751 MB (I2_S) β β Head 4: t+4 4.6%β
βββββββββββββββββββββββ ββββββββββββββββββββ
```
Each head: `h + W_out @ SiLU(W_in @ h)`, projected through the shared lm_head.
## Training
| Detail | Value |
|---|---|
| Data | [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (52K examples) |
| Method | Cache backbone hidden states, train heads on cached features |
| Steps | 2000 (loss 9.85 β 3.32) |
| Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only |
| Time | ~4h caching + ~7h training = ~11h total |
| Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps |
## Usage
### Python (verified working)
```python
import torch
from model import MedusaHeads
ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu")
heads = MedusaHeads(hidden_size=2560, vocab_size=128256,
num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16)
heads.load_state_dict(ckpt["heads"])
```
### C++ (loads and runs, speculation pending kernel fix)
```bash
cd bitnet.cpp/3rdparty/llama.cpp
git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch
# Build, then:
./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16
```
## Credits
- **Medusa:** Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. [ICML 2024](https://arxiv.org/abs/2401.10774). Apache 2.0.
- **BitNet b1.58:** Microsoft Research. [Model](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) (MIT). [bitnet.cpp](https://github.com/microsoft/BitNet) (MIT).
- **llama.cpp:** Georgi Gerganov et al. MIT.
- **Built with:** [Claude Code](https://claude.ai/claude-code) (Anthropic, Opus 4.6)
## Citation
```bibtex
@misc{corcoran2025medusabitnet,
title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs},
author={Parrish Corcoran},
year={2025},
url={https://github.com/parrishcorcoran/MedusaBitNet}
}
```
## License
MIT
|