SpecJAX: A Speculative Decoding Library for TPUs
This post explains what speculative decoding is, why it's mathematically lossless (unlike quantization), why the GPU side of this problem is well-served and the TPU side wasn't, and how SpecJAX : SpecForge :: sglang-jax : SGLang organizes the TPU EAGLE3 stack today.
The Memory-Bandwidth Wall
When you generate text with an LLM, almost every cycle of your accelerator is wasted. An H100 peaks near 2,000 TFLOPS of BF16 compute, but at batch size 1 it runs at well under 1% utilization during decode. Autoregressive generation is memory-bandwidth bound: every new token requires streaming the entire model's weights from HBM once. On a 70 GB model and an H100's ~3.35 TB/s of bandwidth, that caps you at about 48 tokens per second — no matter how many tensor cores are sitting idle.
TPUs have the same problem for the same reason. A v5e chip has ~819 GB/s of HBM bandwidth; bigger pods scale the compute but not the per-chip roof. You can add chips to fit bigger models, but you can't add chips to make a single-stream decode any faster.
The standard fixes — KV cache, continuous batching, tensor parallelism, FlashAttention — are already shipped in every modern inference engine. The remaining big levers are quantization and speculative decoding. They are not the same thing.
Speculative Decoding, Briefly
Speculative decoding uses two models: a small, fast draft model and the full target model.
At each step:
- The draft proposes candidate tokens with cheap forward passes.
- The target verifies all candidates in a single forward pass — the same cost as generating one token normally.
- A principled accept/reject rule decides which proposals to keep.
Left: standard decoding — one target forward pass per token. Right: speculative decoding — the draft proposes several tokens, the target verifies them in one pass, and multiple tokens emerge per target step.
The clever part is the verification rule. Let be the target's probability for draft token and the draft's probability. Accept with probability ; on rejection, resample from the residual . Leviathan et al. (2023) and Chen et al. (2023) independently proved that the marginal distribution of accepted tokens is exactly the target's distribution — not approximately, exactly.
This is the key property: speculative decoding is mathematically lossless. Your model produces the same tokens it would have produced without it. Run it or don't — the output distribution is identical.
Lossless vs. Quantization
It's worth lingering on "lossless," because the other big lever — quantization — doesn't have that property.
Quantization trades bits for bandwidth. INT8 weight-activation (W8A8) is nearly free on modern kernels and gives you ~1.8–2.4× throughput at roughly no quality cost on most benchmarks. But push further — W4A8, W4A4, mixed-precision KV cache — and the output distribution starts drifting. Numerical errors compound across layers, calibration sets matter, and long-context and reasoning tasks degrade first and loudest. The QuaRot, AWQ, and SmoothQuant lines of work exist precisely because the naïve approach breaks things.
Quantization is also a one-shot lever. Once you've quantized, you can't quantize again to get another 2×. It also changes the artifact you ship — INT4 weights are a different object from BF16 weights, with different numerics, different failure modes, and different hardware requirements.
Speculative decoding is different in every one of these dimensions:
| Quantization (W4A4, INT4 KV) | Speculative decoding (EAGLE3) | |
|---|---|---|
| Output identical to FP16 target? | No — measurable drift, especially on reasoning | Yes — mathematically guaranteed |
| Changes the target weights? | Yes — ships new weights | No — target weights unchanged |
| Composable with quantization? | n/a | Yes — stack both |
| Ceiling | ~2–4× before quality cliff | 1.5–2× per head, improves with better drafts |
| Hardware-agnostic? | No — INT4 kernels differ per vendor | Yes — weights are just weights |
These levers are complementary, not alternative. You can quantize your target to INT8 (or FP8), then add an EAGLE3 draft on top, and compose both speedups — which is exactly how we serve our FP8 GLM-4.7-Flash deployment.
Enter EAGLE3
Vanilla speculative decoding works, but its speedup is capped by the draft model's per-token acceptance rate . Expected tokens per target pass with draft proposals is:
Small improvements to the acceptance rate matter a lot. Doubling the number of proposed tokens helps only if the draft is already accepting most of them.
EAGLE3 pushes up hard. Instead of using an independent small LLM as the draft, EAGLE3 trains a single-transformer-block draft head that consumes the target's own hidden states and predicts its next-token distribution directly. The supervision signal is the target's full softmax (via KL divergence), not one-hot ground-truth labels — which teaches the draft to preserve the target's uncertainty over near-synonymous alternatives. That's how every model we trained lands in the 60–66% first-token acceptance band.
In short: quantization reduces the bytes loaded per forward pass; EAGLE3 reduces the number of forward passes per accepted token. Multiply them together.
The GPU Side Was Already Solved
On NVIDIA hardware, this stack is mature:
- Training: SpecForge — the SGLang team's PyTorch/CUDA framework for training EAGLE3 draft heads.
- Inference: SGLang — production inference server with first-class EAGLE3 support:
--speculative-algorithm EAGLE3 --speculative-draft-model-path …and you're done.
We use both internally, and when we need a draft head for an NVIDIA target fleet this is the pipeline we reach for. It works, it's well-documented, and it's what most teams doing speculative decoding today default to.
The TPU Gap
We have significant TPU capacity through Google's TRC program, and for throughput-oriented training workloads a v4-32 or v6e pod is an excellent deal per FLOP. So the natural question: can we run SpecForge on TPU via PyTorch/XLA?
We tried. It didn't work — libtpu mmap corruption, XLA recompilation storms on data-dependent MoE shapes, and the kind of cross-framework debugging that eats a week and produces nothing shippable. After enough attempts, we concluded the PyTorch/XLA route wasn't viable for this workload and started over in pure JAX.
That became SpecJAX. The mental model is clean:
| Phase | GPU / NVIDIA | TPU / Google |
|---|---|---|
| Train EAGLE3 draft head | SpecForge (PyTorch/CUDA) | SpecJAX (pure JAX/XLA) |
| Serve with speculation | SGLang | sglang-jax (we maintain a patched fork) |
SpecJAX : SpecForge :: sglang-jax : SGLang.
SpecJAX is intentionally minimal: no Flax, no nn.Module, no mutable state. All forward passes are stateless pure functions over flat parameter dictionaries loaded directly from safetensors — which is exactly the shape JAX's JIT and SPMD sharding want. MoE expert dispatch uses static-shape einsum so XLA compiles once and stays compiled. The 2D (dp, tp) mesh scales from TP=4 on a single v4-32 host up to TP=8 across two v6e-8 hosts for 32B+ targets.
The Nine Models
Nine draft heads, all trained with SpecJAX, all public:
| Target | Params | Hardware | acc₀ | Draft Head |
|---|---|---|---|---|
| Llama-3.2-3B-Instruct | 3.2B | TPU v6e-4 | 60.6% | thoughtworks/Llama-3.2-3B-Instruct-Eagle3 |
| Llama-3.1-8B-Instruct | 8B | TPU v4-32 | 60.5% | thoughtworks/Llama-3.1-8B-Instruct-Eagle3 |
| Qwen2.5-7B-Instruct | 7.1B | TPU v4-32 | 61.8% | thoughtworks/Qwen2.5-7B-Instruct-Eagle3 |
| Qwen2.5-14B-Instruct | 14B | TPU v4-32 | 60.2% | thoughtworks/Qwen2.5-14B-Instruct-Eagle3 |
| DeepSeek-R1-Distill-Qwen-7B | 7.6B | TPU v5e-32 | 61.5% | thoughtworks/DeepSeek-R1-Distill-Qwen-7B-Eagle3 |
| DeepSeek-R1-Distill-Qwen-14B | 14B | TPU v4-32 | 65.8% | thoughtworks/DeepSeek-R1-Distill-Qwen-14B-Eagle3 |
| Qwen3-8B | 8B | TPU v4-32 | 60.0% | thoughtworks/Qwen3-8B-Eagle3 |
| Qwen3-14B | 14B | TPU v4-32 | 60.1% | thoughtworks/Qwen3-14B-Eagle3 |
| Qwen3-32B | 32B | TPU v6e-16 (TP=8) | 59.3% | thoughtworks/Qwen3-32B-Eagle3 |
acc₀ is the first-token acceptance rate on ShareGPT at temperature 0, measured during training evaluation.
A few highlights worth pulling out:
- DeepSeek-R1-Distill-Qwen-14B hit 65.8% acc₀ — our highest. Reasoning-distilled models produce chain-of-thought traces with repeated syntactic structure, which is unusually predictable for a small draft head. Training took 84 minutes on a v4-32 pod.
- Qwen3-32B is — as far as we know — the first EAGLE3 draft head trained with TP=8 spanning two hosts. Getting there required three fixes: Qwen3's explicit
head_dim=128(the derived value would be 80), DP-rank computation under multi-host TP groups, andprocess_allgatherfor multi-host sharded checkpoint writes. - Every model clears 59.3% acc₀. The EAGLE3 paper's headline acceptance rates sit in the same band; our pure-JAX reimplementation reproduces them cleanly.
Training at Scale: 84 Minutes to 12 Hours
Training wall-clock time per model. Ranges from 84 minutes (DeepSeek-R1-Distill-Qwen-14B on v4-32) to 12 hours (Qwen3-32B on v6e-16 with TP=8). Dataset is a 54K-sample mix of ShareGPT (45%), UltraChat-200K (35%), and Open-PerfectBlend (20%).
The 3B target fits on a single v6e-4 host and finishes in just over two hours. The 7–14B targets train on v4-32 or v5e-32 pods with TP=4 × DP=4 and typically complete in 2–4 hours. The 32B jump required moving past single-host tensor-parallel groups — our first TP=8 configuration spanning two v6e-8 hosts — which was the main engineering lift of the whole release.
Serving the Models on TPU
Inference happens on sglang-jax, the JAX/TPU counterpart to SGLang. Use our tails-mpt/sglang-jax fork, which ships with the EAGLE3 layer-capture patches for Qwen2/2.5 targets and the tied-embeddings fix for Llama-3.2 pre-applied:
git clone https://github.com/tails-mpt/sglang-jax.git
cd sglang-jax && pip install -e .
python -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-8B \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path thoughtworks/Qwen3-8B-Eagle3 \
--speculative-num-steps 3 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 4 \
--tp-size 4 --dtype bfloat16
An Honest Note on sglang-jax Performance
We want to be direct: the inference side of the TPU EAGLE3 stack is not in the same place as the GPU side. sglang-jax is a young project, and the upstream repository concedes that performance on nearly every currently supported model family still has room to grow — their own docs flag Llama, Qwen 2, Qwen 2 MoE, and others as "performance needs to improve," with only parts of the Qwen 3 family marked as mature.
The EAGLE3 verify / tree-building codepath specifically isn't fully JIT-compiled yet. Our models load correctly, generate correct output, and hit the acceptance rates reported above — but wall-clock throughput on sglang-jax today is behind what the same algorithm delivers on SGLang for GPU. We don't want to oversell that. If your priority right now is raw tok/s, the GPU stack remains the mature choice.
What we are doing is preparing the ground. TPU adoption for LLM serving is growing, and the sglang-jax team is actively optimizing the pipeline. When those optimizations land, nine production-ready EAGLE3 draft heads will already be sitting on the Hub, and pointing your serving stack at thoughtworks/<target>-Eagle3 will be a one-line change. The models don't need to wait for the runtime.
Training Your Own
SpecJAX is MIT-licensed and open:
git clone https://github.com/tails-mpt/SpecJAX.git
cd SpecJAX && pip install -e .
python -m specjax.train \
--target-model-path /path/to/your-model \
--target-model-type qwen3 \
--data-path data/sharegpt.jsonl \
--output-dir /path/to/checkpoints
Supported target architectures: Llama 3.x, Qwen 2 / 2.5 / 3 (including MoE), DeepSeek-R1-Distill-Qwen, and GLM-4.7-Flash. Adding a new target is typically a few hundred lines of pure JAX to implement its forward function. If you'd like a draft head for something we haven't covered, open an issue or a PR — this is meant to be a community effort.
Why We Built This in the Open
The GPU EAGLE3 stack works, and most practitioners will default to it. That's fine. But "default to GPU" stops being an option the moment you want to use the TPU capacity you actually have — whether that's a TRC allocation, a Trillium reservation, or a v4 pod sitting idle in your org. For that case, there wasn't a good answer. Now there is.
Nine checkpoints is a start, not a destination. If there's a target model you want an EAGLE3 head for, or a piece of SpecJAX you'd like to see improved, the repository is open and the issues tracker is empty waiting for yours.
Links
- SpecJAX (training framework): github.com/tails-mpt/SpecJAX
- sglang-jax fork (TPU inference): github.com/tails-mpt/sglang-jax
- sglang-jax upstream: github.com/sgl-project/sglang-jax
- SpecForge (GPU counterpart): github.com/sgl-project/SpecForge
- SGLang (GPU counterpart): github.com/sgl-project/sglang
- All 9 models: huggingface.co/thoughtworks
- EAGLE3 paper: Li et al., 2025 — arXiv:2503.01840
- Speculative decoding foundations: Leviathan et al., 2023 · Chen et al., 2023
Citation
@article{li2025eagle3,
title={EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test},
author={Li, Yuhui and Wei, Fangyun and Zhang, Chao and Zhang, Hongyang},
journal={arXiv preprint arXiv:2503.01840},
year={2025}
}






