# Distillation Trainer ## Overview The Distillation Trainer implements on-policy knowledge distillation as described in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem. > Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. The `DistillationTrainer` is designed for distilling teacher models of all sizes into smaller students efficiently. It extends the ideas from the `GKDTrainer` with three key optimizations: 1. **Generation buffer** – decouples the training microbatch size from the generation batch size, letting vLLM batch many prompts in a single call across gradient accumulation steps. This alone can speed up training by up to 40x. 2. **Teacher server support** – moves the teacher to an external vLLM server so it does not need to fit on the same GPUs as the student. 3. **Binary-encoded logprob payloads** – packs log-probabilities into base64-encoded NumPy arrays instead of nested JSON lists, shrinking transfer payloads by ~5x. > [!NOTE] > The Distillation Trainer is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on. ## Quick start ```python from datasets import load_dataset from trl.experimental.distillation import DistillationConfig, DistillationTrainer # 1. Load dataset and format as prompt-only chat messages dataset = load_dataset("openai/gsm8k", "main", split="train") dataset = dataset.map( lambda x: {"messages": [{"role": "user", "content": x["question"]}]}, remove_columns=dataset.column_names, ) # 2. Configure distillation config = DistillationConfig( output_dir="results/distill-qwen-gsm8k", num_train_epochs=1, bf16=True, save_strategy="no", # Distillation lmbda=1.0, # fully on-policy (student generates) beta=1.0, # reverse KL # Teacher teacher_model_init_kwargs={"torch_dtype": "bfloat16"}, ) # 3. Train trainer = DistillationTrainer( model="Qwen/Qwen2.5-1.5B-Instruct", teacher_model="Qwen/Qwen2.5-7B-Instruct", args=config, train_dataset=dataset, ) trainer.train() trainer.save_model() ``` ## Usage tips The [`experimental.distillation.DistillationTrainer`] needs three key parameters set via [`experimental.distillation.DistillationConfig`]: * `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, training is fully off-policy (dataset completions only). When `lmbda=1.0`, training is fully on-policy (student generates all completions). For values in between, each gradient accumulation slice is randomly assigned as on- or off-policy based on `lmbda`. * `beta`: controls the interpolation in the Generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while `beta=1.0` approximates reverse KL divergence. Values in between interpolate. * `loss_top_k`: number of top tokens to use for the KL/JSD loss. Set to `0` for exact full-vocabulary computation (local teacher only), or `> 0` for a top-k approximation. See more about top-k with external teacher server below. ### On-policy vs. off-policy Setting `lmbda=1.0` (fully on-policy) generally outperforms off-policy distillation because the student learns from its own mistakes rather than imitating trajectories it may never produce. The generation buffer ensures on-policy training stays efficient: prompts across gradient accumulation steps are batched into a single vLLM call. ### Using an external teacher server For teachers that do not fit on training GPUs (e.g., 100B+ parameters), host the teacher on a separate vLLM server and set `use_teacher_server=True` with `teacher_model_server_url`: ```python config = DistillationConfig( output_dir="distilled-model", use_teacher_server=True, teacher_model_server_url="http://teacher-host:8000", loss_top_k=1, # required with teacher server when beta > 0 beta=1.0, lmbda=1.0, ) trainer = DistillationTrainer( model="Qwen/Qwen3-4B", args=config, train_dataset=dataset, ) trainer.train() ``` When using the teacher server: - `loss_top_k` must be `> 0` when `beta=0.0` (forward KL) - `loss_top_k` must be exactly `1` when `beta > 0` (reverse KL or JSD) - `reverse_kl_top_1_mode="argmax"` is not supported - Liger kernel is not supported ### Expected dataset type The dataset should be formatted as a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset: ```python {"messages": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}]} ``` When using fully on-policy distillation (`lmbda=1.0`), the assistant turn can be omitted since the student will generate its own completions: ```python {"messages": [{"role": "user", "content": "What color is the sky?"}]} ``` ## Example script Use [`trl/experimental/distillation/distillation.py`](https://github.com/huggingface/trl/blob/main/trl/experimental/distillation/distillation.py) to launch distillation training from the command line. The script supports full training, mixed on/off-policy, and LoRA via the standard `ModelConfig` flags. ```bash # Full training (off-policy only, lmbda=0): python trl/experimental/distillation/distillation.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 2e-5 \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 8 \ --lmbda 0.0 \ --output_dir distilled-model \ --num_train_epochs 1 ``` ```bash # Mixed on/off-policy (lmbda=0.5): python trl/experimental/distillation/distillation.py \ --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ --teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ --dataset_name trl-lib/chatbot_arena_completions \ --learning_rate 2e-5 \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 8 \ --lmbda 0.5 \ --beta 0.5 \ --output_dir distilled-model \ --num_train_epochs 1 ``` ## DistillationTrainer [[autodoc]] experimental.distillation.DistillationTrainer - train - save_model - push_to_hub ## DistillationConfig [[autodoc]] experimental.distillation.DistillationConfig