File size: 5,247 Bytes
ccf50e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b11478
ccf50e8
5b11478
 
 
ccf50e8
 
 
5b11478
 
 
 
 
ccf50e8
 
 
5b11478
ccf50e8
 
 
 
 
 
 
 
5b11478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccf50e8
 
 
5b11478
 
 
 
 
ccf50e8
 
 
 
 
 
5b11478
 
 
 
ccf50e8
 
 
5b11478
ccf50e8
 
 
5b11478
 
 
 
 
 
 
 
ccf50e8
 
 
 
 
 
 
 
 
 
 
 
 
 
5b11478
ccf50e8
 
 
5b11478
 
ccf50e8
 
 
 
5b11478
 
 
ccf50e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
---
license: mit
tags:
  - bitnet
  - speculative-decoding
  - medusa
  - ternary-weights
  - efficient-inference
  - cpu-inference
language:
  - en
base_model: microsoft/BitNet-b1.58-2B-4T
library_name: gguf
pipeline_tag: text-generation
---

# MedusaBitNet 2B-4T

**First integration of [Medusa speculative decoding](https://github.com/FasterDecoding/Medusa) with [BitNet b1.58](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ternary-weight inference.**

4 Medusa heads trained on frozen BitNet b1.58 2B-4T backbone. Generates **2.08 tokens per backbone step** β€” measured end-to-end across 5,136 tokens in 2,473 speculative decode steps.

## Measured Results

### Speculative Decoding (real end-to-end, Python, 20 sequences)

| Metric | Value |
|---|---|
| **Tokens per backbone step** | **2.08** (5,136 tokens / 2,473 steps) |
| Head 1 acceptance (t+1) | 63.0% |
| Head 2 acceptance (t+2) | 29.0% |
| Head 3 acceptance (t+3) | 11.1% |
| Head 4 acceptance (t+4) | 4.6% |
| Medusa head size | 13 MB (f16) |
| Total model size | 764 MB (backbone + heads) |

### Head-to-Head Throughput (same hardware, same prompts, llama-cli)

| Model | Params | Gen tok/s | Size |
|---|---|---|---|
| Llama 3.2 1B (Q4_K_M) | 1.0B | 115.9 | 808 MB |
| Qwen2.5 1.5B (Q4_K_M) | 1.5B | 88.8 | 1117 MB |
| **BitNet b1.58 2B (I2_S)** | **2.4B** | **72.7** | **1187 MB** |
| Gemma 2 2B (Q4_K_M) | 2.0B | 50.5 | 1709 MB |

All benchmarks on AMD Ryzen AI MAX+ 395 (Strix Halo), 16 Zen 5 cores, 93GB LPDDR5x, CPU-only.

### Quality (Microsoft published, 18 tasks)

BitNet b1.58 2B-4T scores **54.19 avg** β€” beats LLaMA 3.2 1B (44.90), Gemma-3 1B (43.74), SmolLM2 1.7B (48.70). Medusa heads don't change output quality β€” they predict ahead, not modify.

## What's Proven vs What Needs Work

**Measured (real data):**
- 2.08 tokens per backbone step (end-to-end speculative decode loop)
- Head acceptance rates: 63.0% / 29.0% / 11.1% / 4.6%
- Head-to-head throughput: 4 models on identical hardware
- Training: loss 9.85 β†’ 3.32 in 2000 steps on Zen 5 CPU

**Not yet proven:**
- Wall-clock C++ Medusa throughput. The GGUF backbone's I2_S kernel lacks BitNet-style activation quantization, causing a hidden state distribution mismatch. Medusa heads work in Python on cached hidden states but not yet through the C++ inference path.
- Estimated C++ speedup: ~1.88x (based on 2.08 tok/step + ~10% head overhead on 13.75ms backbone step)

## Files

- `medusa_heads_step2000.pt` β€” Trained Medusa head weights (4 heads, 1 layer each, hidden=2560)
- `ggml-model-i2_s-medusa.gguf` β€” Merged GGUF: BitNet backbone (I2_S) + Medusa heads (f16)
- `benchmark_headtohead.json` β€” Raw head-to-head benchmark data
- `benchmark_results.json` β€” Efficiency benchmark data
- `figures/` β€” All charts (see below)

## Architecture

```
BitNet b1.58 2B-4T (frozen)     4 Medusa Heads (13 MB)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ 30 layers           β”‚         β”‚ Head 1: t+1 63.0%β”‚
β”‚ 2560 hidden         β”‚ ──h──→  β”‚ Head 2: t+2 29.0%β”‚  ──→  2.08 tok/step
β”‚ Ternary {-1, 0, 1}  β”‚         β”‚ Head 3: t+3 11.1%β”‚
β”‚ 751 MB (I2_S)       β”‚         β”‚ Head 4: t+4  4.6%β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```

Each head: `h + W_out @ SiLU(W_in @ h)`, projected through the shared lm_head.

## Training

| Detail | Value |
|---|---|
| Data | [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (52K examples) |
| Method | Cache backbone hidden states, train heads on cached features |
| Steps | 2000 (loss 9.85 β†’ 3.32) |
| Hardware | AMD Ryzen AI MAX+ 395 (Strix Halo), CPU-only |
| Time | ~4h caching + ~7h training = ~11h total |
| Optimizer | AdamW, lr=1e-3, cosine schedule, 50 warmup steps |

## Usage

### Python (verified working)
```python
import torch
from model import MedusaHeads

ckpt = torch.load("medusa_heads_step2000.pt", map_location="cpu")
heads = MedusaHeads(hidden_size=2560, vocab_size=128256,
                    num_heads=4, num_layers_per_head=1, dtype=torch.bfloat16)
heads.load_state_dict(ckpt["heads"])
```

### C++ (loads and runs, speculation pending kernel fix)
```bash
cd bitnet.cpp/3rdparty/llama.cpp
git apply ../../../MedusaBitNet/patches/medusa-llama-cpp.patch
# Build, then:
./build/bin/llama-medusa -m ggml-model-i2_s-medusa.gguf -p "prompt" -n 128 -t 16
```

## Credits

- **Medusa:** Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao. [ICML 2024](https://arxiv.org/abs/2401.10774). Apache 2.0.
- **BitNet b1.58:** Microsoft Research. [Model](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) (MIT). [bitnet.cpp](https://github.com/microsoft/BitNet) (MIT).
- **llama.cpp:** Georgi Gerganov et al. MIT.
- **Built with:** [Claude Code](https://claude.ai/claude-code) (Anthropic, Opus 4.6)

## Citation

```bibtex
@misc{corcoran2025medusabitnet,
  title={MedusaBitNet: Speculative Decoding for Ternary-Weight LLMs},
  author={Parrish Corcoran},
  year={2025},
  url={https://github.com/parrishcorcoran/MedusaBitNet}
}
```

## License

MIT