Buckets:

hf-doc-build/doc-dev / trl /pr_4331 /en /experimental.html
rtrm's picture
download
raw
40.6 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Experimental Features&quot;,&quot;local&quot;:&quot;experimental-features&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Current Experimental Features&quot;,&quot;local&quot;:&quot;current-experimental-features&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;BEMA for Reference Model&quot;,&quot;local&quot;:&quot;bema-for-reference-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GFPO&quot;,&quot;local&quot;:&quot;gfpo&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GSPO-token&quot;,&quot;local&quot;:&quot;gspo-token&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GRPO With Replay Buffer&quot;,&quot;local&quot;:&quot;grpo-with-replay-buffer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Usage&quot;,&quot;local&quot;:&quot;usage&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Promotion Path (Simple)&quot;,&quot;local&quot;:&quot;promotion-path-simple&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;FAQ&quot;,&quot;local&quot;:&quot;faq&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/trl/pr_4331/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/entry/start.d2344edf.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/scheduler.7b731bd4.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/singletons.8fadd22e.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/index.ac28c20f.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/paths.d46f6aee.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/entry/app.1a672609.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/preload-helper.71df5523.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/index.cc268345.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/nodes/0.79a9cd4d.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/nodes/14.6e5fbb26.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a9dd569.js">
<link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/CodeBlock.3660b8ed.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Experimental Features&quot;,&quot;local&quot;:&quot;experimental-features&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Current Experimental Features&quot;,&quot;local&quot;:&quot;current-experimental-features&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;BEMA for Reference Model&quot;,&quot;local&quot;:&quot;bema-for-reference-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GFPO&quot;,&quot;local&quot;:&quot;gfpo&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GSPO-token&quot;,&quot;local&quot;:&quot;gspo-token&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GRPO With Replay Buffer&quot;,&quot;local&quot;:&quot;grpo-with-replay-buffer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Usage&quot;,&quot;local&quot;:&quot;usage&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Promotion Path (Simple)&quot;,&quot;local&quot;:&quot;promotion-path-simple&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;FAQ&quot;,&quot;local&quot;:&quot;faq&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="experimental-features" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#experimental-features"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Experimental Features</span></h1> <p data-svelte-h="svelte-15l6usj">The <code>trl.experimental</code> namespace provides a minimal, clearly separated space for fast iteration on new ideas.</p> <blockquote class="warning" data-svelte-h="svelte-1jedoer"><p><strong>Stability contract:</strong> Anything under <code>trl.experimental</code> may change or be removed in <em>any</em> release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.</p></blockquote> <h2 class="relative group"><a id="current-experimental-features" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#current-experimental-features"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Current Experimental Features</span></h2> <p data-svelte-h="svelte-3gid1n">The following modules are currently available under <a href="https://github.com/huggingface/trl/tree/main/trl/experimental" rel="nofollow"><code>trl.experimental</code></a>.
This list is not exhaustive and may change at any time.</p> <h3 class="relative group"><a id="bema-for-reference-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bema-for-reference-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>BEMA for Reference Model</span></h3> <p data-svelte-h="svelte-1n2qlx2">This feature implements the BEMA algorithm to update the reference model during DPO training.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.experimental.bema_for_ref_model <span class="hljs-keyword">import</span> BEMACallback, DPOTrainer
<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer
pref_dataset = load_dataset(<span class="hljs-string">&quot;trl-internal-testing/zen&quot;</span>, <span class="hljs-string">&quot;standard_preference&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
ref_model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;trl-internal-testing/tiny-Qwen2ForCausalLM-2.5&quot;</span>)
bema_callback = BEMACallback(update_ref_model=<span class="hljs-literal">True</span>)
model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;trl-internal-testing/tiny-Qwen2ForCausalLM-2.5&quot;</span>)
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;trl-internal-testing/tiny-Qwen2ForCausalLM-2.5&quot;</span>)
tokenizer.pad_token = tokenizer.eos_token
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
train_dataset=pref_dataset,
processing_class=tokenizer,
callbacks=[bema_callback],
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="gfpo" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gfpo"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GFPO</span></h3> <p data-svelte-h="svelte-1dy87lu">This feature implements the GFPO algorithm to enforce concise reasoning in the model’s output generation, as proposed in the paper <a href="https://huggingface.co/papers/2508.09726" rel="nofollow">Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning</a>.</p> <p data-svelte-h="svelte-1lfn8in">To activate GFPO in <code>GFPOTrainer</code>:</p> <ul data-svelte-h="svelte-1h9orgs"><li>set <code>num_remains_in_group</code> in <code>GFPOConfig</code></li> <li>define a group filter function and set it to <code>group_filter_func</code> in <code>GFPOTrainer</code>. <code>group_filter_func</code> will score the <code>num_generations</code> completions and The GFPOTrainer filters groups according to their scores to get top <code>num_remains_in_group</code> completions as a new group. Model will be trained on the filtered group.</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># train_gfpo.py</span>
<span class="hljs-keyword">from</span> trl.experimental.gfpo <span class="hljs-keyword">import</span> GFPOConfig, GFPOTrainer
<span class="hljs-comment"># dummy group filter to scores the completions based on its indice in group</span>
<span class="hljs-keyword">class</span> <span class="hljs-title class_">GroupFilter</span>:
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__call__</span>(<span class="hljs-params">self, group_completions, group_rewards, **kwargs</span>):
group_scores = []
<span class="hljs-keyword">for</span> completions, rewards <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(group_completions, group_rewards):
scores = [<span class="hljs-built_in">float</span>(i) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(completions))]
group_scores.append(scores)
<span class="hljs-keyword">return</span> group_scores
training_args = GFPOConfig(
output_dir=<span class="hljs-string">&quot;Qwen3-0.6B-GFPO&quot;</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
num_remains_in_group=<span class="hljs-number">2</span>,
bf16=<span class="hljs-literal">True</span>,
)
trainer = GFPOTrainer(
model=<span class="hljs-string">&quot;Qwen/Qwen3-0.6B&quot;</span>,
reward_funcs=...,
train_dataset=...,
args=training_args,
group_filter_func=GroupFilter(),
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="gspo-token" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gspo-token"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GSPO-token</span></h3> <p data-svelte-h="svelte-npbvk6">In the paper <a href="https://huggingface.co/papers/2507.18071" rel="nofollow">Group Sequence Policy Optimization</a>, the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the <code>GRPOTrainer</code> class in <code>trl.experimental.gspo_token</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.experimental.gspo_token <span class="hljs-keyword">import</span> GRPOTrainer
<span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> GRPOConfig
training_args = GRPOConfig(
importance_sampling_level=<span class="hljs-string">&quot;sequence_token&quot;</span>,
...
)<!-- HTML_TAG_END --></pre></div> <blockquote class="warning"><p>To leverage GSPO-token, the user will need to provide the per-token advantage <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>A</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>t</mi></mrow></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex"> \hat{A_{i,t}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2329em;vertical-align:-0.2861em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> for each token <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex"> t </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span><!-- HTML_TAG_END --> in the sequence <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex"> i </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span><!-- HTML_TAG_END --> (i.e., make <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>A</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>t</mi></mrow></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex"> \hat{A_{i,t}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2329em;vertical-align:-0.2861em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> varies with <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex"> t </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span><!-- HTML_TAG_END -->—which isn’t the case here, <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>A</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>t</mi></mrow></msub><mo>^</mo></mover><mo>=</mo><mover accent="true"><msub><mi>A</mi><mi>i</mi></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex"> \hat{A_{i,t}}=\hat{A_{i}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2329em;vertical-align:-0.2861em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0968em;vertical-align:-0.15em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END -->). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.</p></blockquote> <h3 class="relative group"><a id="grpo-with-replay-buffer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#grpo-with-replay-buffer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GRPO With Replay Buffer</span></h3> <p data-svelte-h="svelte-avq1w">This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that’ve been used to train a model in prior batches.</p> <h4 class="relative group"><a id="usage" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#usage"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Usage</span></h4> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.experimental.grpo_with_replay_buffer <span class="hljs-keyword">import</span> GRPOWithReplayBufferTrainer
<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
dataset = load_dataset(<span class="hljs-string">&quot;trl-internal-testing/zen&quot;</span>, <span class="hljs-string">&quot;standard_prompt_only&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
<span class="hljs-comment"># Guarantee that some rewards have 0 std</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">custom_reward_func</span>(<span class="hljs-params">completions, **kwargs</span>):
<span class="hljs-keyword">if</span> torch.rand(<span class="hljs-number">1</span>).item() &lt; <span class="hljs-number">0.25</span>:
<span class="hljs-keyword">return</span> [<span class="hljs-number">0</span>] * <span class="hljs-built_in">len</span>(completions) <span class="hljs-comment"># simulate some None rewards</span>
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">return</span> torch.rand(<span class="hljs-built_in">len</span>(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir=self.tmp_dir,
learning_rate=<span class="hljs-number">1e-4</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
num_generations=<span class="hljs-number">4</span>,
max_completion_length=<span class="hljs-number">8</span>,
replay_buffer_size=<span class="hljs-number">8</span>,
report_to=<span class="hljs-string">&quot;none&quot;</span>,
)
trainer = GRPOTrainer(
model=<span class="hljs-string">&quot;trl-internal-testing/tiny-Qwen2ForCausalLM-2.5&quot;</span>,
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() <span class="hljs-keyword">for</span> n, param <span class="hljs-keyword">in</span> trainer.model.named_parameters()}
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1n8heb1">To silence the runtime notice:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> TRL_EXPERIMENTAL_SILENCE=1<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="promotion-path-simple" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#promotion-path-simple"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Promotion Path (Simple)</span></h2> <ol data-svelte-h="svelte-wbh5n4"><li><strong>Prototype outside the main repo:</strong> Start development in your own fork or a separate repository to iterate quickly.</li> <li><strong>Experimental inclusion:</strong> Once it’s ready for early users, move the idea into <code>trl.experimental.&lt;feature&gt;</code>.</li> <li><strong>Improve:</strong> Add tests, a short doc/example, and demonstrate the usage.</li> <li><strong>Promote:</strong> Once the API proves stable and there is clear interest or adoption from the community, move it into <code>trl.&lt;feature&gt;</code> (stable module).</li></ol> <h2 class="relative group"><a id="faq" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#faq"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>FAQ</span></h2> <p data-svelte-h="svelte-yr2qwe"><strong>Why not just use branches?</strong>
Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback.</p> <p data-svelte-h="svelte-1crmgan"><strong>Can these APIs change or vanish without warning?</strong>
Yes. Anything inside <code>trl.experimental</code> can change or disappear in <em>any</em> release.</p> <p data-svelte-h="svelte-9oxni8"><strong>Should I use this in production?</strong>
Only if you are fine with updating your code quickly when things change.</p> <p data-svelte-h="svelte-11i4jal"><strong>Will maintainers promptly fix issues in <code>trl.experimental</code>?</strong>
Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/experimental.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1nqn5y5 = {
assets: "/docs/trl/pr_4331/en",
base: "/docs/trl/pr_4331/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/trl/pr_4331/en/_app/immutable/entry/start.d2344edf.js"),
import("/docs/trl/pr_4331/en/_app/immutable/entry/app.1a672609.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 14],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
40.6 kB
·
Xet hash:
f682061915babd95880eb8cdc9b9b86e0e089256d2c48f7b94641a30dda30440

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.