Buckets:

HuggingFaceDocBuilder's picture
download
raw
55.1 kB
import{s as Ht,n as Vt,o as Yt}from"../chunks/scheduler.7b731bd4.js";import{S as Pt,i as Kt,e as i,s as l,c as m,h as Ot,a as s,d as o,b as a,f as de,g as c,j as d,k as U,l as r,m as n,n as p,t as u,o as _,p as g}from"../chunks/index.cc268345.js";import{C as eo,H as M,E as to}from"../chunks/MermaidChart.svelte_svelte_type_style_lang.f0d99f98.js";import{D as Je}from"../chunks/Docstring.03f7b462.js";import{C as ye}from"../chunks/CodeBlock.169a125f.js";function oo(wt){let y,xe,Te,Ue,N,Ne,I,Ie,j,je,C,Jt='The Distillation Trainer implements on-policy knowledge distillation as described in <a href="https://huggingface.co/papers/2306.13649" rel="nofollow">On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes</a> by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.',Ce,k,xt="<p>Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher’s distribution.</p>",ke,$,Ut="The <code>DistillationTrainer</code> is designed for distilling teacher models of all sizes into smaller students efficiently. It extends the ideas from the <code>GKDTrainer</code> with three key optimizations:",$e,L,Nt="<li><strong>Generation buffer</strong> – decouples the training microbatch size from the generation batch size, letting vLLM batch many prompts in a single call across gradient accumulation steps. This alone can speed up training by up to 40x.</li> <li><strong>Teacher server support</strong> – moves the teacher to an external vLLM server so it does not need to fit on the same GPUs as the student.</li> <li><strong>Binary-encoded logprob payloads</strong> – packs log-probabilities into base64-encoded NumPy arrays instead of nested JSON lists, shrinking transfer payloads by ~5x.</li>",Le,T,It="<p>The Distillation Trainer is currently part of the <code>trl.experimental</code> namespace. APIs may change without notice while the feature is iterated on.</p>",De,D,Ge,G,We,W,qe,q,jt='The <a href="/docs/trl/pr_5607/en/distillation_trainer#trl.experimental.distillation.DistillationTrainer">experimental.distillation.DistillationTrainer</a> needs three key parameters set via <a href="/docs/trl/pr_5607/en/distillation_trainer#trl.experimental.distillation.DistillationConfig">experimental.distillation.DistillationConfig</a>:',Be,B,Ct="<li><code>lmbda</code>: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When <code>lmbda=0.0</code>, training is fully off-policy (dataset completions only). When <code>lmbda=1.0</code>, training is fully on-policy (student generates all completions). For values in between, each gradient accumulation slice is randomly assigned as on- or off-policy based on <code>lmbda</code>.</li> <li><code>beta</code>: controls the interpolation in the Generalized Jensen-Shannon Divergence. When <code>beta=0.0</code> the loss approximates forward KL divergence, while <code>beta=1.0</code> approximates reverse KL divergence. Values in between interpolate.</li> <li><code>loss_top_k</code>: number of top tokens to use for the KL/JSD loss. Set to <code>0</code> for exact full-vocabulary computation (local teacher only), or <code>&gt; 0</code> for a top-k approximation. See more about top-k with external teacher server below.</li>",Fe,F,Ze,Z,kt="Setting <code>lmbda=1.0</code> (fully on-policy) generally outperforms off-policy distillation because the student learns from its own mistakes rather than imitating trajectories it may never produce. The generation buffer ensures on-policy training stays efficient: prompts across gradient accumulation steps are batched into a single vLLM call.",Ae,A,Re,R,$t="For teachers that do not fit on training GPUs (e.g., 100B+ parameters), host the teacher on a separate vLLM server and set <code>use_teacher_server=True</code> with <code>teacher_model_server_url</code>:",ze,z,Ee,E,Lt="When using the teacher server:",Qe,Q,Dt="<li><code>loss_top_k</code> must be <code>&gt; 0</code> when <code>beta=0.0</code> (forward KL)</li> <li><code>loss_top_k</code> must be exactly <code>1</code> when <code>beta &gt; 0</code> (reverse KL or JSD)</li> <li><code>reverse_kl_top_1_mode=&quot;argmax&quot;</code> is not supported</li> <li>Liger kernel is not supported</li>",Se,S,Xe,X,Gt='The dataset should be formatted as a <a href="dataset_formats#conversational">conversational</a> <a href="dataset_formats#language_modeling">language modeling</a> dataset:',He,H,Ve,V,Wt="When using fully on-policy distillation (<code>lmbda=1.0</code>), the assistant turn can be omitted since the student will generate its own completions:",Ye,Y,Pe,P,Ke,K,qt='Use <a href="https://github.com/huggingface/trl/blob/main/trl/experimental/distillation/distillation.py" rel="nofollow"><code>trl/experimental/distillation/distillation.py</code></a> to launch distillation training from the command line. The script supports full training, mixed on/off-policy, and LoRA via the standard <code>ModelConfig</code> flags.',Oe,O,et,ee,tt,te,ot,h,oe,mt,me,Bt="Trainer for knowledge distillation from a teacher model to a student model.",ct,ce,Ft="Supports:",pt,pe,Zt="<li>Generalized JSD loss (forward KL, reverse KL, or interpolated JSD via <code>beta</code>)</li> <li>On-policy / off-policy mixing via <code>lmbda</code> (buffered across gradient accumulation)</li> <li>Local teacher model or external teacher via vLLM server</li> <li>Student on-policy generation via vLLM or model.generate()</li> <li>Liger kernel for memory-efficient fused JSD loss</li>",ut,w,ne,_t,ue,At="Main training entry point.",gt,b,le,ht,_e,Rt="Will save the model, so you can reload it using <code>from_pretrained()</code>.",ft,ge,zt="Will only save from the main process.",vt,J,ae,bt,he,Et="Upload <code>self.model</code> and <code>self.processing_class</code> to the 🤗 model hub on the repo <code>self.args.hub_model_id</code>.",nt,ie,lt,v,se,Mt,fe,Qt="Configuration class for the <code>DistillationTrainer</code>.",yt,ve,St=`Extends <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a> with parameters specific to knowledge distillation. This config is
independent of <a href="/docs/trl/pr_5607/en/sft_trainer#trl.SFTConfig">SFTConfig</a> — all necessary fields are declared here.`,Tt,be,Xt=`Using <a href="https://huggingface.co/docs/transformers/main/en/internal/trainer_utils#transformers.HfArgumentParser" rel="nofollow">HfArgumentParser</a> we can turn this class into
<a href="https://docs.python.org/3/library/argparse#module-argparse" rel="nofollow">argparse</a> arguments that can be specified on the
command line.`,at,re,it,we,st;return N=new eo({props:{containerStyle:"float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"}}),I=new M({props:{title:"Distillation Trainer",local:"distillation-trainer",headingTag:"h1"}}),j=new M({props:{title:"Overview",local:"overview",headingTag:"h2"}}),D=new M({props:{title:"Quick start",local:"quick-start",headingTag:"h2"}}),G=new ye({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwbG9hZF9kYXRhc2V0JTBBZnJvbSUyMHRybC5leHBlcmltZW50YWwuZGlzdGlsbGF0aW9uJTIwaW1wb3J0JTIwRGlzdGlsbGF0aW9uQ29uZmlnJTJDJTIwRGlzdGlsbGF0aW9uVHJhaW5lciUwQSUwQSUyMyUyMDEuJTIwTG9hZCUyMGRhdGFzZXQlMjBhbmQlMjBmb3JtYXQlMjBhcyUyMHByb21wdC1vbmx5JTIwY2hhdCUyMG1lc3NhZ2VzJTBBZGF0YXNldCUyMCUzRCUyMGxvYWRfZGF0YXNldCglMjJvcGVuYWklMkZnc204ayUyMiUyQyUyMCUyMm1haW4lMjIlMkMlMjBzcGxpdCUzRCUyMnRyYWluJTIyKSUwQWRhdGFzZXQlMjAlM0QlMjBkYXRhc2V0Lm1hcCglMEElMjAlMjAlMjAlMjBsYW1iZGElMjB4JTNBJTIwJTdCJTIybWVzc2FnZXMlMjIlM0ElMjAlNUIlN0IlMjJyb2xlJTIyJTNBJTIwJTIydXNlciUyMiUyQyUyMCUyMmNvbnRlbnQlMjIlM0ElMjB4JTVCJTIycXVlc3Rpb24lMjIlNUQlN0QlNUQlN0QlMkMlMEElMjAlMjAlMjAlMjByZW1vdmVfY29sdW1ucyUzRGRhdGFzZXQuY29sdW1uX25hbWVzJTJDJTBBKSUwQSUwQSUyMyUyMDIuJTIwQ29uZmlndXJlJTIwZGlzdGlsbGF0aW9uJTBBY29uZmlnJTIwJTNEJTIwRGlzdGlsbGF0aW9uQ29uZmlnKCUwQSUyMCUyMCUyMCUyMG91dHB1dF9kaXIlM0QlMjJyZXN1bHRzJTJGZGlzdGlsbC1xd2VuLWdzbThrJTIyJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluX2Vwb2NocyUzRDElMkMlMEElMjAlMjAlMjAlMjBiZjE2JTNEVHJ1ZSUyQyUwQSUyMCUyMCUyMCUyMHNhdmVfc3RyYXRlZ3klM0QlMjJubyUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMyUyMERpc3RpbGxhdGlvbiUwQSUyMCUyMCUyMCUyMGxtYmRhJTNEMS4wJTJDJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIzJTIwZnVsbHklMjBvbi1wb2xpY3klMjAoc3R1ZGVudCUyMGdlbmVyYXRlcyklMEElMjAlMjAlMjAlMjBiZXRhJTNEMS4wJTJDJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIzJTIwcmV2ZXJzZSUyMEtMJTBBJTIwJTIwJTIwJTIwJTIzJTIwVGVhY2hlciUwQSUyMCUyMCUyMCUyMHRlYWNoZXJfbW9kZWxfaW5pdF9rd2FyZ3MlM0QlN0IlMjJ0b3JjaF9kdHlwZSUyMiUzQSUyMCUyMmJmbG9hdDE2JTIyJTdEJTJDJTBBKSUwQSUwQSUyMyUyMDMuJTIwVHJhaW4lMEF0cmFpbmVyJTIwJTNEJTIwRGlzdGlsbGF0aW9uVHJhaW5lciglMEElMjAlMjAlMjAlMjBtb2RlbCUzRCUyMlF3ZW4lMkZRd2VuMi41LTEuNUItSW5zdHJ1Y3QlMjIlMkMlMEElMjAlMjAlMjAlMjB0ZWFjaGVyX21vZGVsJTNEJTIyUXdlbiUyRlF3ZW4yLjUtN0ItSW5zdHJ1Y3QlMjIlMkMlMEElMjAlMjAlMjAlMjBhcmdzJTNEY29uZmlnJTJDJTBBJTIwJTIwJTIwJTIwdHJhaW5fZGF0YXNldCUzRGRhdGFzZXQlMkMlMEEpJTBBdHJhaW5lci50cmFpbigpJTBBdHJhaW5lci5zYXZlX21vZGVsKCk=",highlighted:`<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> trl.experimental.distillation <span class="hljs-keyword">import</span> DistillationConfig, DistillationTrainer
<span class="hljs-comment"># 1. Load dataset and format as prompt-only chat messages</span>
dataset = load_dataset(<span class="hljs-string">&quot;openai/gsm8k&quot;</span>, <span class="hljs-string">&quot;main&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
dataset = dataset.<span class="hljs-built_in">map</span>(
<span class="hljs-keyword">lambda</span> x: {<span class="hljs-string">&quot;messages&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: x[<span class="hljs-string">&quot;question&quot;</span>]}]},
remove_columns=dataset.column_names,
)
<span class="hljs-comment"># 2. Configure distillation</span>
config = DistillationConfig(
output_dir=<span class="hljs-string">&quot;results/distill-qwen-gsm8k&quot;</span>,
num_train_epochs=<span class="hljs-number">1</span>,
bf16=<span class="hljs-literal">True</span>,
save_strategy=<span class="hljs-string">&quot;no&quot;</span>,
<span class="hljs-comment"># Distillation</span>
lmbda=<span class="hljs-number">1.0</span>, <span class="hljs-comment"># fully on-policy (student generates)</span>
beta=<span class="hljs-number">1.0</span>, <span class="hljs-comment"># reverse KL</span>
<span class="hljs-comment"># Teacher</span>
teacher_model_init_kwargs={<span class="hljs-string">&quot;torch_dtype&quot;</span>: <span class="hljs-string">&quot;bfloat16&quot;</span>},
)
<span class="hljs-comment"># 3. Train</span>
trainer = DistillationTrainer(
model=<span class="hljs-string">&quot;Qwen/Qwen2.5-1.5B-Instruct&quot;</span>,
teacher_model=<span class="hljs-string">&quot;Qwen/Qwen2.5-7B-Instruct&quot;</span>,
args=config,
train_dataset=dataset,
)
trainer.train()
trainer.save_model()`,wrap:!1}}),W=new M({props:{title:"Usage tips",local:"usage-tips",headingTag:"h2"}}),F=new M({props:{title:"On-policy vs. off-policy",local:"on-policy-vs-off-policy",headingTag:"h3"}}),A=new M({props:{title:"Using an external teacher server",local:"using-an-external-teacher-server",headingTag:"h3"}}),z=new ye({props:{code:"Y29uZmlnJTIwJTNEJTIwRGlzdGlsbGF0aW9uQ29uZmlnKCUwQSUyMCUyMCUyMCUyMG91dHB1dF9kaXIlM0QlMjJkaXN0aWxsZWQtbW9kZWwlMjIlMkMlMEElMjAlMjAlMjAlMjB1c2VfdGVhY2hlcl9zZXJ2ZXIlM0RUcnVlJTJDJTBBJTIwJTIwJTIwJTIwdGVhY2hlcl9tb2RlbF9zZXJ2ZXJfdXJsJTNEJTIyaHR0cCUzQSUyRiUyRnRlYWNoZXItaG9zdCUzQTgwMDAlMjIlMkMlMEElMjAlMjAlMjAlMjBsb3NzX3RvcF9rJTNEMSUyQyUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMyUyMHJlcXVpcmVkJTIwd2l0aCUyMHRlYWNoZXIlMjBzZXJ2ZXIlMjB3aGVuJTIwYmV0YSUyMCUzRSUyMDAlMEElMjAlMjAlMjAlMjBiZXRhJTNEMS4wJTJDJTBBJTIwJTIwJTIwJTIwbG1iZGElM0QxLjAlMkMlMEEpJTBBJTBBdHJhaW5lciUyMCUzRCUyMERpc3RpbGxhdGlvblRyYWluZXIoJTBBJTIwJTIwJTIwJTIwbW9kZWwlM0QlMjJRd2VuJTJGUXdlbjMtNEIlMjIlMkMlMEElMjAlMjAlMjAlMjBhcmdzJTNEY29uZmlnJTJDJTBBJTIwJTIwJTIwJTIwdHJhaW5fZGF0YXNldCUzRGRhdGFzZXQlMkMlMEEpJTBBdHJhaW5lci50cmFpbigp",highlighted:`config = DistillationConfig(
output_dir=<span class="hljs-string">&quot;distilled-model&quot;</span>,
use_teacher_server=<span class="hljs-literal">True</span>,
teacher_model_server_url=<span class="hljs-string">&quot;http://teacher-host:8000&quot;</span>,
loss_top_k=<span class="hljs-number">1</span>, <span class="hljs-comment"># required with teacher server when beta &gt; 0</span>
beta=<span class="hljs-number">1.0</span>,
lmbda=<span class="hljs-number">1.0</span>,
)
trainer = DistillationTrainer(
model=<span class="hljs-string">&quot;Qwen/Qwen3-4B&quot;</span>,
args=config,
train_dataset=dataset,
)
trainer.train()`,wrap:!1}}),S=new M({props:{title:"Expected dataset type",local:"expected-dataset-type",headingTag:"h3"}}),H=new ye({props:{code:"JTdCJTIybWVzc2FnZXMlMjIlM0ElMjAlNUIlN0IlMjJyb2xlJTIyJTNBJTIwJTIydXNlciUyMiUyQyUyMCUyMmNvbnRlbnQlMjIlM0ElMjAlMjJXaGF0JTIwY29sb3IlMjBpcyUyMHRoZSUyMHNreSUzRiUyMiU3RCUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCU3QiUyMnJvbGUlMjIlM0ElMjAlMjJhc3Npc3RhbnQlMjIlMkMlMjAlMjJjb250ZW50JTIyJTNBJTIwJTIySXQlMjBpcyUyMGJsdWUuJTIyJTdEJTVEJTdE",highlighted:`{<span class="hljs-string">&quot;messages&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;What color is the sky?&quot;</span>},
{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;It is blue.&quot;</span>}]}`,wrap:!1}}),Y=new ye({props:{code:"JTdCJTIybWVzc2FnZXMlMjIlM0ElMjAlNUIlN0IlMjJyb2xlJTIyJTNBJTIwJTIydXNlciUyMiUyQyUyMCUyMmNvbnRlbnQlMjIlM0ElMjAlMjJXaGF0JTIwY29sb3IlMjBpcyUyMHRoZSUyMHNreSUzRiUyMiU3RCU1RCU3RA==",highlighted:'{<span class="hljs-string">&quot;messages&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;What color is the sky?&quot;</span>}]}',wrap:!1}}),P=new M({props:{title:"Example script",local:"example-script",headingTag:"h2"}}),O=new ye({props:{code:"JTIzJTIwRnVsbCUyMHRyYWluaW5nJTIwKG9mZi1wb2xpY3klMjBvbmx5JTJDJTIwbG1iZGElM0QwKSUzQSUwQXB5dGhvbiUyMHRybCUyRmV4cGVyaW1lbnRhbCUyRmRpc3RpbGxhdGlvbiUyRmRpc3RpbGxhdGlvbi5weSUyMCU1QyUwQSUyMCUyMCUyMCUyMC0tbW9kZWxfbmFtZV9vcl9wYXRoJTIwUXdlbiUyRlF3ZW4yLjUtMC41Qi1JbnN0cnVjdCUyMCU1QyUwQSUyMCUyMCUyMCUyMC0tdGVhY2hlcl9tb2RlbF9uYW1lX29yX3BhdGglMjBRd2VuJTJGUXdlbjIuNS0xLjVCLUluc3RydWN0JTIwJTVDJTBBJTIwJTIwJTIwJTIwLS1kYXRhc2V0X25hbWUlMjB0cmwtbGliJTJGY2hhdGJvdF9hcmVuYV9jb21wbGV0aW9ucyUyMCU1QyUwQSUyMCUyMCUyMCUyMC0tbGVhcm5pbmdfcmF0ZSUyMDJlLTUlMjAlNUMlMEElMjAlMjAlMjAlMjAtLXBlcl9kZXZpY2VfdHJhaW5fYmF0Y2hfc2l6ZSUyMDQlMjAlNUMlMEElMjAlMjAlMjAlMjAtLWdyYWRpZW50X2FjY3VtdWxhdGlvbl9zdGVwcyUyMDglMjAlNUMlMEElMjAlMjAlMjAlMjAtLWxtYmRhJTIwMC4wJTIwJTVDJTBBJTIwJTIwJTIwJTIwLS1vdXRwdXRfZGlyJTIwZGlzdGlsbGVkLW1vZGVsJTIwJTVDJTBBJTIwJTIwJTIwJTIwLS1udW1fdHJhaW5fZXBvY2hzJTIwMQ==",highlighted:`<span class="hljs-comment"># Full training (off-policy only, lmbda=0):</span>
python trl/experimental/distillation/distillation.py \\
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \\
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \\
--dataset_name trl-lib/chatbot_arena_completions \\
--learning_rate 2e-5 \\
--per_device_train_batch_size 4 \\
--gradient_accumulation_steps 8 \\
--lmbda 0.0 \\
--output_dir distilled-model \\
--num_train_epochs 1`,wrap:!1}}),ee=new ye({props:{code:"JTIzJTIwTWl4ZWQlMjBvbiUyRm9mZi1wb2xpY3klMjAobG1iZGElM0QwLjUpJTNBJTBBcHl0aG9uJTIwdHJsJTJGZXhwZXJpbWVudGFsJTJGZGlzdGlsbGF0aW9uJTJGZGlzdGlsbGF0aW9uLnB5JTIwJTVDJTBBJTIwJTIwJTIwJTIwLS1tb2RlbF9uYW1lX29yX3BhdGglMjBRd2VuJTJGUXdlbjIuNS0wLjVCLUluc3RydWN0JTIwJTVDJTBBJTIwJTIwJTIwJTIwLS10ZWFjaGVyX21vZGVsX25hbWVfb3JfcGF0aCUyMFF3ZW4lMkZRd2VuMi41LTEuNUItSW5zdHJ1Y3QlMjAlNUMlMEElMjAlMjAlMjAlMjAtLWRhdGFzZXRfbmFtZSUyMHRybC1saWIlMkZjaGF0Ym90X2FyZW5hX2NvbXBsZXRpb25zJTIwJTVDJTBBJTIwJTIwJTIwJTIwLS1sZWFybmluZ19yYXRlJTIwMmUtNSUyMCU1QyUwQSUyMCUyMCUyMCUyMC0tcGVyX2RldmljZV90cmFpbl9iYXRjaF9zaXplJTIwNCUyMCU1QyUwQSUyMCUyMCUyMCUyMC0tZ3JhZGllbnRfYWNjdW11bGF0aW9uX3N0ZXBzJTIwOCUyMCU1QyUwQSUyMCUyMCUyMCUyMC0tbG1iZGElMjAwLjUlMjAlNUMlMEElMjAlMjAlMjAlMjAtLWJldGElMjAwLjUlMjAlNUMlMEElMjAlMjAlMjAlMjAtLW91dHB1dF9kaXIlMjBkaXN0aWxsZWQtbW9kZWwlMjAlNUMlMEElMjAlMjAlMjAlMjAtLW51bV90cmFpbl9lcG9jaHMlMjAx",highlighted:`<span class="hljs-comment"># Mixed on/off-policy (lmbda=0.5):</span>
python trl/experimental/distillation/distillation.py \\
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \\
--teacher_model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \\
--dataset_name trl-lib/chatbot_arena_completions \\
--learning_rate 2e-5 \\
--per_device_train_batch_size 4 \\
--gradient_accumulation_steps 8 \\
--lmbda 0.5 \\
--beta 0.5 \\
--output_dir distilled-model \\
--num_train_epochs 1`,wrap:!1}}),te=new M({props:{title:"DistillationTrainer",local:"trl.experimental.distillation.DistillationTrainer",headingTag:"h2"}}),oe=new Je({props:{name:"class trl.experimental.distillation.DistillationTrainer",anchor:"trl.experimental.distillation.DistillationTrainer",parameters:[{name:"model",val:": transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module | str | None = None"},{name:"teacher_model",val:": transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module | str = None"},{name:"args",val:": trl.experimental.distillation.distillation_config.DistillationConfig | None = None"},{name:"data_collator",val:": collections.abc.Callable[[list[typing.Any]], dict[str, typing.Any]] | None = None"},{name:"train_dataset",val:": datasets.arrow_dataset.Dataset | None = None"},{name:"eval_dataset",val:": datasets.arrow_dataset.Dataset | dict[str, datasets.arrow_dataset.Dataset] | None = None"},{name:"processing_class",val:": transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.image_processing_utils.BaseImageProcessor | transformers.feature_extraction_utils.FeatureExtractionMixin | transformers.processing_utils.ProcessorMixin | None = None"},{name:"compute_metrics",val:": collections.abc.Callable[[transformers.trainer_utils.EvalPrediction], dict] | None = None"},{name:"callbacks",val:": list[transformers.trainer_callback.TrainerCallback] | None = None"},{name:"optimizers",val:": tuple = (None, None)"},{name:"preprocess_logits_for_metrics",val:": collections.abc.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None"},{name:"peft_config",val:": typing.Optional[ForwardRef('PeftConfig')] = None"}],source:"https://github.com/huggingface/trl/blob/vr_5607/trl/experimental/distillation/distillation_trainer.py#L356"}}),ne=new Je({props:{name:"train",anchor:"trl.experimental.distillation.DistillationTrainer.train",parameters:[{name:"resume_from_checkpoint",val:": str | bool | None = None"},{name:"trial",val:": optuna.Trial | dict[str, Any] | None = None"},{name:"ignore_keys_for_eval",val:": list[str] | None = None"}],parametersDescription:[{anchor:"trl.experimental.distillation.DistillationTrainer.train.resume_from_checkpoint",description:`<strong>resume_from_checkpoint</strong> (<code>str</code> or <code>bool</code>, <em>optional</em>) &#x2014;
If a <code>str</code>, local path to a saved checkpoint as saved by a previous instance of <code>Trainer</code>. If a
<code>bool</code> and equals <code>True</code>, load the last checkpoint in <em>args.output_dir</em> as saved by a previous instance
of <code>Trainer</code>. If present, training will resume from the model/optimizer/scheduler states loaded here.`,name:"resume_from_checkpoint"},{anchor:"trl.experimental.distillation.DistillationTrainer.train.trial",description:`<strong>trial</strong> (<code>optuna.Trial</code> or <code>dict[str, Any]</code>, <em>optional</em>) &#x2014;
The trial run or the hyperparameter dictionary for hyperparameter search.`,name:"trial"},{anchor:"trl.experimental.distillation.DistillationTrainer.train.ignore_keys_for_eval",description:`<strong>ignore_keys_for_eval</strong> (<code>list[str]</code>, <em>optional</em>) &#x2014;
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions for evaluation during the training.`,name:"ignore_keys_for_eval"}],source:"https://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L1323",returnDescription:`<script context="module">export const metadata = 'undefined';<\/script>
<p>Object containing the global step count, training loss, and metrics.</p>
`,returnType:`<script context="module">export const metadata = 'undefined';<\/script>
<p><code>~trainer_utils.TrainOutput</code></p>
`}}),le=new Je({props:{name:"save_model",anchor:"trl.experimental.distillation.DistillationTrainer.save_model",parameters:[{name:"output_dir",val:": str | None = None"},{name:"_internal_call",val:": bool = False"}],source:"https://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L3746"}}),ae=new Je({props:{name:"push_to_hub",anchor:"trl.experimental.distillation.DistillationTrainer.push_to_hub",parameters:[{name:"commit_message",val:": str | None = 'End of training'"},{name:"blocking",val:": bool = True"},{name:"token",val:": str | None = None"},{name:"revision",val:": str | None = None"},{name:"**kwargs",val:""}],parametersDescription:[{anchor:"trl.experimental.distillation.DistillationTrainer.push_to_hub.commit_message",description:`<strong>commit_message</strong> (<code>str</code>, <em>optional</em>, defaults to <code>&quot;End of training&quot;</code>) &#x2014;
Message to commit while pushing.`,name:"commit_message"},{anchor:"trl.experimental.distillation.DistillationTrainer.push_to_hub.blocking",description:`<strong>blocking</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>True</code>) &#x2014;
Whether the function should return only when the <code>git push</code> has finished.`,name:"blocking"},{anchor:"trl.experimental.distillation.DistillationTrainer.push_to_hub.token",description:`<strong>token</strong> (<code>str</code>, <em>optional</em>, defaults to <code>None</code>) &#x2014;
Token with write permission to overwrite Trainer&#x2019;s original args.`,name:"token"},{anchor:"trl.experimental.distillation.DistillationTrainer.push_to_hub.revision",description:`<strong>revision</strong> (<code>str</code>, <em>optional</em>) &#x2014;
The git revision to commit from. Defaults to the head of the &#x201C;main&#x201D; branch.`,name:"revision"},{anchor:"trl.experimental.distillation.DistillationTrainer.push_to_hub.kwargs",description:`<strong>kwargs</strong> (<code>dict[str, Any]</code>, <em>optional</em>) &#x2014;
Additional keyword arguments passed along to <code>~Trainer.create_model_card</code>.`,name:"kwargs"}],source:"https://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L3993",returnDescription:`<script context="module">export const metadata = 'undefined';<\/script>
<p>The URL of the repository where the model was pushed if <code>blocking=False</code>, or a <code>Future</code> object tracking the
progress of the commit if <code>blocking=True</code>.</p>
`}}),ie=new M({props:{title:"DistillationConfig",local:"trl.experimental.distillation.DistillationConfig",headingTag:"h2"}}),se=new Je({props:{name:"class trl.experimental.distillation.DistillationConfig",anchor:"trl.experimental.distillation.DistillationConfig",parameters:[{name:"output_dir",val:": str | None = None"},{name:"per_device_train_batch_size",val:": int = 8"},{name:"num_train_epochs",val:": float = 3.0"},{name:"max_steps",val:": int = -1"},{name:"learning_rate",val:": float = 1e-06"},{name:"lr_scheduler_type",val:": transformers.trainer_utils.SchedulerType | str = 'linear'"},{name:"lr_scheduler_kwargs",val:": dict | str | None = None"},{name:"warmup_steps",val:": float = 0"},{name:"optim",val:": transformers.training_args.OptimizerNames | str = 'adamw_torch_fused'"},{name:"optim_args",val:": str | None = None"},{name:"weight_decay",val:": float = 0.0"},{name:"adam_beta1",val:": float = 0.9"},{name:"adam_beta2",val:": float = 0.999"},{name:"adam_epsilon",val:": float = 1e-08"},{name:"optim_target_modules",val:": None | str | list[str] = None"},{name:"gradient_accumulation_steps",val:": int = 1"},{name:"average_tokens_across_devices",val:": bool = True"},{name:"max_grad_norm",val:": float = 1.0"},{name:"label_smoothing_factor",val:": float = 0.0"},{name:"bf16",val:": bool | None = None"},{name:"fp16",val:": bool = False"},{name:"bf16_full_eval",val:": bool = False"},{name:"fp16_full_eval",val:": bool = False"},{name:"tf32",val:": bool | None = None"},{name:"gradient_checkpointing",val:": bool = True"},{name:"gradient_checkpointing_kwargs",val:": dict[str, typing.Any] | str | None = None"},{name:"torch_compile",val:": bool = False"},{name:"torch_compile_backend",val:": str | None = None"},{name:"torch_compile_mode",val:": str | None = None"},{name:"use_liger_kernel",val:": bool = False"},{name:"liger_kernel_config",val:": dict[str, bool] | None = None"},{name:"use_cache",val:": bool = False"},{name:"neftune_noise_alpha",val:": float | None = None"},{name:"torch_empty_cache_steps",val:": int | None = None"},{name:"auto_find_batch_size",val:": bool = False"},{name:"logging_strategy",val:": transformers.trainer_utils.IntervalStrategy | str = 'steps'"},{name:"logging_steps",val:": float = 10"},{name:"logging_first_step",val:": bool = False"},{name:"log_on_each_node",val:": bool = True"},{name:"logging_nan_inf_filter",val:": bool = True"},{name:"include_num_input_tokens_seen",val:": str | bool = 'no'"},{name:"log_level",val:": str = 'passive'"},{name:"log_level_replica",val:": str = 'warning'"},{name:"disable_tqdm",val:": bool | None = None"},{name:"report_to",val:": None | str | list[str] = 'none'"},{name:"run_name",val:": str | None = None"},{name:"project",val:": str = 'huggingface'"},{name:"trackio_space_id",val:": str | None = 'trackio'"},{name:"eval_strategy",val:": transformers.trainer_utils.IntervalStrategy | str = 'no'"},{name:"eval_steps",val:": float | None = None"},{name:"eval_delay",val:": float = 0"},{name:"per_device_eval_batch_size",val:": int = 8"},{name:"prediction_loss_only",val:": bool = False"},{name:"eval_on_start",val:": bool = False"},{name:"eval_do_concat_batches",val:": bool = True"},{name:"eval_use_gather_object",val:": bool = False"},{name:"eval_accumulation_steps",val:": int | None = None"},{name:"include_for_metrics",val:": list = <factory>"},{name:"batch_eval_metrics",val:": bool = False"},{name:"save_only_model",val:": bool = False"},{name:"save_strategy",val:": transformers.trainer_utils.SaveStrategy | str = 'steps'"},{name:"save_steps",val:": float = 500"},{name:"save_on_each_node",val:": bool = False"},{name:"save_total_limit",val:": int | None = None"},{name:"enable_jit_checkpoint",val:": bool = False"},{name:"push_to_hub",val:": bool = False"},{name:"hub_token",val:": str | None = None"},{name:"hub_private_repo",val:": bool | None = None"},{name:"hub_model_id",val:": str | None = None"},{name:"hub_strategy",val:": transformers.trainer_utils.HubStrategy | str = 'every_save'"},{name:"hub_always_push",val:": bool = False"},{name:"hub_revision",val:": str | None = None"},{name:"load_best_model_at_end",val:": bool = False"},{name:"metric_for_best_model",val:": str | None = None"},{name:"greater_is_better",val:": bool | None = None"},{name:"ignore_data_skip",val:": bool = False"},{name:"restore_callback_states_from_checkpoint",val:": bool = False"},{name:"full_determinism",val:": bool = False"},{name:"seed",val:": int = 42"},{name:"data_seed",val:": int | None = None"},{name:"use_cpu",val:": bool = False"},{name:"accelerator_config",val:": dict | str | None = None"},{name:"parallelism_config",val:": accelerate.parallelism_config.ParallelismConfig | None = None"},{name:"dataloader_drop_last",val:": bool = False"},{name:"dataloader_num_workers",val:": int = 0"},{name:"dataloader_pin_memory",val:": bool = True"},{name:"dataloader_persistent_workers",val:": bool = False"},{name:"dataloader_prefetch_factor",val:": int | None = None"},{name:"remove_unused_columns",val:": bool = True"},{name:"label_names",val:": list[str] | None = None"},{name:"train_sampling_strategy",val:": str = 'random'"},{name:"length_column_name",val:": str = 'length'"},{name:"ddp_find_unused_parameters",val:": bool | None = None"},{name:"ddp_bucket_cap_mb",val:": int | None = None"},{name:"ddp_broadcast_buffers",val:": bool | None = None"},{name:"ddp_backend",val:": str | None = None"},{name:"ddp_timeout",val:": int = 1800"},{name:"fsdp",val:": list[transformers.trainer_utils.FSDPOption] | str | None = None"},{name:"fsdp_config",val:": dict[str, typing.Any] | str | None = None"},{name:"deepspeed",val:": dict | str | None = None"},{name:"debug",val:": str | list[transformers.debug_utils.DebugOption] = ''"},{name:"skip_memory_metrics",val:": bool = True"},{name:"do_train",val:": bool = False"},{name:"do_eval",val:": bool = False"},{name:"do_predict",val:": bool = False"},{name:"resume_from_checkpoint",val:": str | None = None"},{name:"warmup_ratio",val:": float | None = None"},{name:"logging_dir",val:": str | None = None"},{name:"local_rank",val:": int = -1"},{name:"model_init_kwargs",val:": dict[str, typing.Any] | str | None = None"},{name:"max_length",val:": int | None = 1024"},{name:"temperature",val:": float = 1.0"},{name:"lmbda",val:": float = 1.0"},{name:"beta",val:": float = 1.0"},{name:"reverse_kl_top_1_mode",val:": str = 'sampled'"},{name:"max_completion_length",val:": int = 512"},{name:"max_prompt_length",val:": int | None = None"},{name:"disable_dropout",val:": bool = True"},{name:"teacher_model_name_or_path",val:": str | None = None"},{name:"teacher_model_revision",val:": str | None = None"},{name:"teacher_model_init_kwargs",val:": dict[str, typing.Any] | str | None = None"},{name:"use_teacher_server",val:": bool = False"},{name:"teacher_model_server_url",val:": str | None = None"},{name:"loss_top_k",val:": int = 1"},{name:"loss_add_tail",val:": bool = True"},{name:"num_generations",val:": int = 1"},{name:"generation_batch_size",val:": int | None = None"},{name:"top_p",val:": float = 0.95"},{name:"top_k",val:": int = 0"},{name:"use_vllm",val:": bool = False"},{name:"vllm_mode",val:": str = 'colocate'"},{name:"vllm_server_base_url",val:": str | None = None"},{name:"vllm_server_host",val:": str = '0.0.0.0'"},{name:"vllm_server_port",val:": int = 8001"},{name:"vllm_server_timeout",val:": float = 240.0"},{name:"vllm_group_port",val:": int = 51216"},{name:"vllm_gpu_memory_utilization",val:": float = 0.3"},{name:"vllm_tensor_parallel_size",val:": int = 1"},{name:"vllm_max_model_length",val:": int | None = None"},{name:"vllm_model_impl",val:": str = 'vllm'"},{name:"vllm_structured_outputs_regex",val:": str | None = None"},{name:"vllm_sync_frequency",val:": int = 1"},{name:"vllm_enable_sleep_mode",val:": bool = False"},{name:"wandb_entity",val:": str | None = None"},{name:"wandb_project",val:": str | None = None"},{name:"wandb_run_group",val:": str | None = None"},{name:"log_completions",val:": bool = False"},{name:"log_completions_steps",val:": int = 100"},{name:"num_completions_to_print",val:": int | None = None"}],source:"https://github.com/huggingface/trl/blob/vr_5607/trl/experimental/distillation/distillation_config.py#L23",parameterGroups:[{title:"Parameters that control the model",parametersDescription:[{anchor:"trl.experimental.distillation.DistillationConfig.model_init_kwargs",description:`<strong>model_init_kwargs</strong> (<code>dict[str, Any]</code>, <em>optional</em>) &#x2014;
Keyword arguments for <code>AutoModelForCausalLM.from_pretrained</code>, used when the <code>model</code> argument of the
trainer is provided as a string.`,name:"model_init_kwargs"},{anchor:"trl.experimental.distillation.DistillationConfig.max_length",description:`<strong>max_length</strong> (<code>int</code> or <code>None</code>, <em>optional</em>, defaults to <code>1024</code>) &#x2014;
Maximum total sequence length (prompt + completion) for tokenization and truncation.`,name:"max_length"}]},{title:"Parameters that control the distillation",parametersDescription:[{anchor:"trl.experimental.distillation.DistillationConfig.temperature",description:`<strong>temperature</strong> (<code>float</code>, <em>optional</em>, defaults to <code>1.0</code>) &#x2014;
Temperature for sampling during generation and for computing the distillation loss. Higher values produce
softer probability distributions.`,name:"temperature"},{anchor:"trl.experimental.distillation.DistillationConfig.lmbda",description:`<strong>lmbda</strong> (<code>float</code>, <em>optional</em>, defaults to <code>1.0</code>) &#x2014;
Probability of using on-policy (student-generated) data for each gradient accumulation slice. A value of
<code>0.0</code> means fully off-policy (dataset completions only), <code>1.0</code> means fully on-policy.`,name:"lmbda"},{anchor:"trl.experimental.distillation.DistillationConfig.beta",description:`<strong>beta</strong> (<code>float</code>, <em>optional</em>, defaults to <code>1.0</code>) &#x2014;
Interpolation coefficient for the Generalized Jensen-Shannon Divergence loss. When <code>0.0</code>, the loss is the
forward KL divergence. When <code>1.0</code>, the loss is the reverse KL divergence. When <code>0.5</code>, it is the standard
JSD.`,name:"beta"},{anchor:"trl.experimental.distillation.DistillationConfig.reverse_kl_top_1_mode",description:`<strong>reverse_kl_top_1_mode</strong> (<code>str</code>, <em>optional</em>, defaults to <code>&quot;sampled&quot;</code>) &#x2014;
Selection rule for the reverse-KL top-1 token when <code>beta &gt; 0</code> and <code>loss_top_k == 1</code>. <code>&quot;sampled&quot;</code> uses the
actual completion token in the batch. <code>&quot;argmax&quot;</code> uses the student&#x2019;s highest-probability token. This
setting does not affect the forward-KL support, which always uses the teacher&#x2019;s top-1 token. Ignored when
<code>beta == 0</code> or <code>loss_top_k != 1</code>.`,name:"reverse_kl_top_1_mode"},{anchor:"trl.experimental.distillation.DistillationConfig.max_completion_length",description:`<strong>max_completion_length</strong> (<code>int</code>, <em>optional</em>, defaults to <code>512</code>) &#x2014;
Maximum number of tokens to generate per completion during on-policy generation.`,name:"max_completion_length"},{anchor:"trl.experimental.distillation.DistillationConfig.disable_dropout",description:`<strong>disable_dropout</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>True</code>) &#x2014;
Whether to disable dropout in the student model during training.`,name:"disable_dropout"}]},{title:"Parameters that control the teacher model",parametersDescription:[{anchor:"trl.experimental.distillation.DistillationConfig.teacher_model_name_or_path",description:`<strong>teacher_model_name_or_path</strong> (<code>str</code> or <code>None</code>, <em>optional</em>) &#x2014;
Model name or path for the teacher model. Used when the teacher is loaded locally.`,name:"teacher_model_name_or_path"},{anchor:"trl.experimental.distillation.DistillationConfig.teacher_model_revision",description:`<strong>teacher_model_revision</strong> (<code>str</code> or <code>None</code>, <em>optional</em>) &#x2014;
Model revision of the teacher model (e.g., branch name, tag, or commit hash).`,name:"teacher_model_revision"},{anchor:"trl.experimental.distillation.DistillationConfig.teacher_model_init_kwargs",description:`<strong>teacher_model_init_kwargs</strong> (<code>dict[str, Any]</code> or <code>None</code>, <em>optional</em>) &#x2014;
Keyword arguments passed to <code>AutoModelForCausalLM.from_pretrained</code> when instantiating the teacher model
from a string.`,name:"teacher_model_init_kwargs"},{anchor:"trl.experimental.distillation.DistillationConfig.use_teacher_server",description:`<strong>use_teacher_server</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>False</code>) &#x2014;
Whether to use an external vLLM teacher server instead of a local teacher model.`,name:"use_teacher_server"},{anchor:"trl.experimental.distillation.DistillationConfig.teacher_model_server_url",description:`<strong>teacher_model_server_url</strong> (<code>str</code> or <code>None</code>, <em>optional</em>) &#x2014;
Base URL of a vLLM server hosting the teacher model (e.g., <code>&quot;http://localhost:8000&quot;</code>). When set, teacher
logprobs are fetched from the server instead of running a local forward pass when <code>use_teacher_server=True</code>.`,name:"teacher_model_server_url"},{anchor:"trl.experimental.distillation.DistillationConfig.loss_top_k",description:`<strong>loss_top_k</strong> (<code>int</code>, <em>optional</em>, defaults to <code>1</code>) &#x2014;
Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are
restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary
is used. For local teachers, the general support rule is teacher top-k for forward KL, student top-k for
reverse KL, and the union for mixed JSD. When <code>beta &gt; 0</code> and <code>loss_top_k == 1</code>, the forward support still
uses the teacher&#x2019;s top-1 token, while the reverse top-1 token is controlled by <code>reverse_kl_top_1_mode</code>.
When <code>use_teacher_server=True</code>, the pure forward path (<code>beta=0</code>) requires this to be positive and uses the
teacher&#x2019;s top-k logprobs for the forward term. When <code>beta &gt; 0</code>, server-backed distillation requires
<code>loss_top_k == 1</code> and only supports <code>&quot;sampled&quot;</code> reverse top-1 tokens.`,name:"loss_top_k"},{anchor:"trl.experimental.distillation.DistillationConfig.loss_add_tail",description:`<strong>loss_add_tail</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>True</code>) &#x2014;
Whether to append a tail bucket that represents the remaining probability mass outside the selected top-k
support when computing the loss.`,name:"loss_add_tail"}]},{title:"Parameters that control on-policy generation",parametersDescription:[{anchor:"trl.experimental.distillation.DistillationConfig.num_generations",description:`<strong>num_generations</strong> (<code>int</code>, <em>optional</em>, defaults to <code>1</code>) &#x2014;
Number of completions to generate per prompt during on-policy generation.`,name:"num_generations"},{anchor:"trl.experimental.distillation.DistillationConfig.generation_batch_size",description:`<strong>generation_batch_size</strong> (<code>int</code> or <code>None</code>, <em>optional</em>) &#x2014;
Number of unique prompts per worker per optimizer step. If <code>None</code>, computed from
<code>(per_device_train_batch_size * gradient_accumulation_steps) // num_generations</code>.`,name:"generation_batch_size"},{anchor:"trl.experimental.distillation.DistillationConfig.top_p",description:`<strong>top_p</strong> (<code>float</code>, <em>optional</em>, defaults to <code>0.95</code>) &#x2014;
Top-p (nucleus) sampling parameter for on-policy generation.`,name:"top_p"},{anchor:"trl.experimental.distillation.DistillationConfig.top_k",description:`<strong>top_k</strong> (<code>int</code>, <em>optional</em>, defaults to <code>0</code>) &#x2014;
Top-k sampling parameter for on-policy generation. <code>0</code> disables top-k filtering.`,name:"top_k"}]},{title:"Parameters that control vLLM for student generation",parametersDescription:[{anchor:"trl.experimental.distillation.DistillationConfig.use_vllm",description:`<strong>use_vllm</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>False</code>) &#x2014;
Whether to use vLLM for generating on-policy completions from the student model.`,name:"use_vllm"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_mode",description:`<strong>vllm_mode</strong> (<code>str</code>, <em>optional</em>, defaults to <code>&quot;colocate&quot;</code>) &#x2014;
Mode for student vLLM integration. Either <code>&quot;server&quot;</code> or <code>&quot;colocate&quot;</code>.`,name:"vllm_mode"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_server_base_url",description:`<strong>vllm_server_base_url</strong> (<code>str</code> or <code>None</code>, <em>optional</em>) &#x2014;
Base URL for the student vLLM server. If provided, <code>vllm_server_host</code> and <code>vllm_server_port</code> are ignored.`,name:"vllm_server_base_url"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_server_host",description:`<strong>vllm_server_host</strong> (<code>str</code>, <em>optional</em>, defaults to <code>&quot;0.0.0.0&quot;</code>) &#x2014;
Host of the student vLLM server.`,name:"vllm_server_host"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_server_port",description:`<strong>vllm_server_port</strong> (<code>int</code>, <em>optional</em>, defaults to <code>8001</code>) &#x2014;
Port of the student vLLM server.`,name:"vllm_server_port"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_server_timeout",description:`<strong>vllm_server_timeout</strong> (<code>float</code>, <em>optional</em>, defaults to <code>240.0</code>) &#x2014;
Timeout for connecting to the student vLLM server.`,name:"vllm_server_timeout"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_group_port",description:`<strong>vllm_group_port</strong> (<code>int</code>, <em>optional</em>, defaults to <code>51216</code>) &#x2014;
Port for the vLLM weight-update group (NCCL communicator).`,name:"vllm_group_port"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_gpu_memory_utilization",description:`<strong>vllm_gpu_memory_utilization</strong> (<code>float</code>, <em>optional</em>, defaults to <code>0.3</code>) &#x2014;
GPU memory utilization for the colocated student vLLM engine.`,name:"vllm_gpu_memory_utilization"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_tensor_parallel_size",description:`<strong>vllm_tensor_parallel_size</strong> (<code>int</code>, <em>optional</em>, defaults to <code>1</code>) &#x2014;
Tensor parallel size for the colocated student vLLM engine.`,name:"vllm_tensor_parallel_size"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_max_model_length",description:`<strong>vllm_max_model_length</strong> (<code>int</code> or <code>None</code>, <em>optional</em>) &#x2014;
Maximum model sequence length for the colocated vLLM engine.`,name:"vllm_max_model_length"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_model_impl",description:`<strong>vllm_model_impl</strong> (<code>str</code>, <em>optional</em>, defaults to <code>&quot;vllm&quot;</code>) &#x2014;
Model implementation backend for vLLM. Use <code>&quot;vllm&quot;</code> or <code>&quot;transformers&quot;</code>.`,name:"vllm_model_impl"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_structured_outputs_regex",description:`<strong>vllm_structured_outputs_regex</strong> (<code>str</code> or <code>None</code>, <em>optional</em>) &#x2014;
Regex pattern for vLLM structured outputs.`,name:"vllm_structured_outputs_regex"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_sync_frequency",description:`<strong>vllm_sync_frequency</strong> (<code>int</code>, <em>optional</em>, defaults to <code>1</code>) &#x2014;
Frequency (in training steps) to synchronize student model weights to the vLLM engine.`,name:"vllm_sync_frequency"},{anchor:"trl.experimental.distillation.DistillationConfig.vllm_enable_sleep_mode",description:`<strong>vllm_enable_sleep_mode</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>False</code>) &#x2014;
Enable vLLM sleep mode to offload student weights during the optimizer step.`,name:"vllm_enable_sleep_mode"}]},{title:"Parameters that control logging",parametersDescription:[{anchor:"trl.experimental.distillation.DistillationConfig.log_completions",description:`<strong>log_completions</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>False</code>) &#x2014;
Whether to log a sample of (prompt, completion) pairs every <code>log_completions_steps</code> steps. If <code>rich</code> is
installed, it prints the sample. If <code>wandb</code> and/or <code>trackio</code> logging is enabled, it logs it to <code>wandb</code>
and/or <code>trackio</code>.`,name:"log_completions"},{anchor:"trl.experimental.distillation.DistillationConfig.log_completions_steps",description:`<strong>log_completions_steps</strong> (<code>int</code>, <em>optional</em>, defaults to <code>100</code>) &#x2014;
Number of steps between logging completions. Only used if <code>log_completions</code> is <code>True</code>.`,name:"log_completions_steps"},{anchor:"trl.experimental.distillation.DistillationConfig.num_completions_to_print",description:`<strong>num_completions_to_print</strong> (<code>int</code> or <code>None</code>, <em>optional</em>) &#x2014;
Number of completions to print. If <code>None</code>, all completions are logged.`,name:"num_completions_to_print"}]}]}}),re=new to({props:{source:"https://github.com/huggingface/trl/blob/main/docs/source/distillation_trainer.md"}}),{c(){y=i("meta"),xe=l(),Te=i("p"),Ue=l(),m(N.$$.fragment),Ne=l(),m(I.$$.fragment),Ie=l(),m(j.$$.fragment),je=l(),C=i("p"),C.innerHTML=Jt,Ce=l(),k=i("blockquote"),k.innerHTML=xt,ke=l(),$=i("p"),$.innerHTML=Ut,$e=l(),L=i("ol"),L.innerHTML=Nt,Le=l(),T=i("blockquote"),T.innerHTML=It,De=l(),m(D.$$.fragment),Ge=l(),m(G.$$.fragment),We=l(),m(W.$$.fragment),qe=l(),q=i("p"),q.innerHTML=jt,Be=l(),B=i("ul"),B.innerHTML=Ct,Fe=l(),m(F.$$.fragment),Ze=l(),Z=i("p"),Z.innerHTML=kt,Ae=l(),m(A.$$.fragment),Re=l(),R=i("p"),R.innerHTML=$t,ze=l(),m(z.$$.fragment),Ee=l(),E=i("p"),E.textContent=Lt,Qe=l(),Q=i("ul"),Q.innerHTML=Dt,Se=l(),m(S.$$.fragment),Xe=l(),X=i("p"),X.innerHTML=Gt,He=l(),m(H.$$.fragment),Ve=l(),V=i("p"),V.innerHTML=Wt,Ye=l(),m(Y.$$.fragment),Pe=l(),m(P.$$.fragment),Ke=l(),K=i("p"),K.innerHTML=qt,Oe=l(),m(O.$$.fragment),et=l(),m(ee.$$.fragment),tt=l(),m(te.$$.fragment),ot=l(),h=i("div"),m(oe.$$.fragment),mt=l(),me=i("p"),me.textContent=Bt,ct=l(),ce=i("p"),ce.textContent=Ft,pt=l(),pe=i("ul"),pe.innerHTML=Zt,ut=l(),w=i("div"),m(ne.$$.fragment),_t=l(),ue=i("p"),ue.textContent=At,gt=l(),b=i("div"),m(le.$$.fragment),ht=l(),_e=i("p"),_e.innerHTML=Rt,ft=l(),ge=i("p"),ge.textContent=zt,vt=l(),J=i("div"),m(ae.$$.fragment),bt=l(),he=i("p"),he.innerHTML=Et,nt=l(),m(ie.$$.fragment),lt=l(),v=i("div"),m(se.$$.fragment),Mt=l(),fe=i("p"),fe.innerHTML=Qt,yt=l(),ve=i("p"),ve.innerHTML=St,Tt=l(),be=i("p"),be.innerHTML=Xt,at=l(),m(re.$$.fragment),it=l(),we=i("p"),this.h()},l(e){const t=Ot("svelte-u9bgzb",document.head);y=s(t,"META",{name:!0,content:!0}),t.forEach(o),xe=a(e),Te=s(e,"P",{}),de(Te).forEach(o),Ue=a(e),c(N.$$.fragment,e),Ne=a(e),c(I.$$.fragment,e),Ie=a(e),c(j.$$.fragment,e),je=a(e),C=s(e,"P",{"data-svelte-h":!0}),d(C)!=="svelte-15pf7f9"&&(C.innerHTML=Jt),Ce=a(e),k=s(e,"BLOCKQUOTE",{"data-svelte-h":!0}),d(k)!=="svelte-11akxtc"&&(k.innerHTML=xt),ke=a(e),$=s(e,"P",{"data-svelte-h":!0}),d($)!=="svelte-1fdfsb3"&&($.innerHTML=Ut),$e=a(e),L=s(e,"OL",{"data-svelte-h":!0}),d(L)!=="svelte-1g11z5x"&&(L.innerHTML=Nt),Le=a(e),T=s(e,"BLOCKQUOTE",{class:!0,"data-svelte-h":!0}),d(T)!=="svelte-1mf99ue"&&(T.innerHTML=It),De=a(e),c(D.$$.fragment,e),Ge=a(e),c(G.$$.fragment,e),We=a(e),c(W.$$.fragment,e),qe=a(e),q=s(e,"P",{"data-svelte-h":!0}),d(q)!=="svelte-1h7t8i7"&&(q.innerHTML=jt),Be=a(e),B=s(e,"UL",{"data-svelte-h":!0}),d(B)!=="svelte-1wrzi5"&&(B.innerHTML=Ct),Fe=a(e),c(F.$$.fragment,e),Ze=a(e),Z=s(e,"P",{"data-svelte-h":!0}),d(Z)!=="svelte-1hoe89p"&&(Z.innerHTML=kt),Ae=a(e),c(A.$$.fragment,e),Re=a(e),R=s(e,"P",{"data-svelte-h":!0}),d(R)!=="svelte-19a76y9"&&(R.innerHTML=$t),ze=a(e),c(z.$$.fragment,e),Ee=a(e),E=s(e,"P",{"data-svelte-h":!0}),d(E)!=="svelte-1iggf6u"&&(E.textContent=Lt),Qe=a(e),Q=s(e,"UL",{"data-svelte-h":!0}),d(Q)!=="svelte-15rldx5"&&(Q.innerHTML=Dt),Se=a(e),c(S.$$.fragment,e),Xe=a(e),X=s(e,"P",{"data-svelte-h":!0}),d(X)!=="svelte-1drngiz"&&(X.innerHTML=Gt),He=a(e),c(H.$$.fragment,e),Ve=a(e),V=s(e,"P",{"data-svelte-h":!0}),d(V)!=="svelte-1wwqce1"&&(V.innerHTML=Wt),Ye=a(e),c(Y.$$.fragment,e),Pe=a(e),c(P.$$.fragment,e),Ke=a(e),K=s(e,"P",{"data-svelte-h":!0}),d(K)!=="svelte-ilhjb7"&&(K.innerHTML=qt),Oe=a(e),c(O.$$.fragment,e),et=a(e),c(ee.$$.fragment,e),tt=a(e),c(te.$$.fragment,e),ot=a(e),h=s(e,"DIV",{class:!0});var f=de(h);c(oe.$$.fragment,f),mt=a(f),me=s(f,"P",{"data-svelte-h":!0}),d(me)!=="svelte-1nh45w"&&(me.textContent=Bt),ct=a(f),ce=s(f,"P",{"data-svelte-h":!0}),d(ce)!=="svelte-3plf16"&&(ce.textContent=Ft),pt=a(f),pe=s(f,"UL",{"data-svelte-h":!0}),d(pe)!=="svelte-1ucyylw"&&(pe.innerHTML=Zt),ut=a(f),w=s(f,"DIV",{class:!0});var rt=de(w);c(ne.$$.fragment,rt),_t=a(rt),ue=s(rt,"P",{"data-svelte-h":!0}),d(ue)!=="svelte-1cilnet"&&(ue.textContent=At),rt.forEach(o),gt=a(f),b=s(f,"DIV",{class:!0});var Me=de(b);c(le.$$.fragment,Me),ht=a(Me),_e=s(Me,"P",{"data-svelte-h":!0}),d(_e)!=="svelte-r8h4ov"&&(_e.innerHTML=Rt),ft=a(Me),ge=s(Me,"P",{"data-svelte-h":!0}),d(ge)!=="svelte-1e6bius"&&(ge.textContent=zt),Me.forEach(o),vt=a(f),J=s(f,"DIV",{class:!0});var dt=de(J);c(ae.$$.fragment,dt),bt=a(dt),he=s(dt,"P",{"data-svelte-h":!0}),d(he)!=="svelte-8tudwd"&&(he.innerHTML=Et),dt.forEach(o),f.forEach(o),nt=a(e),c(ie.$$.fragment,e),lt=a(e),v=s(e,"DIV",{class:!0});var x=de(v);c(se.$$.fragment,x),Mt=a(x),fe=s(x,"P",{"data-svelte-h":!0}),d(fe)!=="svelte-ym7xdc"&&(fe.innerHTML=Qt),yt=a(x),ve=s(x,"P",{"data-svelte-h":!0}),d(ve)!=="svelte-4u9beg"&&(ve.innerHTML=St),Tt=a(x),be=s(x,"P",{"data-svelte-h":!0}),d(be)!=="svelte-ekuf1t"&&(be.innerHTML=Xt),x.forEach(o),at=a(e),c(re.$$.fragment,e),it=a(e),we=s(e,"P",{}),de(we).forEach(o),this.h()},h(){U(y,"name","hf:doc:metadata"),U(y,"content",no),U(T,"class","note"),U(w,"class","docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"),U(b,"class","docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"),U(J,"class","docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"),U(h,"class","docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"),U(v,"class","docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8")},m(e,t){r(document.head,y),n(e,xe,t),n(e,Te,t),n(e,Ue,t),p(N,e,t),n(e,Ne,t),p(I,e,t),n(e,Ie,t),p(j,e,t),n(e,je,t),n(e,C,t),n(e,Ce,t),n(e,k,t),n(e,ke,t),n(e,$,t),n(e,$e,t),n(e,L,t),n(e,Le,t),n(e,T,t),n(e,De,t),p(D,e,t),n(e,Ge,t),p(G,e,t),n(e,We,t),p(W,e,t),n(e,qe,t),n(e,q,t),n(e,Be,t),n(e,B,t),n(e,Fe,t),p(F,e,t),n(e,Ze,t),n(e,Z,t),n(e,Ae,t),p(A,e,t),n(e,Re,t),n(e,R,t),n(e,ze,t),p(z,e,t),n(e,Ee,t),n(e,E,t),n(e,Qe,t),n(e,Q,t),n(e,Se,t),p(S,e,t),n(e,Xe,t),n(e,X,t),n(e,He,t),p(H,e,t),n(e,Ve,t),n(e,V,t),n(e,Ye,t),p(Y,e,t),n(e,Pe,t),p(P,e,t),n(e,Ke,t),n(e,K,t),n(e,Oe,t),p(O,e,t),n(e,et,t),p(ee,e,t),n(e,tt,t),p(te,e,t),n(e,ot,t),n(e,h,t),p(oe,h,null),r(h,mt),r(h,me),r(h,ct),r(h,ce),r(h,pt),r(h,pe),r(h,ut),r(h,w),p(ne,w,null),r(w,_t),r(w,ue),r(h,gt),r(h,b),p(le,b,null),r(b,ht),r(b,_e),r(b,ft),r(b,ge),r(h,vt),r(h,J),p(ae,J,null),r(J,bt),r(J,he),n(e,nt,t),p(ie,e,t),n(e,lt,t),n(e,v,t),p(se,v,null),r(v,Mt),r(v,fe),r(v,yt),r(v,ve),r(v,Tt),r(v,be),n(e,at,t),p(re,e,t),n(e,it,t),n(e,we,t),st=!0},p:Vt,i(e){st||(u(N.$$.fragment,e),u(I.$$.fragment,e),u(j.$$.fragment,e),u(D.$$.fragment,e),u(G.$$.fragment,e),u(W.$$.fragment,e),u(F.$$.fragment,e),u(A.$$.fragment,e),u(z.$$.fragment,e),u(S.$$.fragment,e),u(H.$$.fragment,e),u(Y.$$.fragment,e),u(P.$$.fragment,e),u(O.$$.fragment,e),u(ee.$$.fragment,e),u(te.$$.fragment,e),u(oe.$$.fragment,e),u(ne.$$.fragment,e),u(le.$$.fragment,e),u(ae.$$.fragment,e),u(ie.$$.fragment,e),u(se.$$.fragment,e),u(re.$$.fragment,e),st=!0)},o(e){_(N.$$.fragment,e),_(I.$$.fragment,e),_(j.$$.fragment,e),_(D.$$.fragment,e),_(G.$$.fragment,e),_(W.$$.fragment,e),_(F.$$.fragment,e),_(A.$$.fragment,e),_(z.$$.fragment,e),_(S.$$.fragment,e),_(H.$$.fragment,e),_(Y.$$.fragment,e),_(P.$$.fragment,e),_(O.$$.fragment,e),_(ee.$$.fragment,e),_(te.$$.fragment,e),_(oe.$$.fragment,e),_(ne.$$.fragment,e),_(le.$$.fragment,e),_(ae.$$.fragment,e),_(ie.$$.fragment,e),_(se.$$.fragment,e),_(re.$$.fragment,e),st=!1},d(e){e&&(o(xe),o(Te),o(Ue),o(Ne),o(Ie),o(je),o(C),o(Ce),o(k),o(ke),o($),o($e),o(L),o(Le),o(T),o(De),o(Ge),o(We),o(qe),o(q),o(Be),o(B),o(Fe),o(Ze),o(Z),o(Ae),o(Re),o(R),o(ze),o(Ee),o(E),o(Qe),o(Q),o(Se),o(Xe),o(X),o(He),o(Ve),o(V),o(Ye),o(Pe),o(Ke),o(K),o(Oe),o(et),o(tt),o(ot),o(h),o(nt),o(lt),o(v),o(at),o(it),o(we)),o(y),g(N,e),g(I,e),g(j,e),g(D,e),g(G,e),g(W,e),g(F,e),g(A,e),g(z,e),g(S,e),g(H,e),g(Y,e),g(P,e),g(O,e),g(ee,e),g(te,e),g(oe),g(ne),g(le),g(ae),g(ie,e),g(se),g(re,e)}}}const no='{"title":"Distillation Trainer","local":"distillation-trainer","sections":[{"title":"Overview","local":"overview","sections":[],"depth":2},{"title":"Quick start","local":"quick-start","sections":[],"depth":2},{"title":"Usage tips","local":"usage-tips","sections":[{"title":"On-policy vs. off-policy","local":"on-policy-vs-off-policy","sections":[],"depth":3},{"title":"Using an external teacher server","local":"using-an-external-teacher-server","sections":[],"depth":3},{"title":"Expected dataset type","local":"expected-dataset-type","sections":[],"depth":3}],"depth":2},{"title":"Example script","local":"example-script","sections":[],"depth":2},{"title":"DistillationTrainer","local":"trl.experimental.distillation.DistillationTrainer","sections":[],"depth":2},{"title":"DistillationConfig","local":"trl.experimental.distillation.DistillationConfig","sections":[],"depth":2}],"depth":1}';function lo(wt){return Yt(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class co extends Pt{constructor(y){super(),Kt(this,y,lo,oo,Ht,{})}}export{co as component};

Xet Storage Details

Size:
55.1 kB
·
Xet hash:
a5b8806d3f40e8f8ab71af3a4e0b4d2f791a7ce5886ff96328d518dba93c5a1f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.