EAGLE3 Draft Head — Gemma-4-31B
A speculative decoding draft head for google/gemma-4-31B, trained using the EAGLE3 method on NVIDIA H200 GPUs with SpecForge.
EAGLE3 draft heads accelerate autoregressive generation by proposing multiple tokens per step that the target model then verifies in parallel — achieving up to 1.72x throughput on conversational workloads with no change in output quality.
Important: This model requires a fork of SGLang — not the official upstream release. Gemma-4 with EAGLE3 is not yet supported in mainline SGLang. See the installation instructions below.
Installation
Gemma-4 EAGLE3 requires the ThoughtWorks fork of SGLang. The official SGLang release does not include Gemma-4 EAGLE3 support.
pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
pip install git+https://github.com/tails-mpt/sglang.git
Why a fork? Gemma-4 is not yet supported in upstream SGLang. Its hybrid attention architecture uses head_dim=512 for global layers, which is incompatible with FlashInfer. Our fork adds Gemma-4 support using the triton attention backend.
| Issue | Gemma-4 global attention layers use head_dim=512; FlashInfer supports up to 256 |
| Fix | Triton attention backend (handles arbitrary head dimensions) |
| Fork | tails-mpt/sglang on main |
Usage
SGLang (GPU)
# Install the fork first — see Installation above
python -m sglang.launch_server \
--model google/gemma-4-31B \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path thoughtworks/Gemma-4-31B-Eagle3 \
--speculative-num-steps 3 \
--speculative-eagle-topk 4 \
--speculative-num-draft-tokens 8 \
--tp 4 \
--dtype bfloat16 \
--enable-cuda-graph
Configuration tips:
- CUDA graphs are recommended for best throughput (
--enable-cuda-graph). - For batch size > 1, use
--speculative-num-steps 5 --speculative-eagle-topk 1 --speculative-num-draft-tokens 6for optimal latency. - The triton attention backend is selected automatically when FlashInfer detects incompatible head dimensions.
Python (SGLang client)
import sglang as sgl
llm = sgl.LLM(
model="google/gemma-4-31B",
speculative_algorithm="EAGLE3",
speculative_draft_model_path="thoughtworks/Gemma-4-31B-Eagle3",
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=8,
tp_size=4,
dtype="bfloat16",
)
Benchmark Results
Throughput measured with CUDA graphs enabled, TP=2, temperature=0, max_tokens=512. SGLang fork: tails-mpt/sglang@main
| Dataset | Baseline (tok/s) | EAGLE3 (tok/s) | Speedup |
|---|---|---|---|
| MT-Bench | 49.7 | 85.4 | 1.72x |
| HumanEval | 49.8 | 73.7 | 1.48x |
| SWE-Bench Multilingual | 48.5 | 55.4 | 1.14x |
| SWE-Bench Verified | 48.2 | 50.4 | 1.05x |
Benchmarked on 8x NVIDIA H200 144GB. Speedups are measured at batch size 1, which is the primary use case for speculative decoding. Actual gains depend on prompt distribution and hardware configuration.
Training Details
| Parameter | Value |
|---|---|
| Framework | SpecForge (PyTorch) |
| SGLang fork | tails-mpt/sglang@main (training backend) |
| Hardware | 8x NVIDIA H200 144GB (TP=4, DP=2) |
| Dataset | 54K mixed: ShareGPT (45%) + UltraChat-200K (35%) + Open-PerfectBlend (20%) |
| Epochs | 3 |
| Optimizer | AdamW, cosine LR decay |
| Learning rate | 5e-5 |
| Batch size | 1, max sequence length 1024 |
| TTT length | 7 (multi-step speculative rollout) |
| Training time | ~2 hours |
| Precision | bfloat16 |
Training Method
This model uses EAGLE3's Test-Time Training (TTT) objective with a rollout length of 7. At each training step, the draft head autoregressively proposes 7 tokens; the target model provides ground-truth hidden states and logits for all positions; a geometric loss (0.8^k weighting) trains the draft to match the target's output distribution at each speculative position.
Training data was processed through an SGLang server running the target model, ensuring the draft head learns from SGLang's runtime hidden state distributions rather than HuggingFace's — a critical alignment step that improves acceptance rates at inference time.
Model Architecture
The draft head is a single-layer transformer that operates on the target model's hidden states:
| Parameter | Value |
|---|---|
| Architecture | LlamaForCausalLMEagle3 (1 decoder layer) |
| Hidden size | 5,376 |
| Attention heads | 42 (GQA: 14 KV heads) |
| Head dimension | 128 |
| Vocabulary size | 262,144 (full target vocab) |
| Draft vocabulary | 32,000 (top tokens by training frequency) |
| Auxiliary layers | [2, 29, 56] (hidden states from target model) |
| Parameters | ~650M |
Limitations
- Requires a fork of SGLang — this model will not work with the official SGLang release. Use tails-mpt/sglang@main.
- FlashInfer incompatible — Gemma-4's global attention layers use
head_dim=512, which exceeds FlashInfer's current limit of 256. The triton attention backend is used instead, which is slightly slower for non-speculative inference but fully functional. - Trained on English-dominant instruction-following data; acceptance rates may be lower on non-English or highly domain-specific inputs.
- This draft head is text-only. Gemma-4's vision capabilities are not utilized during training or inference with EAGLE3.
Developer Notes
This checkpoint was selected from internal experiment exp-b-sglang, trained with max_length=1024 on the SGLang backend. It was chosen as the best-performing variant after comparing multiple training configurations across 4 benchmark datasets.
License
This model is released under the Gemma License, consistent with the base model's license.
References
@article{li2025eagle3,
title={EAGLE3: Scalable Speculative Decoding with Training-Free Multi-Draft Speculation},
author={Li, Yuhui and Wei, Fangyun and Zhang, Chao and Zhang, Hongyang},
journal={arXiv preprint arXiv:2503.01840},
year={2025}
}
- Downloads last month
- 1,040
Model tree for thoughtworks/Gemma-4-31B-Eagle3
Base model
google/gemma-4-31B