Llama-3.2-1B-MTP-k8: Multi-Token Prediction on a Single Consumer GPU

This is a reproduction of "Multi-Token Prediction via Self-Distillation" (arXiv 2602.06019) adapted for a single NVIDIA RTX 5090 (32GB). The original paper used 4x NVIDIA GH200 (384GB total) with Llama-3.1-8B. We scaled it down to Llama-3.2-1B on consumer hardware.

What is Multi-Token Prediction (MTP)?

Standard language models predict one token at a time (autoregressive decoding). MTP trains the model to predict multiple future tokens simultaneously using online self-distillation:

  1. A frozen teacher (the original model) generates soft probability distributions
  2. A trainable student (same architecture) learns to predict k future tokens at each position
  3. At inference, ConfAdapt decoding emits multiple tokens when the model is confident, falling back to single-token when uncertain

The result: faster inference with minimal quality loss.

Results: GSM8K 8-shot Chain-of-Thought

Configuration Exact Match (flexible) Exact Match (strict) Throughput
Baseline (Llama-3.2-1B, standard AR) 7.13% ± 0.71 6.07% ± 0.66 ~1.5 s/sample
MTP k=1 (single token, quality check) 5.23% ± 0.61 2.96% ± 0.47 ~2.4 s/sample
MTP k=8 + ConfAdapt 90% 5.08% ± 0.60 3.03% ± 0.47 ~1.3 s/sample

Key Findings

  • ConfAdapt works: k=8 with ConfAdapt matches k=1 quality while being 1.8x faster (avg 2.82 tokens emitted per step)
  • Quality drop is expected: The ~2% accuracy drop from baseline is consistent with our smaller setup (1B model, 500M training tokens vs paper's 8B model, 2B tokens)
  • The core claim holds: Multi-token decoding via ConfAdapt preserves generation quality while improving throughput, even on a tiny 1B model

Training Details

What We Changed from the Paper

Parameter Paper (8B / 4x GH200) Ours (1B / 1x RTX 5090)
Base model Llama-3.1-8B Llama-3.2-1B
GPUs 4x GH200 (96GB each) 1x RTX 5090 (32GB)
FSDP mesh 1x4 1x1 (no FSDP)
k_toks Randomized 2-16 across ranks Fixed 8
Training tokens 2B 500M
micro_batch_size 32 8
global_batch_size 128 64 (grad accumulation)
mask_region_ct 5 1
rollout_multiplier 4 2
Template Chat (Instruct tokenizer) Plain text (base tokenizer)

What We Kept the Same

  • Supervision method: Soft teacher via KL divergence (paper's recommended self-distillation)
  • Dataset: MetaMathQA (jwkirchenbauer/metamathqa-grouped-split)
  • Sequence length: 160 tokens
  • Peak learning rate: 1e-5
  • Optimizer: AdamW with cosine decay

Training Metrics

  • Total steps: 48,828
  • Training time: ~17 hours on RTX 5090
  • Final train loss: ~0.9
  • Final val loss: 1.895 (perplexity 6.65)

Why k=8 Instead of the Paper's Randomized k=2-16?

The paper's approach randomizes k across GPU ranks each step. With 4 GPUs, the model sees k=2, k=5, k=12, k=16 simultaneously in a single batch, learning to handle any prediction horizon.

With a single GPU, we can only train one k value per step. We chose k=8 as a middle ground — large enough to demonstrate meaningful multi-token speedup, small enough to fit in 32GB VRAM.

This is an important tradeoff: our model is specialized for k=8, while the paper's model generalizes across all k values. A production deployment would benefit from the paper's multi-GPU randomized approach.

Infrastructure: Running on Consumer Hardware

This reproduction ran entirely on a home Kubernetes cluster:

  • GPU: NVIDIA RTX 5090 (32GB, Blackwell architecture / sm_120)
  • System: 16GB RAM, Debian 13
  • Stack: Kubernetes + containerd + NVIDIA device plugin
  • PyTorch: Nightly build with CUDA 12.8 (required for Blackwell sm_120 support)

Challenges We Solved

  1. Blackwell GPU support: RTX 5090 (sm_120) requires PyTorch nightly with cu128 — stable releases don't include sm_120 yet
  2. Single-GPU checkpoint saving: The original code uses torch.distributed.all_reduce() for checkpoint state sync, which crashes when distributed is not initialized. We added an is_initialized() guard
  3. W&B configuration: Default config points to the paper authors' organization. Override with --wandb.entity=null
  4. HuggingFace checkpoint format: The litgpt converter outputs model.pth but transformers expects pytorch_model.bin

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "celestialcreator/Llama-3.2-1B-MTP-k8",
    trust_remote_code=True,
    torch_dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained("celestialcreator/Llama-3.2-1B-MTP-k8")

# Standard generation (single token, works like any Llama model)
inputs = tokenizer("The capital of France is", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

For MTP inference with ConfAdapt decoding, use the mtp-lm evaluation harness fork.

Reproduction Guide

Full reproduction instructions with Kubernetes manifests and configs: GitHub Fork

Citation

If you use this model, please cite the original paper:

@article{kirchenbauer2025multitokenpredictionselfdistillation,
  title={Multi-Token Prediction via Self-Distillation},
  author={John Kirchenbauer and Jonas Geiping and Yuxin Wen and Tom Goldstein},
  journal={arXiv preprint arXiv:2602.06019},
  year={2025}
}

Acknowledgments

Downloads last month
76
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for celestialcreator/Llama-3.2-1B-MTP-k8

Finetuned
(883)
this model

Dataset used to train celestialcreator/Llama-3.2-1B-MTP-k8

Paper for celestialcreator/Llama-3.2-1B-MTP-k8

Evaluation results