Qwen3-Next-80B-A3B-Instruct β GSM8K MTP Finetuned (Epoch 0)
This is Qwen/Qwen3-Next-80B-A3B-Instruct with its native Multi-Token Prediction (MTP) head replaced by a version finetuned on GSM8K math reasoning data to improve speculative decoding acceptance rates.
What changed
The MTP head weights (mtp.* keys) have been replaced with weights finetuned using the
speculators library. All other model weights
(transformer layers, embeddings, lm_head) are identical to the original Qwen3-Next-80B release.
Training details
| Base model | Qwen/Qwen3-Next-80B-A3B-Instruct |
| Training dataset | openai/gsm8k (train split, 7473 samples) |
| Hidden states | Regenerated from the verifier (GSM8K responses) |
| Training framework | speculators |
| Epochs trained | 1 (epoch 0 checkpoint) |
| Learning rate | 5e-5 |
| Batch size | 16 (4 GPU, FSDP) |
| Step weights | [0.51, 0.31, 0.18] (Ξ²=0.6 exponential decay) |
| Loss | Multi-step MTP loss (3 steps, teacher-forced) |
Acceptance rate results
Evaluated on RedHatAI/speculator_benchmarks:math_reasoning.jsonl with num_speculative_tokens=3:
| Checkpoint | Pos0 | Pos1 | Pos2 | Mean Accepted Tokens |
|---|---|---|---|---|
| Base (original MTP head) | 0.873 | 0.778 | 0.675 | 2.01 |
| This model (epoch 0) | 0.934 | 0.891 | 0.838 | 2.46 |
+22.5% improvement in mean accepted tokens over the base MTP head.
Usage
Load directly with vLLM β the MTP head is embedded in the model weights:
from vllm import LLM, SamplingParams
llm = LLM(
model="inference-optimization/Qwen3-Next-80B-A3B-Instruct-GSM8K-MTP-finetuned",
tokenizer_mode="auto",
tensor_parallel_size=4,
gpu_memory_utilization=0.8,
speculative_config={
"method": "mtp",
"num_speculative_tokens": 3,
},
enable_chunked_prefill=False,
)
sampling_params = SamplingParams(temperature=0.6, top_p=0.95)
outputs = llm.generate(["Solve: Janet has 3 apples..."], sampling_params)
Generation pipeline
This checkpoint was produced using the speculators library via the following steps:
- Data generation β hidden states extracted from
Qwen3-Next-80B-A3B-Instructon GSM8K usingexamples/fast_mtp/generate_dataset.py - Finetuning β MTP head trained with
examples/fast_mtp/04_finetune.py - Weight stitching β finetuned MTP weights stitched back into the verifier using
examples/fast_mtp/stitch_weights.py
- Downloads last month
- 62
Model tree for inference-optimization/Qwen3-Next-80B-A3B-Instruct-GSM8K-MTP-finetuned
Base model
Qwen/Qwen3-Next-80B-A3B-Instruct