Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Experimental Features","local":"experimental-features","sections":[{"title":"Current Experimental Features","local":"current-experimental-features","sections":[{"title":"BEMA for Reference Model","local":"bema-for-reference-model","sections":[],"depth":3},{"title":"GFPO","local":"gfpo","sections":[],"depth":3},{"title":"GSPO-token","local":"gspo-token","sections":[],"depth":3},{"title":"GRPO With Replay Buffer","local":"grpo-with-replay-buffer","sections":[{"title":"Usage","local":"usage","sections":[],"depth":4}],"depth":3}],"depth":2},{"title":"Promotion Path (Simple)","local":"promotion-path-simple","sections":[],"depth":2},{"title":"FAQ","local":"faq","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/trl/pr_4331/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/entry/start.d2344edf.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/scheduler.7b731bd4.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/singletons.8fadd22e.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/index.ac28c20f.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/paths.d46f6aee.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/entry/app.1a672609.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/preload-helper.71df5523.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/index.cc268345.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/nodes/0.79a9cd4d.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/nodes/14.6e5fbb26.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a9dd569.js"> | |
| <link rel="modulepreload" href="/docs/trl/pr_4331/en/_app/immutable/chunks/CodeBlock.3660b8ed.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Experimental Features","local":"experimental-features","sections":[{"title":"Current Experimental Features","local":"current-experimental-features","sections":[{"title":"BEMA for Reference Model","local":"bema-for-reference-model","sections":[],"depth":3},{"title":"GFPO","local":"gfpo","sections":[],"depth":3},{"title":"GSPO-token","local":"gspo-token","sections":[],"depth":3},{"title":"GRPO With Replay Buffer","local":"grpo-with-replay-buffer","sections":[{"title":"Usage","local":"usage","sections":[],"depth":4}],"depth":3}],"depth":2},{"title":"Promotion Path (Simple)","local":"promotion-path-simple","sections":[],"depth":2},{"title":"FAQ","local":"faq","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="experimental-features" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#experimental-features"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Experimental Features</span></h1> <p data-svelte-h="svelte-15l6usj">The <code>trl.experimental</code> namespace provides a minimal, clearly separated space for fast iteration on new ideas.</p> <blockquote class="warning" data-svelte-h="svelte-1jedoer"><p><strong>Stability contract:</strong> Anything under <code>trl.experimental</code> may change or be removed in <em>any</em> release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.</p></blockquote> <h2 class="relative group"><a id="current-experimental-features" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#current-experimental-features"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Current Experimental Features</span></h2> <p data-svelte-h="svelte-3gid1n">The following modules are currently available under <a href="https://github.com/huggingface/trl/tree/main/trl/experimental" rel="nofollow"><code>trl.experimental</code></a>. | |
| This list is not exhaustive and may change at any time.</p> <h3 class="relative group"><a id="bema-for-reference-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#bema-for-reference-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>BEMA for Reference Model</span></h3> <p data-svelte-h="svelte-1n2qlx2">This feature implements the BEMA algorithm to update the reference model during DPO training.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.experimental.bema_for_ref_model <span class="hljs-keyword">import</span> BEMACallback, DPOTrainer | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM, AutoTokenizer | |
| pref_dataset = load_dataset(<span class="hljs-string">"trl-internal-testing/zen"</span>, <span class="hljs-string">"standard_preference"</span>, split=<span class="hljs-string">"train"</span>) | |
| ref_model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"</span>) | |
| bema_callback = BEMACallback(update_ref_model=<span class="hljs-literal">True</span>) | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"</span>) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| trainer = DPOTrainer( | |
| model=model, | |
| ref_model=ref_model, | |
| train_dataset=pref_dataset, | |
| processing_class=tokenizer, | |
| callbacks=[bema_callback], | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="gfpo" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gfpo"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GFPO</span></h3> <p data-svelte-h="svelte-1dy87lu">This feature implements the GFPO algorithm to enforce concise reasoning in the model’s output generation, as proposed in the paper <a href="https://huggingface.co/papers/2508.09726" rel="nofollow">Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning</a>.</p> <p data-svelte-h="svelte-1lfn8in">To activate GFPO in <code>GFPOTrainer</code>:</p> <ul data-svelte-h="svelte-1h9orgs"><li>set <code>num_remains_in_group</code> in <code>GFPOConfig</code></li> <li>define a group filter function and set it to <code>group_filter_func</code> in <code>GFPOTrainer</code>. <code>group_filter_func</code> will score the <code>num_generations</code> completions and The GFPOTrainer filters groups according to their scores to get top <code>num_remains_in_group</code> completions as a new group. Model will be trained on the filtered group.</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># train_gfpo.py</span> | |
| <span class="hljs-keyword">from</span> trl.experimental.gfpo <span class="hljs-keyword">import</span> GFPOConfig, GFPOTrainer | |
| <span class="hljs-comment"># dummy group filter to scores the completions based on its indice in group</span> | |
| <span class="hljs-keyword">class</span> <span class="hljs-title class_">GroupFilter</span>: | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">__call__</span>(<span class="hljs-params">self, group_completions, group_rewards, **kwargs</span>): | |
| group_scores = [] | |
| <span class="hljs-keyword">for</span> completions, rewards <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(group_completions, group_rewards): | |
| scores = [<span class="hljs-built_in">float</span>(i) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-built_in">len</span>(completions))] | |
| group_scores.append(scores) | |
| <span class="hljs-keyword">return</span> group_scores | |
| training_args = GFPOConfig( | |
| output_dir=<span class="hljs-string">"Qwen3-0.6B-GFPO"</span>, | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| num_remains_in_group=<span class="hljs-number">2</span>, | |
| bf16=<span class="hljs-literal">True</span>, | |
| ) | |
| trainer = GFPOTrainer( | |
| model=<span class="hljs-string">"Qwen/Qwen3-0.6B"</span>, | |
| reward_funcs=..., | |
| train_dataset=..., | |
| args=training_args, | |
| group_filter_func=GroupFilter(), | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="gspo-token" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gspo-token"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GSPO-token</span></h3> <p data-svelte-h="svelte-npbvk6">In the paper <a href="https://huggingface.co/papers/2507.18071" rel="nofollow">Group Sequence Policy Optimization</a>, the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the <code>GRPOTrainer</code> class in <code>trl.experimental.gspo_token</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.experimental.gspo_token <span class="hljs-keyword">import</span> GRPOTrainer | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> GRPOConfig | |
| training_args = GRPOConfig( | |
| importance_sampling_level=<span class="hljs-string">"sequence_token"</span>, | |
| ... | |
| )<!-- HTML_TAG_END --></pre></div> <blockquote class="warning"><p>To leverage GSPO-token, the user will need to provide the per-token advantage <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>A</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>t</mi></mrow></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex"> \hat{A_{i,t}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2329em;vertical-align:-0.2861em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> for each token <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex"> t </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span><!-- HTML_TAG_END --> in the sequence <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex"> i </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6595em;"></span><span class="mord mathnormal">i</span></span></span></span><!-- HTML_TAG_END --> (i.e., make <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>A</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>t</mi></mrow></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex"> \hat{A_{i,t}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2329em;vertical-align:-0.2861em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> varies with <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>t</mi></mrow><annotation encoding="application/x-tex"> t </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.6151em;"></span><span class="mord mathnormal">t</span></span></span></span><!-- HTML_TAG_END -->—which isn’t the case here, <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>A</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>t</mi></mrow></msub><mo>^</mo></mover><mo>=</mo><mover accent="true"><msub><mi>A</mi><mi>i</mi></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex"> \hat{A_{i,t}}=\hat{A_{i}} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2329em;vertical-align:-0.2861em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2861em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0968em;vertical-align:-0.15em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.9468em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">A</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3117em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span style="top:-3.2523em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END -->). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.</p></blockquote> <h3 class="relative group"><a id="grpo-with-replay-buffer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#grpo-with-replay-buffer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GRPO With Replay Buffer</span></h3> <p data-svelte-h="svelte-avq1w">This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that’ve been used to train a model in prior batches.</p> <h4 class="relative group"><a id="usage" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#usage"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Usage</span></h4> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl.experimental.grpo_with_replay_buffer <span class="hljs-keyword">import</span> GRPOWithReplayBufferTrainer | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| dataset = load_dataset(<span class="hljs-string">"trl-internal-testing/zen"</span>, <span class="hljs-string">"standard_prompt_only"</span>, split=<span class="hljs-string">"train"</span>) | |
| <span class="hljs-comment"># Guarantee that some rewards have 0 std</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">custom_reward_func</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| <span class="hljs-keyword">if</span> torch.rand(<span class="hljs-number">1</span>).item() < <span class="hljs-number">0.25</span>: | |
| <span class="hljs-keyword">return</span> [<span class="hljs-number">0</span>] * <span class="hljs-built_in">len</span>(completions) <span class="hljs-comment"># simulate some None rewards</span> | |
| <span class="hljs-keyword">else</span>: | |
| <span class="hljs-keyword">return</span> torch.rand(<span class="hljs-built_in">len</span>(completions)).tolist() | |
| training_args = GRPOWithReplayBufferConfig( | |
| output_dir=self.tmp_dir, | |
| learning_rate=<span class="hljs-number">1e-4</span>, | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| num_generations=<span class="hljs-number">4</span>, | |
| max_completion_length=<span class="hljs-number">8</span>, | |
| replay_buffer_size=<span class="hljs-number">8</span>, | |
| report_to=<span class="hljs-string">"none"</span>, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=<span class="hljs-string">"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"</span>, | |
| reward_funcs=[custom_reward_func], | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| previous_trainable_params = {n: param.clone() <span class="hljs-keyword">for</span> n, param <span class="hljs-keyword">in</span> trainer.model.named_parameters()} | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1n8heb1">To silence the runtime notice:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">export</span> TRL_EXPERIMENTAL_SILENCE=1<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="promotion-path-simple" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#promotion-path-simple"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Promotion Path (Simple)</span></h2> <ol data-svelte-h="svelte-wbh5n4"><li><strong>Prototype outside the main repo:</strong> Start development in your own fork or a separate repository to iterate quickly.</li> <li><strong>Experimental inclusion:</strong> Once it’s ready for early users, move the idea into <code>trl.experimental.<feature></code>.</li> <li><strong>Improve:</strong> Add tests, a short doc/example, and demonstrate the usage.</li> <li><strong>Promote:</strong> Once the API proves stable and there is clear interest or adoption from the community, move it into <code>trl.<feature></code> (stable module).</li></ol> <h2 class="relative group"><a id="faq" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#faq"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>FAQ</span></h2> <p data-svelte-h="svelte-yr2qwe"><strong>Why not just use branches?</strong> | |
| Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback.</p> <p data-svelte-h="svelte-1crmgan"><strong>Can these APIs change or vanish without warning?</strong> | |
| Yes. Anything inside <code>trl.experimental</code> can change or disappear in <em>any</em> release.</p> <p data-svelte-h="svelte-9oxni8"><strong>Should I use this in production?</strong> | |
| Only if you are fine with updating your code quickly when things change.</p> <p data-svelte-h="svelte-11i4jal"><strong>Will maintainers promptly fix issues in <code>trl.experimental</code>?</strong> | |
| Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/experimental.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1nqn5y5 = { | |
| assets: "/docs/trl/pr_4331/en", | |
| base: "/docs/trl/pr_4331/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/trl/pr_4331/en/_app/immutable/entry/start.d2344edf.js"), | |
| import("/docs/trl/pr_4331/en/_app/immutable/entry/app.1a672609.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 14], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 40.6 kB
- Xet hash:
- f682061915babd95880eb8cdc9b9b86e0e089256d2c48f7b94641a30dda30440
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.