EAGLE3 Draft Head — Gemma-4-31B

A speculative decoding draft head for google/gemma-4-31B, trained using the EAGLE3 method on NVIDIA H200 GPUs with SpecForge.

EAGLE3 draft heads accelerate autoregressive generation by proposing multiple tokens per step that the target model then verifies in parallel — achieving up to 1.72x throughput on conversational workloads with no change in output quality.

Important: This model requires a fork of SGLang — not the official upstream release. Gemma-4 with EAGLE3 is not yet supported in mainline SGLang. See the installation instructions below.

Installation

Gemma-4 EAGLE3 requires the ThoughtWorks fork of SGLang. The official SGLang release does not include Gemma-4 EAGLE3 support.

pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
pip install git+https://github.com/tails-mpt/sglang.git

Why a fork? Gemma-4 is not yet supported in upstream SGLang. Its hybrid attention architecture uses head_dim=512 for global layers, which is incompatible with FlashInfer. Our fork adds Gemma-4 support using the triton attention backend.

Issue Gemma-4 global attention layers use head_dim=512; FlashInfer supports up to 256
Fix Triton attention backend (handles arbitrary head dimensions)
Fork tails-mpt/sglang on main

Usage

SGLang (GPU)

# Install the fork first — see Installation above
python -m sglang.launch_server \
    --model google/gemma-4-31B \
    --speculative-algorithm EAGLE3 \
    --speculative-draft-model-path thoughtworks/Gemma-4-31B-Eagle3 \
    --speculative-num-steps 3 \
    --speculative-eagle-topk 4 \
    --speculative-num-draft-tokens 8 \
    --tp 4 \
    --dtype bfloat16 \
    --enable-cuda-graph

Configuration tips:

  • CUDA graphs are recommended for best throughput (--enable-cuda-graph).
  • For batch size > 1, use --speculative-num-steps 5 --speculative-eagle-topk 1 --speculative-num-draft-tokens 6 for optimal latency.
  • The triton attention backend is selected automatically when FlashInfer detects incompatible head dimensions.

Python (SGLang client)

import sglang as sgl

llm = sgl.LLM(
    model="google/gemma-4-31B",
    speculative_algorithm="EAGLE3",
    speculative_draft_model_path="thoughtworks/Gemma-4-31B-Eagle3",
    speculative_num_steps=3,
    speculative_eagle_topk=4,
    speculative_num_draft_tokens=8,
    tp_size=4,
    dtype="bfloat16",
)

Benchmark Results

Throughput measured with CUDA graphs enabled, TP=2, temperature=0, max_tokens=512. SGLang fork: tails-mpt/sglang@main

Dataset Baseline (tok/s) EAGLE3 (tok/s) Speedup
MT-Bench 49.7 85.4 1.72x
HumanEval 49.8 73.7 1.48x
SWE-Bench Multilingual 48.5 55.4 1.14x
SWE-Bench Verified 48.2 50.4 1.05x

Benchmarked on 8x NVIDIA H200 144GB. Speedups are measured at batch size 1, which is the primary use case for speculative decoding. Actual gains depend on prompt distribution and hardware configuration.

Training Details

Parameter Value
Framework SpecForge (PyTorch)
SGLang fork tails-mpt/sglang@main (training backend)
Hardware 8x NVIDIA H200 144GB (TP=4, DP=2)
Dataset 54K mixed: ShareGPT (45%) + UltraChat-200K (35%) + Open-PerfectBlend (20%)
Epochs 3
Optimizer AdamW, cosine LR decay
Learning rate 5e-5
Batch size 1, max sequence length 1024
TTT length 7 (multi-step speculative rollout)
Training time ~2 hours
Precision bfloat16

Training Method

This model uses EAGLE3's Test-Time Training (TTT) objective with a rollout length of 7. At each training step, the draft head autoregressively proposes 7 tokens; the target model provides ground-truth hidden states and logits for all positions; a geometric loss (0.8^k weighting) trains the draft to match the target's output distribution at each speculative position.

Training data was processed through an SGLang server running the target model, ensuring the draft head learns from SGLang's runtime hidden state distributions rather than HuggingFace's — a critical alignment step that improves acceptance rates at inference time.

Model Architecture

The draft head is a single-layer transformer that operates on the target model's hidden states:

Parameter Value
Architecture LlamaForCausalLMEagle3 (1 decoder layer)
Hidden size 5,376
Attention heads 42 (GQA: 14 KV heads)
Head dimension 128
Vocabulary size 262,144 (full target vocab)
Draft vocabulary 32,000 (top tokens by training frequency)
Auxiliary layers [2, 29, 56] (hidden states from target model)
Parameters ~650M

Limitations

  • Requires a fork of SGLang — this model will not work with the official SGLang release. Use tails-mpt/sglang@main.
  • FlashInfer incompatible — Gemma-4's global attention layers use head_dim=512, which exceeds FlashInfer's current limit of 256. The triton attention backend is used instead, which is slightly slower for non-speculative inference but fully functional.
  • Trained on English-dominant instruction-following data; acceptance rates may be lower on non-English or highly domain-specific inputs.
  • This draft head is text-only. Gemma-4's vision capabilities are not utilized during training or inference with EAGLE3.

Developer Notes

This checkpoint was selected from internal experiment exp-b-sglang, trained with max_length=1024 on the SGLang backend. It was chosen as the best-performing variant after comparing multiple training configurations across 4 benchmark datasets.

License

This model is released under the Gemma License, consistent with the base model's license.

References

@article{li2025eagle3,
  title={EAGLE3: Scalable Speculative Decoding with Training-Free Multi-Draft Speculation},
  author={Li, Yuhui and Wei, Fangyun and Zhang, Chao and Zhang, Hongyang},
  journal={arXiv preprint arXiv:2503.01840},
  year={2025}
}
Downloads last month
1,040
Safetensors
Model size
0.6B params
Tensor type
I64
·
BF16
·
BOOL
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for thoughtworks/Gemma-4-31B-Eagle3

Finetuned
(13)
this model

Paper for thoughtworks/Gemma-4-31B-Eagle3