Buckets:

hf-doc-build/doc-dev / trl /pr_5607 /en /reward_trainer.html
HuggingFaceDocBuilder's picture
download
raw
175 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Reward Modeling&quot;,&quot;local&quot;:&quot;reward-modeling&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Overview&quot;,&quot;local&quot;:&quot;overview&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Quick start&quot;,&quot;local&quot;:&quot;quick-start&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Expected dataset type and format&quot;,&quot;local&quot;:&quot;expected-dataset-type-and-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Looking deeper into the training method&quot;,&quot;local&quot;:&quot;looking-deeper-into-the-training-method&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Preprocessing and tokenization&quot;,&quot;local&quot;:&quot;preprocessing-and-tokenization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Computing the loss&quot;,&quot;local&quot;:&quot;computing-the-loss&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Logged metrics&quot;,&quot;local&quot;:&quot;logged-metrics&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Customization&quot;,&quot;local&quot;:&quot;customization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Model initialization&quot;,&quot;local&quot;:&quot;model-initialization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Train adapters with PEFT&quot;,&quot;local&quot;:&quot;train-adapters-with-peft&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Tool Calling with Reward Modeling&quot;,&quot;local&quot;:&quot;tool-calling-with-reward-modeling&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;RewardTrainer&quot;,&quot;local&quot;:&quot;trl.RewardTrainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;RewardConfig&quot;,&quot;local&quot;:&quot;trl.RewardConfig&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/trl/pr_5607/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/entry/start.151d81bd.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/scheduler.7b731bd4.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/singletons.2cf51804.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/index.ac28c20f.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/paths.ba01f37d.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/entry/app.3d9a91c0.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/preload-helper.e1689b3a.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/index.cc268345.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/nodes/0.cd288160.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/nodes/48.2d520d8f.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.f0d99f98.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/Docstring.03f7b462.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/CodeBlock.169a125f.js">
<link rel="modulepreload" href="/docs/trl/pr_5607/en/_app/immutable/chunks/ExampleCodeBlock.415f9452.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Reward Modeling&quot;,&quot;local&quot;:&quot;reward-modeling&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Overview&quot;,&quot;local&quot;:&quot;overview&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Quick start&quot;,&quot;local&quot;:&quot;quick-start&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Expected dataset type and format&quot;,&quot;local&quot;:&quot;expected-dataset-type-and-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Looking deeper into the training method&quot;,&quot;local&quot;:&quot;looking-deeper-into-the-training-method&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Preprocessing and tokenization&quot;,&quot;local&quot;:&quot;preprocessing-and-tokenization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Computing the loss&quot;,&quot;local&quot;:&quot;computing-the-loss&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Logged metrics&quot;,&quot;local&quot;:&quot;logged-metrics&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Customization&quot;,&quot;local&quot;:&quot;customization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Model initialization&quot;,&quot;local&quot;:&quot;model-initialization&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Train adapters with PEFT&quot;,&quot;local&quot;:&quot;train-adapters-with-peft&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Tool Calling with Reward Modeling&quot;,&quot;local&quot;:&quot;tool-calling-with-reward-modeling&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;RewardTrainer&quot;,&quot;local&quot;:&quot;trl.RewardTrainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;RewardConfig&quot;,&quot;local&quot;:&quot;trl.RewardConfig&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="reward-modeling" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#reward-modeling"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Reward Modeling</span></h1> <p data-svelte-h="svelte-1rjidu2"><a href="https://huggingface.co/models?other=reward-trainer,trl" rel="nofollow"><img src="https://img.shields.io/badge/All_models-Reward_Trainer-blue" alt="model badge"></a></p> <h2 class="relative group"><a id="overview" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#overview"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Overview</span></h2> <p data-svelte-h="svelte-1ti5dgc">TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models.</p> <p data-svelte-h="svelte-a1ehbo">This post-training method was contributed by <a href="https://huggingface.co/ybelkada" rel="nofollow">Younes Belkada</a>.</p> <h2 class="relative group"><a id="quick-start" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#quick-start"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Quick start</span></h2> <p data-svelte-h="svelte-1rk39qy">This example demonstrates how to train a reward model using the <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> from TRL. We train a <a href="https://huggingface.co/Qwen/Qwen3-0.6B" rel="nofollow">Qwen 3 0.6B</a> model on the <a href="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized" rel="nofollow">UltraFeedback dataset</a>, large-scale, fine-grained, diverse preference dataset.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> RewardTrainer
<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
trainer = RewardTrainer(
model=<span class="hljs-string">&quot;Qwen/Qwen3-0.6B&quot;</span>,
train_dataset=load_dataset(<span class="hljs-string">&quot;trl-lib/ultrafeedback_binarized&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>),
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&sidebar=hidden&runs=reward_qwen3-0.6B_ultrafeedback2" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameborder="0"></iframe> <h2 class="relative group"><a id="expected-dataset-type-and-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#expected-dataset-type-and-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Expected dataset type and format</span></h2> <p data-svelte-h="svelte-m5hlgx"><a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> supports <a href="dataset_formats#preference">preference</a> datasets type (both implicit and explicit prompt). The <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> is compatible with both <a href="dataset_formats#standard">standard</a> and <a href="dataset_formats#conversational">conversational</a> dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Standard preference (implicit prompt)</span>
{<span class="hljs-string">&quot;chosen&quot;</span>: <span class="hljs-string">&quot;The sky is blue.&quot;</span>,
<span class="hljs-string">&quot;rejected&quot;</span>: <span class="hljs-string">&quot;The sky is green.&quot;</span>}
<span class="hljs-comment"># Conversational preference (implicit prompt)</span>
{<span class="hljs-string">&quot;chosen&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;What color is the sky?&quot;</span>},
{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;It is blue.&quot;</span>}],
<span class="hljs-string">&quot;rejected&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;What color is the sky?&quot;</span>},
{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;It is green.&quot;</span>}]}
<span class="hljs-comment"># Standard preference (explicit prompt)</span>
{<span class="hljs-string">&quot;prompt&quot;</span>: <span class="hljs-string">&quot;The sky is&quot;</span>,
<span class="hljs-string">&quot;chosen&quot;</span>: <span class="hljs-string">&quot; blue.&quot;</span>,
<span class="hljs-string">&quot;rejected&quot;</span>: <span class="hljs-string">&quot; green.&quot;</span>}
<span class="hljs-comment"># Conversational preference (explicit prompt)</span>
{<span class="hljs-string">&quot;prompt&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;What color is the sky?&quot;</span>}],
<span class="hljs-string">&quot;chosen&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;It is blue.&quot;</span>}],
<span class="hljs-string">&quot;rejected&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: <span class="hljs-string">&quot;It is green.&quot;</span>}]}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-wa19cs">If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the <a href="https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k" rel="nofollow">lmarena-ai/arena-human-preference-55k</a> dataset:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">import</span> json
dataset = load_dataset(<span class="hljs-string">&quot;lmarena-ai/arena-human-preference-55k&quot;</span>)
<span class="hljs-comment"># Filter out ties</span>
dataset = dataset.<span class="hljs-built_in">filter</span>(<span class="hljs-keyword">lambda</span> example: example[<span class="hljs-string">&quot;winner_tie&quot;</span>] == <span class="hljs-number">0</span>)
<span class="hljs-comment"># Create &#x27;chosen&#x27; and &#x27;rejected&#x27; fields based on the winner column</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">response_a_b_to_chosen_rejected</span>(<span class="hljs-params">example</span>):
<span class="hljs-keyword">if</span> example[<span class="hljs-string">&quot;winner_model_a&quot;</span>] == <span class="hljs-number">1</span>:
example[<span class="hljs-string">&quot;chosen&quot;</span>] = example[<span class="hljs-string">&quot;response_a&quot;</span>]
example[<span class="hljs-string">&quot;rejected&quot;</span>] = example[<span class="hljs-string">&quot;response_b&quot;</span>]
<span class="hljs-keyword">else</span>:
example[<span class="hljs-string">&quot;chosen&quot;</span>] = example[<span class="hljs-string">&quot;response_b&quot;</span>]
example[<span class="hljs-string">&quot;rejected&quot;</span>] = example[<span class="hljs-string">&quot;response_a&quot;</span>]
<span class="hljs-keyword">return</span> example
dataset = dataset.<span class="hljs-built_in">map</span>(response_a_b_to_chosen_rejected)
<span class="hljs-comment"># Convert to conversational format</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">make_conversation</span>(<span class="hljs-params">example</span>):
prompt = json.loads(example[<span class="hljs-string">&quot;prompt&quot;</span>])[<span class="hljs-number">0</span>] <span class="hljs-comment"># &#x27;[&quot;What color is the sky?&quot;]&#x27; -&gt; &quot;What color is the sky?&quot;</span>
chosen = json.loads(example[<span class="hljs-string">&quot;chosen&quot;</span>])[<span class="hljs-number">0</span>]
rejected = json.loads(example[<span class="hljs-string">&quot;rejected&quot;</span>])[<span class="hljs-number">0</span>]
<span class="hljs-keyword">return</span> {
<span class="hljs-string">&quot;chosen&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: prompt}, {<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: chosen}],
<span class="hljs-string">&quot;rejected&quot;</span>: [{<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;user&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: prompt}, {<span class="hljs-string">&quot;role&quot;</span>: <span class="hljs-string">&quot;assistant&quot;</span>, <span class="hljs-string">&quot;content&quot;</span>: rejected}],
}
dataset = dataset.<span class="hljs-built_in">map</span>(make_conversation)
<span class="hljs-comment"># Keep only necessary columns</span>
dataset = dataset.select_columns([<span class="hljs-string">&quot;chosen&quot;</span>, <span class="hljs-string">&quot;rejected&quot;</span>])
<span class="hljs-built_in">print</span>(<span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(dataset[<span class="hljs-string">&quot;train&quot;</span>])))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-punctuation">{</span>
<span class="hljs-attr">&quot;chosen&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-punctuation">[</span>
<span class="hljs-punctuation">{</span><span class="hljs-attr">&quot;role&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;user&quot;</span><span class="hljs-punctuation">,</span> <span class="hljs-attr">&quot;content&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;Is it morally right to try to have a certain percentage of females on managerial positions?&quot;</span><span class="hljs-punctuation">}</span><span class="hljs-punctuation">,</span>
<span class="hljs-punctuation">{</span><span class="hljs-attr">&quot;role&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;assistant&quot;</span><span class="hljs-punctuation">,</span> <span class="hljs-attr">&quot;content&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;The question of whether it is morally right to aim for a certain percentage of females...&quot;</span><span class="hljs-punctuation">}</span><span class="hljs-punctuation">,</span>
<span class="hljs-punctuation">]</span><span class="hljs-punctuation">,</span>
<span class="hljs-attr">&quot;rejected&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-punctuation">[</span>
<span class="hljs-punctuation">{</span><span class="hljs-attr">&quot;role&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;user&quot;</span><span class="hljs-punctuation">,</span> <span class="hljs-attr">&quot;content&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;Is it morally right to try to have a certain percentage of females on managerial positions?&quot;</span><span class="hljs-punctuation">}</span><span class="hljs-punctuation">,</span>
<span class="hljs-punctuation">{</span><span class="hljs-attr">&quot;role&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;assistant&quot;</span><span class="hljs-punctuation">,</span> <span class="hljs-attr">&quot;content&quot;</span><span class="hljs-punctuation">:</span> <span class="hljs-string">&quot;As an AI, I don&#x27;t have personal beliefs or opinions. However, ...&quot;</span><span class="hljs-punctuation">}</span><span class="hljs-punctuation">,</span>
<span class="hljs-punctuation">]</span><span class="hljs-punctuation">,</span>
<span class="hljs-punctuation">}</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="looking-deeper-into-the-training-method" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#looking-deeper-into-the-training-method"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Looking deeper into the training method</span></h2> <p data-svelte-h="svelte-cp5aph">Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences.</p> <p data-svelte-h="svelte-r5fokq">This section breaks down how reward modeling works in practice, covering the key steps: <strong>preprocessing</strong> and <strong>loss computation</strong>.</p> <h3 class="relative group"><a id="preprocessing-and-tokenization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#preprocessing-and-tokenization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Preprocessing and tokenization</span></h3> <p data-svelte-h="svelte-87ovcn">During training, each example is expected to contain a <strong>chosen</strong> and <strong>rejected</strong> field. For more details on the expected formats, see <a href="dataset_formats#preference">Dataset formats - Preference</a>.
The <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> tokenizes each input using the model’s tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization.</p> <h3 class="relative group"><a id="computing-the-loss" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#computing-the-loss"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Computing the loss</span></h3> <p>Let <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex"> x </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal">x</span></span></span></span><!-- HTML_TAG_END --> be the input sequence (prompt) and <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo>+</mo></msup></mrow><annotation encoding="application/x-tex"> y^+ </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9658em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> and <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo></mo></msup></mrow><annotation encoding="application/x-tex"> y^- </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9658em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> be the chosen and rejected sequences respectively. Under the Bradley-Terry model (<a href="https://www.jstor.org/stable/2334029" rel="nofollow" data-svelte-h="svelte-18polgg">Bradley &amp; Terry, 1952</a>), the probability that <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo>+</mo></msup></mrow><annotation encoding="application/x-tex"> y^+ </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9658em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> is preferred over <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo></mo></msup></mrow><annotation encoding="application/x-tex"> y^- </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9658em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> given a reward function <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi></mrow><annotation encoding="application/x-tex"> r </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span></span></span></span><!-- HTML_TAG_END --> is <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><msup><mi>y</mi><mo>+</mo></msup><mo></mo><msup><mi>y</mi><mo></mo></msup><mi mathvariant="normal"></mi><mi>x</mi><mo stretchy="false">)</mo><mo>=</mo><mi>σ</mi><mo stretchy="false">(</mo><mi>r</mi><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><msup><mi>y</mi><mo>+</mo></msup><mo stretchy="false">)</mo><mo></mo><mi>r</mi><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><msup><mi>y</mi><mo></mo></msup><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex"> p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0213em;vertical-align:-0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel"></span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0213em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span><span class="mord"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.0213em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222em;"></span></span><span class="base"><span class="strut" style="height:1.0213em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span><span class="mclose">))</span></span></span></span><!-- HTML_TAG_END -->, where <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>σ</mi></mrow><annotation encoding="application/x-tex"> σ </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.4306em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span></span></span></span><!-- HTML_TAG_END --> is the sigmoid function.</p> <p>The reward model <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>r</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><mi>y</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex"> r_\theta(x, y) </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="mclose">)</span></span></span></span><!-- HTML_TAG_END --> is trained to assign higher scores to preferred responses <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo>+</mo></msup></mrow><annotation encoding="application/x-tex"> y^+ </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9658em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END --> over non-preferred ones <!-- HTML_TAG_START --><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>y</mi><mo></mo></msup></mrow><annotation encoding="application/x-tex"> y^- </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.9658em;vertical-align:-0.1944em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7713em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span></span></span></span><!-- HTML_TAG_END -->. The loss is then defined as the negative log-likelihood of the observed preferences:
<!-- HTML_TAG_START --><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi mathvariant="script">L</mi><mo stretchy="false">(</mo><mi>θ</mi><mo stretchy="false">)</mo><mo>=</mo><mo></mo><msub><mi mathvariant="double-struck">E</mi><mrow><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><msup><mi>y</mi><mo>+</mo></msup><mo separator="true">,</mo><msup><mi>y</mi><mo></mo></msup><mo stretchy="false">)</mo><mo></mo><mi mathvariant="script">D</mi></mrow></msub><mrow><mo fence="true">[</mo><mi>log</mi><mo></mo><mi>σ</mi><mo stretchy="false">(</mo><msub><mi>r</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><msup><mi>y</mi><mo>+</mo></msup><mo stretchy="false">)</mo><mo></mo><msub><mi>r</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><msup><mi>y</mi><mo></mo></msup><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo fence="true">]</mo></mrow><mi mathvariant="normal">.</mi></mrow><annotation encoding="application/x-tex">
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right].
</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathcal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2778em;"></span></span><span class="base"><span class="strut" style="height:1.2052em;vertical-align:-0.3552em;"></span><span class="mord"></span><span class="mord"><span class="mord mathbb">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathnormal mtight">x</span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7027em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span><span class="mpunct mtight">,</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.7027em;"><span style="top:-2.786em;margin-right:0.0714em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span><span class="mclose mtight">)</span><span class="mrel mtight"></span><span class="mord mathcal mtight" style="margin-right:0.02778em;">D</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3552em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size1">[</span></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord mathnormal" style="margin-right:0.03588em;">σ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8213em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">+</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">r</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361em;"><span style="top:-2.55em;margin-left:-0.0278em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8213em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight"></span></span></span></span></span></span></span></span><span class="mclose">))</span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size1">]</span></span></span><span class="mspace" style="margin-right:0.1667em;"></span><span class="mord">.</span></span></span></span></span><!-- HTML_TAG_END --></p> <blockquote class="tip" data-svelte-h="svelte-1nlt6oo"><p>The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, <a href="https://huggingface.co/papers/2312.09244" rel="nofollow">Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking</a> proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the <code>center_rewards_coefficient</code> parameter in the <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardConfig">RewardConfig</a>. The recommended value is <code>1e-2</code>.</p></blockquote> <h2 class="relative group"><a id="logged-metrics" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#logged-metrics"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Logged metrics</span></h2> <p data-svelte-h="svelte-132s7j9">While training and evaluating we record the following reward metrics:</p> <ul data-svelte-h="svelte-ecm0rq"><li><code>global_step</code>: The total number of optimizer steps taken so far.</li> <li><code>epoch</code>: The current epoch number, based on dataset iteration.</li> <li><code>num_tokens</code>: The total number of tokens processed so far.</li> <li><code>loss</code>: The average loss over the last logging interval.</li> <li><code>accuracy</code>: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval.</li> <li><code>min_reward</code>: The minimum reward score assigned by the model. This value is averaged over the logging interval.</li> <li><code>mean_reward</code>: The average reward score assigned by the model over the last logging interval.</li> <li><code>max_reward</code>: The maximum reward score assigned by the model. This value is averaged over the logging interval.</li> <li><code>margin</code>: The average margin (difference between chosen and rejected rewards) over the last logging interval.</li> <li><code>learning_rate</code>: The current learning rate, which may change dynamically if a scheduler is used.</li> <li><code>grad_norm</code>: The L2 norm of the gradients, computed before gradient clipping.</li></ul> <h2 class="relative group"><a id="customization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#customization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Customization</span></h2> <h3 class="relative group"><a id="model-initialization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-initialization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model initialization</span></h3> <p data-svelte-h="svelte-1ieqrp4">You can directly pass the kwargs of the <code>from_pretrained()</code> method to the <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardConfig">RewardConfig</a>. For example, if you want to load a model in a different precision, analogous to</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->model = AutoModelForSequenceClassification.from_pretrained(<span class="hljs-string">&quot;Qwen/Qwen3-0.6B&quot;</span>, dtype=torch.bfloat16)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1riioh2">you can do so by passing the <code>model_init_kwargs={&quot;dtype&quot;: torch.bfloat16}</code> argument to the <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardConfig">RewardConfig</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> RewardConfig
training_args = RewardConfig(
model_init_kwargs={<span class="hljs-string">&quot;dtype&quot;</span>: torch.bfloat16},
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zb69hw">Note that all keyword arguments of <code>from_pretrained()</code> are supported, except for <code>num_labels</code>, which is automatically set to 1.</p> <h3 class="relative group"><a id="train-adapters-with-peft" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#train-adapters-with-peft"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Train adapters with PEFT</span></h3> <p data-svelte-h="svelte-t2zuq8">We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> RewardTrainer
<span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> LoraConfig
dataset = load_dataset(<span class="hljs-string">&quot;trl-lib/ultrafeedback_binarized&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
trainer = RewardTrainer(
<span class="hljs-string">&quot;Qwen/Qwen3-4B&quot;</span>,
train_dataset=dataset,
peft_config=LoraConfig(modules_to_save=[<span class="hljs-string">&quot;score&quot;</span>]) <span class="hljs-comment"># important to include the score head when base model is not a sequence classification model</span>
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-11e9nof">You can also continue training your <a href="https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel" rel="nofollow">PeftModel</a>. For that, first load a <code>PeftModel</code> outside <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> and pass it directly to the trainer without the <code>peft_config</code> argument being passed.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> RewardTrainer
<span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(<span class="hljs-string">&quot;trl-lib/Qwen3-4B-Reward-LoRA&quot;</span>, is_trainable=<span class="hljs-literal">True</span>)
dataset = load_dataset(<span class="hljs-string">&quot;trl-lib/Capybara&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
trainer = RewardTrainer(
model=model,
train_dataset=dataset,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <blockquote class="tip"><p data-svelte-h="svelte-dc5ccy">When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->RewardConfig(learning_rate=<span class="hljs-number">1e-3</span>, ...)<!-- HTML_TAG_END --></pre></div></blockquote> <h2 class="relative group"><a id="tool-calling-with-reward-modeling" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#tool-calling-with-reward-modeling"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Tool Calling with Reward Modeling</span></h2> <p data-svelte-h="svelte-iquuc2">The <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> fully supports fine-tuning models with <em>tool calling</em> capabilities. In this case, each dataset example should include:</p> <ul data-svelte-h="svelte-1vlmw2d"><li>The conversation messages, including any tool calls (<code>tool_calls</code>) and tool responses (<code>tool</code> role messages)</li> <li>The list of available tools in the <code>tools</code> column, typically provided as JSON schemas</li></ul> <p data-svelte-h="svelte-vl4ede">For details on the expected dataset structure, see the <a href="dataset_formats#tool-calling">Dataset Format — Tool Calling</a> section.</p> <h2 class="relative group"><a id="trl.RewardTrainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>RewardTrainer</span></h2> <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> <div><span class="group flex space-x-1.5 items-center text-gray-800 bg-gradient-to-r rounded-tr-lg -mt-4 -ml-4 pt-3 px-2.5" id="trl.RewardTrainer"><!-- HTML_TAG_START --><h3 class="!m-0"><span class="flex-1 break-all md:text-lg bg-gradient-to-r px-2.5 py-1.5 rounded-xl from-indigo-50/70 to-white dark:from-gray-900 dark:to-gray-950 dark:text-indigo-300 text-indigo-700"><svg class="mr-1.5 text-indigo-500 inline-block -mt-0.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width=".8em" height=".8em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-quaternary" d="M20.23 7.24L12 12L3.77 7.24a1.98 1.98 0 0 1 .7-.71L11 2.76c.62-.35 1.38-.35 2 0l6.53 3.77c.29.173.531.418.7.71z" opacity=".25" fill="currentColor"></path><path class="uim-tertiary" d="M12 12v9.5a2.09 2.09 0 0 1-.91-.21L4.5 17.48a2.003 2.003 0 0 1-1-1.73v-7.5a2.06 2.06 0 0 1 .27-1.01L12 12z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M20.5 8.25v7.5a2.003 2.003 0 0 1-1 1.73l-6.62 3.82c-.275.13-.576.198-.88.2V12l8.23-4.76c.175.308.268.656.27 1.01z" fill="currentColor"></path></svg><span class="font-light">class</span> <span class="font-medium">trl.</span><span class="font-semibold">RewardTrainer</span></span></h3><!-- HTML_TAG_END --> <a id="trl.RewardTrainer" class="header-link invisible with-hover:group-hover:visible pr-2" href="#trl.RewardTrainer"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></a> <a class="!ml-auto !text-gray-400 !no-underline text-sm flex items-center" href="https://github.com/huggingface/trl/blob/vr_5607/trl/trainer/reward_trainer.py#L229" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span class="hidden md:block mx-0.5 hover:!underline" data-svelte-h="svelte-122apf4">source</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span></a></span> <p class="font-mono text-xs md:text-sm !leading-relaxed !my-6"><span data-svelte-h="svelte-8mvn6a">(</span> <span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">model<span class="opacity-60">: str | PreTrainedModel | PeftModel</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">args<span class="opacity-60">: trl.trainer.reward_config.RewardConfig | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">data_collator<span class="opacity-60">: collections.abc.Callable[[list[typing.Any]], dict[str, typing.Any]] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">train_dataset<span class="opacity-60">: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_dataset<span class="opacity-60">: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | dict[str, datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">processing_class<span class="opacity-60">: transformers.tokenization_utils_base.PreTrainedTokenizerBase | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">compute_metrics<span class="opacity-60">: collections.abc.Callable[[transformers.trainer_utils.EvalPrediction], dict] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">callbacks<span class="opacity-60">: list[transformers.trainer_callback.TrainerCallback] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">optimizers<span class="opacity-60">: tuple = (None, None)</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">optimizer_cls_and_kwargs<span class="opacity-60">: tuple[type[torch.optim.optimizer.Optimizer], dict[str, typing.Any]] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">preprocess_logits_for_metrics<span class="opacity-60">: collections.abc.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">peft_config<span class="opacity-60">: PeftConfig | None = None</span></span> </span> <span data-svelte-h="svelte-1jq0pl7">)</span> </p> <div class="!mb-10 relative docstring-details "> <p class="flex items-center font-semibold !mt-2 !mb-2 text-gray-800" data-svelte-h="svelte-lt6pb6">Parameters <span class="flex-auto border-t-2 border-gray-100 dark:border-gray-700 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.model" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.model"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>model</strong> (<code>str</code> or <a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel" rel="nofollow">PreTrainedModel</a> or <a href="https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel" rel="nofollow">PeftModel</a>) &#x2014;
Model to be trained. Can be either:</p>
<ul>
<li>A string, being the <em>model id</em> of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a <em>directory</em> containing model weights saved using
<a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.save_pretrained" rel="nofollow">save_pretrained</a>, e.g., <code>&apos;./my_model_directory/&apos;</code>. The model is loaded
using <code>AutoModelForSequenceClassification.from_pretrained</code> with the keyword arguments in
<code>args.model_init_kwargs</code>.</li>
<li>A sequence classification <a href="https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel" rel="nofollow">PreTrainedModel</a> object.</li>
<li>A sequence classification <a href="https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel" rel="nofollow">PeftModel</a> object.</li>
</ul><!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.args" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.args"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>args</strong> (<a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardConfig">RewardConfig</a>, <em>optional</em>) &#x2014;
Configuration for this trainer. If <code>None</code>, a default configuration is used.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.data_collator" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.data_collator"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>data_collator</strong> (<code>DataCollator</code>, <em>optional</em>) &#x2014;
Function to use to form a batch from a list of elements of the processed <code>train_dataset</code> or <code>eval_dataset</code>.
Will default to <code>DataCollatorForPreference</code>.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.train_dataset" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.train_dataset"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>train_dataset</strong> (<a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset" rel="nofollow">Dataset</a> or <a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.IterableDataset" rel="nofollow">IterableDataset</a>) &#x2014;
Dataset to use for training. This trainer supports <a href="#preference">preference</a> type (both implicit and
explicit prompt). The format of the samples can be either:</p>
<ul>
<li><a href="dataset_formats#standard">Standard</a>: Each sample contains plain text.</li>
<li><a href="dataset_formats#conversational">Conversational</a>: Each sample contains structured messages (e.g., role
and content).</li>
</ul>
<p>The trainer also supports processed datasets (tokenized) as long as they contain <code>chosen_ids</code> and
<code>rejected_ids</code> fields.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.eval_dataset" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.eval_dataset"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>eval_dataset</strong> (<a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset" rel="nofollow">Dataset</a>, <a href="https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.IterableDataset" rel="nofollow">IterableDataset</a> or <code>dict[str, Dataset | IterableDataset]</code>) &#x2014;
Dataset to use for evaluation. It must meet the same requirements as <code>train_dataset</code>.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.processing_class" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.processing_class"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>processing_class</strong> (<a href="https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase" rel="nofollow">PreTrainedTokenizerBase</a>, <em>optional</em>) &#x2014;
Tokenizer used to process the data. If <code>None</code>, the tokenizer is loaded from the model&#x2019;s name with
<a href="https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained" rel="nofollow">from_pretrained</a>. A padding token, <code>processing_class.pad_token</code>, must be
set. If the processing class has not set a padding token, <code>processing_class.eos_token</code> will be used as the
default.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.compute_metrics" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.compute_metrics"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>compute_metrics</strong> (<code>Callable[[EvalPrediction], dict]</code>, <em>optional</em>) &#x2014;
The function that will be used to compute metrics at evaluation. Must take a
<a href="https://huggingface.co/docs/transformers/main/en/internal/trainer_utils#transformers.EvalPrediction" rel="nofollow">EvalPrediction</a> and return a dictionary string to metric values. When passing
<a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardConfig">RewardConfig</a> with <code>batch_eval_metrics</code> set to <code>True</code>, your <code>compute_metrics</code> function must take a
boolean <code>compute_result</code> argument. This will be triggered after the last eval batch to signal that the
function needs to calculate and return the global summary statistics rather than accumulating the
batch-level statistics.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.callbacks" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.callbacks"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>callbacks</strong> (list of <a href="https://huggingface.co/docs/transformers/main/en/main_classes/callback#transformers.TrainerCallback" rel="nofollow">TrainerCallback</a>, <em>optional</em>) &#x2014;
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
in <a href="https://huggingface.co/docs/transformers/main_classes/callback" rel="nofollow">here</a>.</p>
<p>If you want to remove one of the default callbacks used, use the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.remove_callback" rel="nofollow">remove_callback</a>
method.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.optimizers" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.optimizers"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>optimizers</strong> (<code>tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]</code>, <em>optional</em>, defaults to <code>(None, None)</code>) &#x2014;
A tuple containing the optimizer and the scheduler to use. Will default to an instance of <code>AdamW</code> on your
model and a scheduler given by <a href="https://huggingface.co/docs/transformers/main/en/main_classes/optimizer_schedules#transformers.get_linear_schedule_with_warmup" rel="nofollow">get_linear_schedule_with_warmup</a> controlled by <code>args</code>.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.optimizer_cls_and_kwargs" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.optimizer_cls_and_kwargs"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>optimizer_cls_and_kwargs</strong> (<code>tuple[Type[torch.optim.Optimizer], Dict[str, Any]]</code>, <em>optional</em>) &#x2014;
A tuple containing the optimizer class and keyword arguments to use. Overrides <code>optim</code> and <code>optim_args</code> in
<code>args</code>. Incompatible with the <code>optimizers</code> argument.</p>
<p>Unlike <code>optimizers</code>, this argument avoids the need to place model parameters on the correct devices before
initializing the Trainer.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.preprocess_logits_for_metrics" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.preprocess_logits_for_metrics"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>preprocess_logits_for_metrics</strong> (<code>Callable[[torch.Tensor, torch.Tensor], torch.Tensor]</code>, <em>optional</em>) &#x2014;
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by <code>compute_metrics</code>.</p>
<p>Note that the labels (second parameter) will be <code>None</code> if the dataset does not have them.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.peft_config" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.peft_config"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>peft_config</strong> (<a href="https://huggingface.co/docs/peft/main/en/package_reference/config#peft.PeftConfig" rel="nofollow">PeftConfig</a>, <em>optional</em>) &#x2014;
PEFT configuration used to wrap the model. If <code>None</code>, the model is not wrapped. Note that if the loaded
model is a causal LM, it&#x2019;s highly recommended to set <code>modules_to_save=[&quot;score&quot;]</code> in the PEFT configuration
to ensure that the reward head is properly trained.<!-- HTML_TAG_END --> </span></span> </li></ul> </div></div> <p data-svelte-h="svelte-1prnqj7">Trainer for Outcome-supervised Reward Models (ORM).</p> <p data-svelte-h="svelte-10vjtjm">This class is a wrapper around the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer" rel="nofollow">Trainer</a> class and inherits all of its attributes and methods.</p> <div class="relative group rounded-md"><a id="trl.RewardTrainer.example" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.example"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <p data-svelte-h="svelte-11lpom8">Example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> RewardTrainer
<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
dataset = load_dataset(<span class="hljs-string">&quot;trl-lib/ultrafeedback_binarized&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
trainer = RewardTrainer(
model=<span class="hljs-string">&quot;Qwen/Qwen2.5-0.5B-Instruct&quot;</span>,
train_dataset=dataset,
)
trainer.train()<!-- HTML_TAG_END --></pre></div></div> <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> <div><span class="group flex space-x-1.5 items-center text-gray-800 bg-gradient-to-r rounded-tr-lg -mt-4 -ml-4 pt-3 px-2.5" id="trl.RewardTrainer.train"><!-- HTML_TAG_START --><h4 class="!m-0"><span class="flex-1 rounded-xl py-0.5 break-all bg-gradient-to-r from-blue-50/60 to-white dark:from-gray-900 dark:to-gray-950 text-blue-700 dark:text-blue-300 font-medium px-2"><svg width="1em" height="1em" viewBox="0 0 32 33" class="mr-1 inline-block -mt-0.5" xmlns="http://www.w3.org/2000/svg"><path d="M5.80566 18.3545C4.90766 17.4565 4.90766 16.0005 5.80566 15.1025L14.3768 6.53142C15.2748 5.63342 16.7307 5.63342 17.6287 6.53142L26.1999 15.1025C27.0979 16.0005 27.0979 17.4565 26.1999 18.3545L17.6287 26.9256C16.7307 27.8236 15.2748 27.8236 14.3768 26.9256L5.80566 18.3545Z" fill="currentColor" fill-opacity="0.25"/><path fill-rule="evenodd" clip-rule="evenodd" d="M16.4801 13.9619C16.4801 12.9761 16.7467 12.5436 16.9443 12.3296C17.1764 12.078 17.5731 11.8517 18.2275 11.707C18.8821 11.5623 19.638 11.5342 20.4038 11.5582C20.7804 11.57 21.1341 11.5932 21.4719 11.6156L21.5263 11.6193C21.8195 11.6389 22.1626 11.6618 22.4429 11.6618V7.40825C22.3209 7.40825 22.1219 7.39596 21.7544 7.37149C21.4202 7.34925 20.9976 7.32115 20.5371 7.30672C19.6286 7.27824 18.4672 7.29779 17.3093 7.55377C16.1512 7.8098 14.8404 8.33724 13.8181 9.4452C12.7612 10.5907 12.2266 12.1236 12.2266 13.9619V15.0127H10.6836V19.2662H12.2266V26.6332H16.4801V19.2662H20.3394V15.0127H16.4801V13.9619Z" fill="currentColor"/></svg>train</span></h4><!-- HTML_TAG_END --> <a id="trl.RewardTrainer.train" class="header-link invisible with-hover:group-hover:visible pr-2" href="#trl.RewardTrainer.train"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></a> <a class="!ml-auto !text-gray-400 !no-underline text-sm flex items-center" href="https://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L1323" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span class="hidden md:block mx-0.5 hover:!underline" data-svelte-h="svelte-122apf4">source</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span></a></span> <p class="font-mono text-xs md:text-sm !leading-relaxed !my-6"><span data-svelte-h="svelte-8mvn6a">(</span> <span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">resume_from_checkpoint<span class="opacity-60">: str | bool | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">trial<span class="opacity-60">: optuna.Trial | dict[str, Any] | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ignore_keys_for_eval<span class="opacity-60">: list[str] | None = None</span></span> </span> <span data-svelte-h="svelte-1jq0pl7">)</span> <span class="font-bold" data-svelte-h="svelte-1j6k10o"></span> <span class="rounded hover:bg-gray-400 cursor-pointer"><!-- HTML_TAG_START --><script context="module">export const metadata = 'undefined';</script><span><code>~trainer_utils.TrainOutput</code></span><!-- HTML_TAG_END --></span></p> <div class="!mb-10 relative docstring-details "> <p class="flex items-center font-semibold !mt-2 !mb-2 text-gray-800" data-svelte-h="svelte-lt6pb6">Parameters <span class="flex-auto border-t-2 border-gray-100 dark:border-gray-700 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.train.resume_from_checkpoint" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.train.resume_from_checkpoint"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>resume_from_checkpoint</strong> (<code>str</code> or <code>bool</code>, <em>optional</em>) &#x2014;
If a <code>str</code>, local path to a saved checkpoint as saved by a previous instance of <code>Trainer</code>. If a
<code>bool</code> and equals <code>True</code>, load the last checkpoint in <em>args.output_dir</em> as saved by a previous instance
of <code>Trainer</code>. If present, training will resume from the model/optimizer/scheduler states loaded here.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.train.trial" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.train.trial"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>trial</strong> (<code>optuna.Trial</code> or <code>dict[str, Any]</code>, <em>optional</em>) &#x2014;
The trial run or the hyperparameter dictionary for hyperparameter search.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.train.ignore_keys_for_eval" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.train.ignore_keys_for_eval"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>ignore_keys_for_eval</strong> (<code>list[str]</code>, <em>optional</em>) &#x2014;
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions for evaluation during the training.<!-- HTML_TAG_END --> </span></span> </li></ul> <div id="trl.RewardTrainer.train.returns" class="flex items-center font-semibold space-x-3 text-base !mt-0 !mb-0 text-gray-800 rounded "><p class="text-base">Returns</p> <!-- HTML_TAG_START --><script context="module">export const metadata = 'undefined';</script>
<p><code>~trainer_utils.TrainOutput</code></p>
<!-- HTML_TAG_END --> <span class="flex-auto border-t-2 border-gray-100 dark:border-gray-700"></span></div> <p class="text-base"><!-- HTML_TAG_START --><script context="module">export const metadata = 'undefined';</script>
<p>Object containing the global step count, training loss, and metrics.</p>
<!-- HTML_TAG_END --></p> </div></div> <p data-svelte-h="svelte-1cilnet">Main training entry point.</p></div> <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> <div><span class="group flex space-x-1.5 items-center text-gray-800 bg-gradient-to-r rounded-tr-lg -mt-4 -ml-4 pt-3 px-2.5" id="trl.RewardTrainer.save_model"><!-- HTML_TAG_START --><h4 class="!m-0"><span class="flex-1 rounded-xl py-0.5 break-all bg-gradient-to-r from-blue-50/60 to-white dark:from-gray-900 dark:to-gray-950 text-blue-700 dark:text-blue-300 font-medium px-2"><svg width="1em" height="1em" viewBox="0 0 32 33" class="mr-1 inline-block -mt-0.5" xmlns="http://www.w3.org/2000/svg"><path d="M5.80566 18.3545C4.90766 17.4565 4.90766 16.0005 5.80566 15.1025L14.3768 6.53142C15.2748 5.63342 16.7307 5.63342 17.6287 6.53142L26.1999 15.1025C27.0979 16.0005 27.0979 17.4565 26.1999 18.3545L17.6287 26.9256C16.7307 27.8236 15.2748 27.8236 14.3768 26.9256L5.80566 18.3545Z" fill="currentColor" fill-opacity="0.25"/><path fill-rule="evenodd" clip-rule="evenodd" d="M16.4801 13.9619C16.4801 12.9761 16.7467 12.5436 16.9443 12.3296C17.1764 12.078 17.5731 11.8517 18.2275 11.707C18.8821 11.5623 19.638 11.5342 20.4038 11.5582C20.7804 11.57 21.1341 11.5932 21.4719 11.6156L21.5263 11.6193C21.8195 11.6389 22.1626 11.6618 22.4429 11.6618V7.40825C22.3209 7.40825 22.1219 7.39596 21.7544 7.37149C21.4202 7.34925 20.9976 7.32115 20.5371 7.30672C19.6286 7.27824 18.4672 7.29779 17.3093 7.55377C16.1512 7.8098 14.8404 8.33724 13.8181 9.4452C12.7612 10.5907 12.2266 12.1236 12.2266 13.9619V15.0127H10.6836V19.2662H12.2266V26.6332H16.4801V19.2662H20.3394V15.0127H16.4801V13.9619Z" fill="currentColor"/></svg>save_model</span></h4><!-- HTML_TAG_END --> <a id="trl.RewardTrainer.save_model" class="header-link invisible with-hover:group-hover:visible pr-2" href="#trl.RewardTrainer.save_model"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></a> <a class="!ml-auto !text-gray-400 !no-underline text-sm flex items-center" href="https://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L3746" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span class="hidden md:block mx-0.5 hover:!underline" data-svelte-h="svelte-122apf4">source</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span></a></span> <p class="font-mono text-xs md:text-sm !leading-relaxed !my-6"><span data-svelte-h="svelte-8mvn6a">(</span> <span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">output_dir<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">_internal_call<span class="opacity-60">: bool = False</span></span> </span> <span data-svelte-h="svelte-1jq0pl7">)</span> </p> <div class="!mb-10 relative docstring-details "> </div></div> <p data-svelte-h="svelte-r8h4ov">Will save the model, so you can reload it using <code>from_pretrained()</code>.</p> <p data-svelte-h="svelte-1e6bius">Will only save from the main process.</p></div> <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> <div><span class="group flex space-x-1.5 items-center text-gray-800 bg-gradient-to-r rounded-tr-lg -mt-4 -ml-4 pt-3 px-2.5" id="trl.RewardTrainer.push_to_hub"><!-- HTML_TAG_START --><h4 class="!m-0"><span class="flex-1 rounded-xl py-0.5 break-all bg-gradient-to-r from-blue-50/60 to-white dark:from-gray-900 dark:to-gray-950 text-blue-700 dark:text-blue-300 font-medium px-2"><svg width="1em" height="1em" viewBox="0 0 32 33" class="mr-1 inline-block -mt-0.5" xmlns="http://www.w3.org/2000/svg"><path d="M5.80566 18.3545C4.90766 17.4565 4.90766 16.0005 5.80566 15.1025L14.3768 6.53142C15.2748 5.63342 16.7307 5.63342 17.6287 6.53142L26.1999 15.1025C27.0979 16.0005 27.0979 17.4565 26.1999 18.3545L17.6287 26.9256C16.7307 27.8236 15.2748 27.8236 14.3768 26.9256L5.80566 18.3545Z" fill="currentColor" fill-opacity="0.25"/><path fill-rule="evenodd" clip-rule="evenodd" d="M16.4801 13.9619C16.4801 12.9761 16.7467 12.5436 16.9443 12.3296C17.1764 12.078 17.5731 11.8517 18.2275 11.707C18.8821 11.5623 19.638 11.5342 20.4038 11.5582C20.7804 11.57 21.1341 11.5932 21.4719 11.6156L21.5263 11.6193C21.8195 11.6389 22.1626 11.6618 22.4429 11.6618V7.40825C22.3209 7.40825 22.1219 7.39596 21.7544 7.37149C21.4202 7.34925 20.9976 7.32115 20.5371 7.30672C19.6286 7.27824 18.4672 7.29779 17.3093 7.55377C16.1512 7.8098 14.8404 8.33724 13.8181 9.4452C12.7612 10.5907 12.2266 12.1236 12.2266 13.9619V15.0127H10.6836V19.2662H12.2266V26.6332H16.4801V19.2662H20.3394V15.0127H16.4801V13.9619Z" fill="currentColor"/></svg>push_to_hub</span></h4><!-- HTML_TAG_END --> <a id="trl.RewardTrainer.push_to_hub" class="header-link invisible with-hover:group-hover:visible pr-2" href="#trl.RewardTrainer.push_to_hub"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></a> <a class="!ml-auto !text-gray-400 !no-underline text-sm flex items-center" href="https://github.com/huggingface/trl/blob/vr_5607/transformers/trainer.py#L3993" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span class="hidden md:block mx-0.5 hover:!underline" data-svelte-h="svelte-122apf4">source</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span></a></span> <p class="font-mono text-xs md:text-sm !leading-relaxed !my-6"><span data-svelte-h="svelte-8mvn6a">(</span> <span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">commit_message<span class="opacity-60">: str | None = 'End of training'</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">blocking<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">token<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-pointer"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">revision<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">**kwargs<span class="opacity-60"></span></span> </span> <span data-svelte-h="svelte-1jq0pl7">)</span> </p> <div class="!mb-10 relative docstring-details "> <p class="flex items-center font-semibold !mt-2 !mb-2 text-gray-800" data-svelte-h="svelte-lt6pb6">Parameters <span class="flex-auto border-t-2 border-gray-100 dark:border-gray-700 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.push_to_hub.commit_message" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.push_to_hub.commit_message"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>commit_message</strong> (<code>str</code>, <em>optional</em>, defaults to <code>&quot;End of training&quot;</code>) &#x2014;
Message to commit while pushing.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.push_to_hub.blocking" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.push_to_hub.blocking"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>blocking</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>True</code>) &#x2014;
Whether the function should return only when the <code>git push</code> has finished.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.push_to_hub.token" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.push_to_hub.token"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>token</strong> (<code>str</code>, <em>optional</em>, defaults to <code>None</code>) &#x2014;
Token with write permission to overwrite Trainer&#x2019;s original args.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.push_to_hub.revision" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.push_to_hub.revision"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>revision</strong> (<code>str</code>, <em>optional</em>) &#x2014;
The git revision to commit from. Defaults to the head of the &#x201C;main&#x201D; branch.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardTrainer.push_to_hub.kwargs" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardTrainer.push_to_hub.kwargs"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>kwargs</strong> (<code>dict[str, Any]</code>, <em>optional</em>) &#x2014;
Additional keyword arguments passed along to <code>~Trainer.create_model_card</code>.<!-- HTML_TAG_END --> </span></span> </li></ul> </div></div> <p data-svelte-h="svelte-8tudwd">Upload <code>self.model</code> and <code>self.processing_class</code> to the 🤗 model hub on the repo <code>self.args.hub_model_id</code>.</p></div></div> <h2 class="relative group"><a id="trl.RewardConfig" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>RewardConfig</span></h2> <div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8"> <div><span class="group flex space-x-1.5 items-center text-gray-800 bg-gradient-to-r rounded-tr-lg -mt-4 -ml-4 pt-3 px-2.5" id="trl.RewardConfig"><!-- HTML_TAG_START --><h3 class="!m-0"><span class="flex-1 break-all md:text-lg bg-gradient-to-r px-2.5 py-1.5 rounded-xl from-indigo-50/70 to-white dark:from-gray-900 dark:to-gray-950 dark:text-indigo-300 text-indigo-700"><svg class="mr-1.5 text-indigo-500 inline-block -mt-0.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width=".8em" height=".8em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><path class="uim-quaternary" d="M20.23 7.24L12 12L3.77 7.24a1.98 1.98 0 0 1 .7-.71L11 2.76c.62-.35 1.38-.35 2 0l6.53 3.77c.29.173.531.418.7.71z" opacity=".25" fill="currentColor"></path><path class="uim-tertiary" d="M12 12v9.5a2.09 2.09 0 0 1-.91-.21L4.5 17.48a2.003 2.003 0 0 1-1-1.73v-7.5a2.06 2.06 0 0 1 .27-1.01L12 12z" opacity=".5" fill="currentColor"></path><path class="uim-primary" d="M20.5 8.25v7.5a2.003 2.003 0 0 1-1 1.73l-6.62 3.82c-.275.13-.576.198-.88.2V12l8.23-4.76c.175.308.268.656.27 1.01z" fill="currentColor"></path></svg><span class="font-light">class</span> <span class="font-medium">trl.</span><span class="font-semibold">RewardConfig</span></span></h3><!-- HTML_TAG_END --> <a id="trl.RewardConfig" class="header-link invisible with-hover:group-hover:visible pr-2" href="#trl.RewardConfig"><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></a> <a class="!ml-auto !text-gray-400 !no-underline text-sm flex items-center" href="https://github.com/huggingface/trl/blob/vr_5607/trl/trainer/reward_config.py#L23" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span class="hidden md:block mx-0.5 hover:!underline" data-svelte-h="svelte-122apf4">source</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span></a></span> <p class="font-mono text-xs md:text-sm !leading-relaxed !my-6"><span data-svelte-h="svelte-8mvn6a">(</span> <span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">output_dir<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">per_device_train_batch_size<span class="opacity-60">: int = 8</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">num_train_epochs<span class="opacity-60">: float = 3.0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">max_steps<span class="opacity-60">: int = -1</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">learning_rate<span class="opacity-60">: float = 0.0001</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">lr_scheduler_type<span class="opacity-60">: transformers.trainer_utils.SchedulerType | str = 'linear'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">lr_scheduler_kwargs<span class="opacity-60">: dict | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">warmup_steps<span class="opacity-60">: float = 0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">optim<span class="opacity-60">: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">optim_args<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">weight_decay<span class="opacity-60">: float = 0.0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">adam_beta1<span class="opacity-60">: float = 0.9</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">adam_beta2<span class="opacity-60">: float = 0.999</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">adam_epsilon<span class="opacity-60">: float = 1e-08</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">optim_target_modules<span class="opacity-60">: None | str | list[str] = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">gradient_accumulation_steps<span class="opacity-60">: int = 1</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">average_tokens_across_devices<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">max_grad_norm<span class="opacity-60">: float = 1.0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">label_smoothing_factor<span class="opacity-60">: float = 0.0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">bf16<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">fp16<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">bf16_full_eval<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">fp16_full_eval<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">tf32<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">gradient_checkpointing<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">gradient_checkpointing_kwargs<span class="opacity-60">: dict[str, typing.Any] | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">torch_compile<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">torch_compile_backend<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">torch_compile_mode<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">use_liger_kernel<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">liger_kernel_config<span class="opacity-60">: dict[str, bool] | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">use_cache<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">neftune_noise_alpha<span class="opacity-60">: float | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">torch_empty_cache_steps<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">auto_find_batch_size<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">logging_strategy<span class="opacity-60">: transformers.trainer_utils.IntervalStrategy | str = 'steps'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">logging_steps<span class="opacity-60">: float = 10</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">logging_first_step<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">log_on_each_node<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">logging_nan_inf_filter<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">include_num_input_tokens_seen<span class="opacity-60">: str | bool = 'no'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">log_level<span class="opacity-60">: str = 'passive'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">log_level_replica<span class="opacity-60">: str = 'warning'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">disable_tqdm<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">report_to<span class="opacity-60">: None | str | list[str] = 'none'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">run_name<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">project<span class="opacity-60">: str = 'huggingface'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">trackio_space_id<span class="opacity-60">: str | None = 'trackio'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_strategy<span class="opacity-60">: transformers.trainer_utils.IntervalStrategy | str = 'no'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_steps<span class="opacity-60">: float | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_delay<span class="opacity-60">: float = 0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">per_device_eval_batch_size<span class="opacity-60">: int = 8</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">prediction_loss_only<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_on_start<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_do_concat_batches<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_use_gather_object<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eval_accumulation_steps<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">include_for_metrics<span class="opacity-60">: list = &lt;factory></span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">batch_eval_metrics<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">save_only_model<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">save_strategy<span class="opacity-60">: transformers.trainer_utils.SaveStrategy | str = 'steps'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">save_steps<span class="opacity-60">: float = 500</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">save_on_each_node<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">save_total_limit<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">enable_jit_checkpoint<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">push_to_hub<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">hub_token<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">hub_private_repo<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">hub_model_id<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">hub_strategy<span class="opacity-60">: transformers.trainer_utils.HubStrategy | str = 'every_save'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">hub_always_push<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">hub_revision<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">load_best_model_at_end<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">metric_for_best_model<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">greater_is_better<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ignore_data_skip<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">restore_callback_states_from_checkpoint<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">full_determinism<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">seed<span class="opacity-60">: int = 42</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">data_seed<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">use_cpu<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">accelerator_config<span class="opacity-60">: dict | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">parallelism_config<span class="opacity-60">: accelerate.parallelism_config.ParallelismConfig | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">dataloader_drop_last<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">dataloader_num_workers<span class="opacity-60">: int = 0</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">dataloader_pin_memory<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">dataloader_persistent_workers<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">dataloader_prefetch_factor<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">remove_unused_columns<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">label_names<span class="opacity-60">: list[str] | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">train_sampling_strategy<span class="opacity-60">: str = 'random'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">length_column_name<span class="opacity-60">: str = 'length'</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ddp_find_unused_parameters<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ddp_bucket_cap_mb<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ddp_broadcast_buffers<span class="opacity-60">: bool | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ddp_backend<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">ddp_timeout<span class="opacity-60">: int = 1800</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">fsdp<span class="opacity-60">: list[transformers.trainer_utils.FSDPOption] | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">fsdp_config<span class="opacity-60">: dict[str, typing.Any] | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">deepspeed<span class="opacity-60">: dict | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">debug<span class="opacity-60">: str | list[transformers.debug_utils.DebugOption] = ''</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">skip_memory_metrics<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">do_train<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">do_eval<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">do_predict<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">resume_from_checkpoint<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">warmup_ratio<span class="opacity-60">: float | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">logging_dir<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">local_rank<span class="opacity-60">: int = -1</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">model_init_kwargs<span class="opacity-60">: dict[str, typing.Any] | str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">chat_template_path<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">disable_dropout<span class="opacity-60">: bool = True</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">dataset_num_proc<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">eos_token<span class="opacity-60">: str | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">max_length<span class="opacity-60">: int | None = 1024</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">pad_to_multiple_of<span class="opacity-60">: int | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">center_rewards_coefficient<span class="opacity-60">: float | None = None</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">activation_offloading<span class="opacity-60">: bool = False</span></span> </span><span class="comma cursor-default"><span class="rounded hover:bg-black hover:text-white dark:hover:bg-white dark:hover:text-black">pad_token<span class="opacity-60">: str | None = None</span></span> </span> <span data-svelte-h="svelte-1jq0pl7">)</span> </p> <div class="!mb-10 relative docstring-details "> <p class="flex items-center font-semibold">Parameters that control the model <span class="flex-auto border-t-2 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.model_init_kwargs" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.model_init_kwargs"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>model_init_kwargs</strong> (<code>dict[str, Any]</code>, <em>optional</em>) &#x2014;
Keyword arguments for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained" rel="nofollow">from_pretrained</a>, used when the <code>model</code>
argument of the <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a> is provided as a string. If you&#x2019;re training a MoE architecture and want
to include the load balancing/auxiliary loss as a part of the final loss, remember to set
<code>output_router_logits=True</code> in this dictionary.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.chat_template_path" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.chat_template_path"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>chat_template_path</strong> (<code>str</code>, <em>optional</em>) &#x2014;
If specified, sets the model&#x2019;s chat template. This can either be the path to a tokenizer (local directory
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
ensure that any special tokens referenced in the template are added to the tokenizer and that the model&#x2019;s
embedding layer is resized accordingly.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.disable_dropout" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.disable_dropout"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>disable_dropout</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>True</code>) &#x2014;
Whether to disable dropout in the model.<!-- HTML_TAG_END --> </span></span> </li> </ul><p class="flex items-center font-semibold">Parameters that control the data preprocessing <span class="flex-auto border-t-2 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.dataset_num_proc" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.dataset_num_proc"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>dataset_num_proc</strong> (<code>int</code>, <em>optional</em>) &#x2014;
Number of processes to use for processing the dataset.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.eos_token" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.eos_token"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>eos_token</strong> (<code>str</code>, <em>optional</em>) &#x2014;
Token used to indicate the end of a turn or sequence. If <code>None</code>, it defaults to
<code>processing_class.eos_token</code>.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.max_length" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.max_length"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>max_length</strong> (<code>int</code> or <code>None</code>, <em>optional</em>, defaults to <code>1024</code>) &#x2014;
Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
exceeds this value. If <code>None</code>, no filtering is applied.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.pad_to_multiple_of" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.pad_to_multiple_of"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>pad_to_multiple_of</strong> (<code>int</code>, <em>optional</em>) &#x2014;
If set, the sequences will be padded to a multiple of this value.<!-- HTML_TAG_END --> </span></span> </li> </ul><p class="flex items-center font-semibold">Parameters that control the training <span class="flex-auto border-t-2 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.center_rewards_coefficient" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.center_rewards_coefficient"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>center_rewards_coefficient</strong> (<code>float</code>, <em>optional</em>) &#x2014;
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
<a href="https://huggingface.co/papers/2312.09244" rel="nofollow">https://huggingface.co/papers/2312.09244</a>, Eq. 2). Recommended value: <code>0.01</code>.<!-- HTML_TAG_END --> </span></span> </li><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.activation_offloading" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.activation_offloading"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>activation_offloading</strong> (<code>bool</code>, <em>optional</em>, defaults to <code>False</code>) &#x2014;
Whether to offload the activations to the CPU.<!-- HTML_TAG_END --> </span></span> </li> </ul><p class="flex items-center font-semibold">Deprecated parameters <span class="flex-auto border-t-2 ml-3"></span></p> <ul class="px-2"><li class="text-base !pl-4 my-3 rounded "><span class="group flex space-x-1.5 items-start"><a id="trl.RewardConfig.pad_token" class="header-link block pr-0.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trl.RewardConfig.pad_token"><span><svg class="text-smd" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span><!-- HTML_TAG_START --><strong>pad_token</strong> &#x2014;</p>
<deprecated version="1.1.0">
<p>Parameter <code>pad_token</code> is deprecated and will be removed in version v2.0.0. Set <code>tokenizer.pad_token</code>
directly and pass it as <code>processing_class</code> to the trainer instead.</p>
</deprecated><!-- HTML_TAG_END --> </span></span> </li> </ul> </div></div> <p data-svelte-h="svelte-o9kf85">Configuration class for the <a href="/docs/trl/pr_5607/en/reward_trainer#trl.RewardTrainer">RewardTrainer</a>.</p> <p data-svelte-h="svelte-3thf4f">This class includes only the parameters that are specific to Reward training. For a full list of training
arguments, please refer to the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a> documentation. Note that default values in this
class may differ from those in <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a>.</p> <p data-svelte-h="svelte-ekuf1t">Using <a href="https://huggingface.co/docs/transformers/main/en/internal/trainer_utils#transformers.HfArgumentParser" rel="nofollow">HfArgumentParser</a> we can turn this class into
<a href="https://docs.python.org/3/library/argparse#module-argparse" rel="nofollow">argparse</a> arguments that can be specified on the
command line.</p> <blockquote class="note" data-svelte-h="svelte-mu4yq0"><p>These parameters have default values different from <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments" rel="nofollow">TrainingArguments</a>:</p> <ul><li><code>logging_steps</code>: Defaults to <code>10</code> instead of <code>500</code>.</li> <li><code>gradient_checkpointing</code>: Defaults to <code>True</code> instead of <code>False</code>.</li> <li><code>bf16</code>: Defaults to <code>True</code> if <code>fp16</code> is not set, instead of <code>False</code>.</li> <li><code>learning_rate</code>: Defaults to <code>1e-4</code> instead of <code>5e-5</code>.</li></ul></blockquote></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/reward_trainer.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1hqaf25 = {
assets: "/docs/trl/pr_5607/en",
base: "/docs/trl/pr_5607/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/trl/pr_5607/en/_app/immutable/entry/start.151d81bd.js"),
import("/docs/trl/pr_5607/en/_app/immutable/entry/app.3d9a91c0.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 48],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
175 kB
·
Xet hash:
4f68eec7ac6c3a6e256265908e7486bc01570184154a7b46b30b0bd428eb1813

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.