Buckets:
Distillation Trainer
Overview
The Distillation Trainer implements on-policy knowledge distillation as described in On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution.
The DistillationTrainer is designed for distilling teacher models of all sizes into smaller students efficiently. It extends the ideas from the GKDTrainer with three key optimizations:
- Generation buffer – decouples the training microbatch size from the generation batch size, letting vLLM batch many prompts in a single call across gradient accumulation steps. This alone can speed up training by up to 40x.
- Teacher server support – moves the teacher to an external vLLM server so it does not need to fit on the same GPUs as the student.
- Binary-encoded logprob payloads – packs log-probabilities into base64-encoded NumPy arrays instead of nested JSON lists, shrinking transfer payloads by ~5x.
The Distillation Trainer is currently part of the
trl.experimentalnamespace. APIs may change without notice while the feature is iterated on.
Quick start
from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer
# 1. Load dataset and format as prompt-only chat messages
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
remove_columns=dataset.column_names,
)
# 2. Configure distillation
config = DistillationConfig(
output_dir="results/distill-qwen-gsm8k",
num_train_epochs=1,
bf16=True,
save_strategy="no",
# Distillation
lmbda=1.0, # fully on-policy (student generates)
beta=1.0, # reverse KL
# Teacher
teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
)
# 3. Train
trainer = DistillationTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
teacher_model="Qwen/Qwen2.5-7B-Instruct",
args=config,
train_dataset=dataset,
)
trainer.train()
trainer.save_model()
Usage tips
The experimental.distillation.DistillationTrainer needs three key parameters set via experimental.distillation.DistillationConfig:
lmbda: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. Whenlmbda=0.0, training is fully off-policy (dataset completions only). Whenlmbda=1.0, training is fully on-policy (student generates all completions). For values in between, each gradient accumulation slice is randomly assigned as on- or off-policy based onlmbda.beta: controls the interpolation in the Generalized Jensen-Shannon Divergence. Whenbeta=0.0the loss approximates forward KL divergence, whilebeta=1.0approximates reverse KL divergence. Values in between interpolate.loss_top_k: number of top tokens to use for the KL/JSD loss. Set to0for exact full-vocabulary computation (local teacher only), or> 0for a top-k approximation. See more about top-k with external teacher server below.
On-policy vs. off-policy
Setting lmbda=1.0 (fully on-policy) generally outperforms off-policy distillation because the student learns from its own mistakes rather than imitating trajectories it may never produce. The generation buffer ensures on-policy training stays efficient: prompts across gradient accumulation steps are batched into a single vLLM call.
Using an external teacher server
For teachers that do not fit on training GPUs (e.g., 100B+ parameters), host the teacher on a separate vLLM server and set use_teacher_server=True with teacher_model_server_url:
config = DistillationConfig(
output_dir="distilled-model",
use_teacher_server=True,
teacher_model_server_url="http://teacher-host:8000",
loss_top_k=1, # required with teacher server when beta > 0
beta=1.0,
lmbda=1.0,
)
trainer = DistillationTrainer(
model="Qwen/Qwen3-4B",
args=config,
train_dataset=dataset,
)
trainer.train()
When using the teacher server:
loss_top_kmust be> 0whenbeta=0.0(forward KL)loss_top_kmust be exactly1whenbeta > 0(reverse KL or JSD)reverse_kl_top_1_mode="argmax"is not supported- Liger kernel is not supported
Expected dataset type
The dataset should be formatted as a conversational language modeling dataset:
{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}
When using fully on-policy distillation (lmbda=1.0), the assistant turn can be omitted since the student will generate its own completions:
{"messages": [{"role": "user", "content": "What color is the sky?"}]}
Example script
Use trl/experimental/distillation/distillation.py to launch distillation training from the command line. The script supports full training, mixed on/off-policy, and LoRA via the standard ModelConfig flags.
# Full training (off-policy only, lmbda=0):
python trl/experimental/distillation/distillation.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--lmbda 0.0 \
--output_dir distilled-model \
--num_train_epochs 1
# Mixed on/off-policy (lmbda=0.5):
python trl/experimental/distillation/distillation.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name trl-lib/chatbot_arena_completions \
--learning_rate 2e-5 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--lmbda 0.5 \
--beta 0.5 \
--output_dir distilled-model \
--num_train_epochs 1
DistillationTrainer[[trl.experimental.distillation.DistillationTrainer]]
trl.experimental.distillation.DistillationTrainer[[trl.experimental.distillation.DistillationTrainer]]
Trainer for knowledge distillation from a teacher model to a student model.
Supports:
- Generalized JSD loss (forward KL, reverse KL, or interpolated JSD via
beta) - On-policy / off-policy mixing via
lmbda(buffered across gradient accumulation) - Local teacher model or external teacher via vLLM server
- Student on-policy generation via vLLM or model.generate()
- Liger kernel for memory-efficient fused JSD loss
traintrl.experimental.distillation.DistillationTrainer.trainhttps://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L1323[{"name": "resume_from_checkpoint", "val": ": str | bool | None = None"}, {"name": "trial", "val": ": optuna.Trial | dict[str, Any] | None = None"}, {"name": "ignore_keys_for_eval", "val": ": list[str] | None = None"}]- resume_from_checkpoint (str or bool, optional) --
If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a
bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance
of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
- trial (
optuna.Trialordict[str, Any], optional) -- The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) -- A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.0~trainer_utils.TrainOutputObject containing the global step count, training loss, and metrics.
Main training entry point.
Parameters:
resume_from_checkpoint (str or bool, optional) : If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
trial (optuna.Trial or dict[str, Any], optional) : The trial run or the hyperparameter dictionary for hyperparameter search.
ignore_keys_for_eval (list[str], optional) : A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns:
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
save_model[[trl.experimental.distillation.DistillationTrainer.save_model]]
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub[[trl.experimental.distillation.DistillationTrainer.push_to_hub]]
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
Parameters:
commit_message (str, optional, defaults to "End of training") : Message to commit while pushing.
blocking (bool, optional, defaults to True) : Whether the function should return only when the git push has finished.
token (str, optional, defaults to None) : Token with write permission to overwrite Trainer's original args.
revision (str, optional) : The git revision to commit from. Defaults to the head of the "main" branch.
kwargs (dict[str, Any], optional) : Additional keyword arguments passed along to ~Trainer.create_model_card.
Returns:
The URL of the repository where the model was pushed if blocking=False, or a Future object tracking the
progress of the commit if blocking=True.
DistillationConfig[[trl.experimental.distillation.DistillationConfig]]
trl.experimental.distillation.DistillationConfig[[trl.experimental.distillation.DistillationConfig]]
Configuration class for the DistillationTrainer.
Extends TrainingArguments with parameters specific to knowledge distillation. This config is independent of SFTConfig — all necessary fields are declared here.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.
Xet Storage Details
- Size:
- 12.3 kB
- Xet hash:
- 4f80649e6d98caa346a9d245053875efbcb0aabdb0f3103435a71c50e7f2ad29
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.