Qwen3-Next-80B-A3B-Instruct β€” GSM8K MTP Finetuned (Epoch 0)

This is Qwen/Qwen3-Next-80B-A3B-Instruct with its native Multi-Token Prediction (MTP) head replaced by a version finetuned on GSM8K math reasoning data to improve speculative decoding acceptance rates.

What changed

The MTP head weights (mtp.* keys) have been replaced with weights finetuned using the speculators library. All other model weights (transformer layers, embeddings, lm_head) are identical to the original Qwen3-Next-80B release.

Training details

Base model Qwen/Qwen3-Next-80B-A3B-Instruct
Training dataset openai/gsm8k (train split, 7473 samples)
Hidden states Regenerated from the verifier (GSM8K responses)
Training framework speculators
Epochs trained 1 (epoch 0 checkpoint)
Learning rate 5e-5
Batch size 16 (4 GPU, FSDP)
Step weights [0.51, 0.31, 0.18] (Ξ²=0.6 exponential decay)
Loss Multi-step MTP loss (3 steps, teacher-forced)

Acceptance rate results

Evaluated on RedHatAI/speculator_benchmarks:math_reasoning.jsonl with num_speculative_tokens=3:

Checkpoint Pos0 Pos1 Pos2 Mean Accepted Tokens
Base (original MTP head) 0.873 0.778 0.675 2.01
This model (epoch 0) 0.934 0.891 0.838 2.46

+22.5% improvement in mean accepted tokens over the base MTP head.

Usage

Load directly with vLLM β€” the MTP head is embedded in the model weights:

from vllm import LLM, SamplingParams

llm = LLM(
    model="inference-optimization/Qwen3-Next-80B-A3B-Instruct-GSM8K-MTP-finetuned",
    tokenizer_mode="auto",
    tensor_parallel_size=4,
    gpu_memory_utilization=0.8,
    speculative_config={
        "method": "mtp",
        "num_speculative_tokens": 3,
    },
    enable_chunked_prefill=False,
)

sampling_params = SamplingParams(temperature=0.6, top_p=0.95)
outputs = llm.generate(["Solve: Janet has 3 apples..."], sampling_params)

Generation pipeline

This checkpoint was produced using the speculators library via the following steps:

  1. Data generation β€” hidden states extracted from Qwen3-Next-80B-A3B-Instruct on GSM8K using examples/fast_mtp/generate_dataset.py
  2. Finetuning β€” MTP head trained with examples/fast_mtp/04_finetune.py
  3. Weight stitching β€” finetuned MTP weights stitched back into the verifier using examples/fast_mtp/stitch_weights.py
Downloads last month
62
Safetensors
Model size
81B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for inference-optimization/Qwen3-Next-80B-A3B-Instruct-GSM8K-MTP-finetuned

Finetuned
(36)
this model

Collection including inference-optimization/Qwen3-Next-80B-A3B-Instruct-GSM8K-MTP-finetuned