Llama-3.2-1B-MTP-k8: Multi-Token Prediction on a Single Consumer GPU
This is a reproduction of "Multi-Token Prediction via Self-Distillation" (arXiv 2602.06019) adapted for a single NVIDIA RTX 5090 (32GB). The original paper used 4x NVIDIA GH200 (384GB total) with Llama-3.1-8B. We scaled it down to Llama-3.2-1B on consumer hardware.
What is Multi-Token Prediction (MTP)?
Standard language models predict one token at a time (autoregressive decoding). MTP trains the model to predict multiple future tokens simultaneously using online self-distillation:
- A frozen teacher (the original model) generates soft probability distributions
- A trainable student (same architecture) learns to predict k future tokens at each position
- At inference, ConfAdapt decoding emits multiple tokens when the model is confident, falling back to single-token when uncertain
The result: faster inference with minimal quality loss.
Results: GSM8K 8-shot Chain-of-Thought
| Configuration | Exact Match (flexible) | Exact Match (strict) | Throughput |
|---|---|---|---|
| Baseline (Llama-3.2-1B, standard AR) | 7.13% ± 0.71 | 6.07% ± 0.66 | ~1.5 s/sample |
| MTP k=1 (single token, quality check) | 5.23% ± 0.61 | 2.96% ± 0.47 | ~2.4 s/sample |
| MTP k=8 + ConfAdapt 90% | 5.08% ± 0.60 | 3.03% ± 0.47 | ~1.3 s/sample |
Key Findings
- ConfAdapt works: k=8 with ConfAdapt matches k=1 quality while being 1.8x faster (avg 2.82 tokens emitted per step)
- Quality drop is expected: The ~2% accuracy drop from baseline is consistent with our smaller setup (1B model, 500M training tokens vs paper's 8B model, 2B tokens)
- The core claim holds: Multi-token decoding via ConfAdapt preserves generation quality while improving throughput, even on a tiny 1B model
Training Details
What We Changed from the Paper
| Parameter | Paper (8B / 4x GH200) | Ours (1B / 1x RTX 5090) |
|---|---|---|
| Base model | Llama-3.1-8B | Llama-3.2-1B |
| GPUs | 4x GH200 (96GB each) | 1x RTX 5090 (32GB) |
| FSDP mesh | 1x4 | 1x1 (no FSDP) |
| k_toks | Randomized 2-16 across ranks | Fixed 8 |
| Training tokens | 2B | 500M |
| micro_batch_size | 32 | 8 |
| global_batch_size | 128 | 64 (grad accumulation) |
| mask_region_ct | 5 | 1 |
| rollout_multiplier | 4 | 2 |
| Template | Chat (Instruct tokenizer) | Plain text (base tokenizer) |
What We Kept the Same
- Supervision method: Soft teacher via KL divergence (paper's recommended self-distillation)
- Dataset: MetaMathQA (
jwkirchenbauer/metamathqa-grouped-split) - Sequence length: 160 tokens
- Peak learning rate: 1e-5
- Optimizer: AdamW with cosine decay
Training Metrics
- Total steps: 48,828
- Training time: ~17 hours on RTX 5090
- Final train loss: ~0.9
- Final val loss: 1.895 (perplexity 6.65)
Why k=8 Instead of the Paper's Randomized k=2-16?
The paper's approach randomizes k across GPU ranks each step. With 4 GPUs, the model sees k=2, k=5, k=12, k=16 simultaneously in a single batch, learning to handle any prediction horizon.
With a single GPU, we can only train one k value per step. We chose k=8 as a middle ground — large enough to demonstrate meaningful multi-token speedup, small enough to fit in 32GB VRAM.
This is an important tradeoff: our model is specialized for k=8, while the paper's model generalizes across all k values. A production deployment would benefit from the paper's multi-GPU randomized approach.
Infrastructure: Running on Consumer Hardware
This reproduction ran entirely on a home Kubernetes cluster:
- GPU: NVIDIA RTX 5090 (32GB, Blackwell architecture / sm_120)
- System: 16GB RAM, Debian 13
- Stack: Kubernetes + containerd + NVIDIA device plugin
- PyTorch: Nightly build with CUDA 12.8 (required for Blackwell sm_120 support)
Challenges We Solved
- Blackwell GPU support: RTX 5090 (sm_120) requires PyTorch nightly with cu128 — stable releases don't include sm_120 yet
- Single-GPU checkpoint saving: The original code uses
torch.distributed.all_reduce()for checkpoint state sync, which crashes when distributed is not initialized. We added anis_initialized()guard - W&B configuration: Default config points to the paper authors' organization. Override with
--wandb.entity=null - HuggingFace checkpoint format: The litgpt converter outputs
model.pthbut transformers expectspytorch_model.bin
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"celestialcreator/Llama-3.2-1B-MTP-k8",
trust_remote_code=True,
torch_dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained("celestialcreator/Llama-3.2-1B-MTP-k8")
# Standard generation (single token, works like any Llama model)
inputs = tokenizer("The capital of France is", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
For MTP inference with ConfAdapt decoding, use the mtp-lm evaluation harness fork.
Reproduction Guide
Full reproduction instructions with Kubernetes manifests and configs: GitHub Fork
Citation
If you use this model, please cite the original paper:
@article{kirchenbauer2025multitokenpredictionselfdistillation,
title={Multi-Token Prediction via Self-Distillation},
author={John Kirchenbauer and Jonas Geiping and Yuxin Wen and Tom Goldstein},
journal={arXiv preprint arXiv:2602.06019},
year={2025}
}
Acknowledgments
- Original paper and code by John Kirchenbauer et al.
- Built with LitGPT, PyTorch, and lm-evaluation-harness
- Downloads last month
- 76
Model tree for celestialcreator/Llama-3.2-1B-MTP-k8
Base model
meta-llama/Llama-3.2-1BDataset used to train celestialcreator/Llama-3.2-1B-MTP-k8
Paper for celestialcreator/Llama-3.2-1B-MTP-k8
Evaluation results
- exact_match (flexible) on GSM8Kself-reported5.080
- exact_match (strict) on GSM8Kself-reported3.030