Buckets:

hf-doc-build/doc-dev / trl /main /en /rloo_trainer.html
rtrm's picture
download
raw
47.9 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;RLOO Trainer&quot;,&quot;local&quot;:&quot;rloo-trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Get started&quot;,&quot;local&quot;:&quot;get-started&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Explanation of the logged metrics&quot;,&quot;local&quot;:&quot;explanation-of-the-logged-metrics&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cookbook&quot;,&quot;local&quot;:&quot;cookbook&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;What is my model doing exactly?&quot;,&quot;local&quot;:&quot;what-is-my-model-doing-exactly&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Implementation details&quot;,&quot;local&quot;:&quot;implementation-details&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Benchmark experiments&quot;,&quot;local&quot;:&quot;benchmark-experiments&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/trl/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/start.183b226a.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/scheduler.85c25b89.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/singletons.98fe034d.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/paths.eb9df337.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/index.c142fe32.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/0.5efac18d.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/30.0fb52b78.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/CodeBlock.a5e95a57.js">
<link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/EditOnGithub.a592e7aa.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;RLOO Trainer&quot;,&quot;local&quot;:&quot;rloo-trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Get started&quot;,&quot;local&quot;:&quot;get-started&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Explanation of the logged metrics&quot;,&quot;local&quot;:&quot;explanation-of-the-logged-metrics&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Cookbook&quot;,&quot;local&quot;:&quot;cookbook&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;What is my model doing exactly?&quot;,&quot;local&quot;:&quot;what-is-my-model-doing-exactly&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Implementation details&quot;,&quot;local&quot;:&quot;implementation-details&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Benchmark experiments&quot;,&quot;local&quot;:&quot;benchmark-experiments&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="rloo-trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#rloo-trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>RLOO Trainer</span></h1> <p data-svelte-h="svelte-1hke7q3">TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.</p> <p data-svelte-h="svelte-k74mug">References:</p> <ul data-svelte-h="svelte-fqh62x"><li><a href="https://huggingface.co/papers/2402.14740" rel="nofollow">Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs</a></li> <li><a href="https://huggingface.co/papers/2205.09123" rel="nofollow">A2C is a special case of PPO</a></li> <li><a href="https://github.com/openai/lm-human-preferences" rel="nofollow">Fine-Tuning Language Models from Human Preferences</a></li> <li><a href="https://github.com/openai/summarize-from-feedback" rel="nofollow">Learning to Summarize from Human Feedback</a></li> <li><a href="https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo" rel="nofollow">The N Implementation Details of RLHF with PPO</a></li> <li><a href="https://huggingface.co/papers/2403.17031" rel="nofollow">The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization</a></li></ul> <h2 class="relative group"><a id="get-started" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#get-started"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Get started</span></h2> <p data-svelte-h="svelte-dp2pj8">To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->python examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="explanation-of-the-logged-metrics" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#explanation-of-the-logged-metrics"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Explanation of the logged metrics</span></h2> <p data-svelte-h="svelte-9rhmys">The logged metrics are as follows. Here is an example <a href="https://wandb.ai/huggingface/trl/runs/u2sqci34" rel="nofollow">tracked run at Weights and Biases</a></p> <ul data-svelte-h="svelte-8uf9kn"><li><code>eps</code>: Tracks the number of episodes per second.</li> <li><code>objective/kl</code>: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.</li> <li><code>objective/entropy</code>: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.</li> <li><code>objective/non_score_reward</code>: The mean reward from non-score-related sources, basically <code>beta * kl.sum(1)</code>, where <code>beta</code> is the KL penalty coefficient and <code>kl</code> is the per-token KL divergence.</li> <li><code>objective/rlhf_reward</code>: The mean RLHF reward, which is <code>score - non_score_reward</code>.</li> <li><code>objective/scores</code>: The mean scores returned by the reward model / environment.</li> <li><code>policy/approxkl_avg</code>: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as <code>objective/kl</code>.</li> <li><code>policy/clipfrac_avg</code>: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.</li> <li><code>loss/policy_avg</code>: The average policy loss, indicating how well the policy is performing.</li> <li><code>val/clipfrac_avg</code>: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.</li> <li><code>policy/entropy_avg</code>: The average entropy of the policy during training, indicating how diverse the policy’s actions are.</li> <li><code>val/ratio</code>: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.</li> <li><code>val/ratio_var</code>: The variance of the <code>val/ratio</code>, indicating the variability in policy changes.</li> <li><code>val/num_eos_tokens</code>: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.</li> <li><code>lr</code>: lr: The current learning rate used by the optimizer.</li> <li><code>episode</code>: episode: The current global step or episode count in the training process.</li></ul> <h2 class="relative group"><a id="cookbook" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cookbook"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Cookbook</span></h2> <ul data-svelte-h="svelte-76e0mb"><li>Debugging TIP: <code>objective/rlhf_reward</code>: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.</li> <li>Debugging TIP: <code>val/ratio</code>: this number should float around 1.0, and it gets clipped by <code>--cliprange 0.2</code> with PPO’s surrogate loss. So if this <code>ratio</code> is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.</li> <li>Memory TIP: If you are running out of memory, you can try to reduce the <code>--per_device_train_batch_size</code> or increase the <code>--gradient_accumulation_steps</code> to reduce the memory footprint.</li> <li>Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint <code>accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml</code>.</li> <li>Usage TIP: We recommend to use the “EOS trick” via <code>--missing_eos_penalty</code>, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.</li></ul> <h2 class="relative group"><a id="what-is-my-model-doing-exactly" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#what-is-my-model-doing-exactly"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>What is my model doing exactly?</span></h2> <p data-svelte-h="svelte-1x3ha3h">To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example <a href="https://wandb.ai/huggingface/trl/runs/u2sqci34" rel="nofollow">tracked run at Weights and Biases</a>, it looks like the following, allowing you to see the model’s response at different stages of training. By default we generate <code>--num_sample_generations 10</code> during training, but you can customize the number of generations.</p> <p data-svelte-h="svelte-15f99oh"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif"></p> <p data-svelte-h="svelte-98m2ov">In the logs the sampled generations look like</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ SUBREDDIT: r/AskReddit │ I&#x27;m <span class="hljs-keyword">in</span> love <span class="hljs-keyword">with</span> a friend, <span class="hljs-keyword">and</span><span class="hljs-number">3.921875</span>
│ │ I don&#x27;t know how <span class="hljs-keyword">to</span> <span class="hljs-keyword">get</span> rid <span class="hljs-keyword">of</span> │ │
│ TITLE: How do you <span class="hljs-keyword">get</span> someone │ those feelings. I&#x27;m │ │
<span class="hljs-keyword">out of</span> your head? │ desperate.&lt;|endoftext|&gt;[PAD][P… │ │
│ │ │ │
│ POST: Hi, │ │ │
│ I&#x27;m <span class="hljs-number">22</span>, <span class="hljs-keyword">and</span> I have been <span class="hljs-keyword">with</span> <span class="hljs-keyword">my</span> │ │ │
│ girlfriend <span class="hljs-keyword">for</span> <span class="hljs-number">5</span> years now. We │ │ │
│ recently moved together. We&#x27;ve │ │ │
│ always loved each other │ │ │
│ intensely. │ │ │
│ │ │ │
│ Problem, I recently started <span class="hljs-keyword">to</span> │ │ │
│ have feelings <span class="hljs-keyword">for</span> an other │ │ │
│ person (a friend). This person │ │ │
│ has had a boyfriend <span class="hljs-keyword">for</span> now <span class="hljs-number">3</span> │ │ │
│ years, <span class="hljs-keyword">and</span> has absolutely no │ │ │
│ ideas. Those feelings were so │ │ │
│ strong, <span class="hljs-keyword">it</span> was hard <span class="hljs-keyword">to</span> hide │ │ │
│ them. After <span class="hljs-number">2</span> months <span class="hljs-keyword">of</span> <span class="hljs-keyword">me</span> │ │ │
│ being distant <span class="hljs-keyword">and</span> really sad, │ │ │
<span class="hljs-keyword">my</span> girlfriend forced <span class="hljs-keyword">me</span> <span class="hljs-keyword">to</span> <span class="hljs-built_in">say</span> │ │ │
│ what was bothering <span class="hljs-keyword">me</span>. I&#x27;m <span class="hljs-keyword">not</span> │ │ │
│ a good liar, <span class="hljs-keyword">and</span> now she knows. │ │ │
│ │ │ │
│ We decided <span class="hljs-keyword">to</span> give us a week │ │ │
│ alone, I went <span class="hljs-keyword">to</span> <span class="hljs-keyword">my</span> parents. │ │ │
│ │ │ │
│ Now, I&#x27;m completely lost. I │ │ │
│ keep <span class="hljs-keyword">on</span> thinking <span class="hljs-keyword">about</span> this │ │ │
│ person, <span class="hljs-keyword">and</span> I hate <span class="hljs-keyword">that</span>. I │ │ │
│ would like <span class="hljs-keyword">for</span> those feelings │ │ │
<span class="hljs-keyword">to</span> go away, <span class="hljs-keyword">to</span> leave <span class="hljs-keyword">me</span> alone. │ │ │
│ But I can&#x27;t. │ │ │
│ │ │ │
│ What do I do? It&#x27;s been <span class="hljs-number">3</span> │ │ │
│ months now, <span class="hljs-keyword">and</span> I&#x27;m just │ │ │
│ desperate. │ │ │
│ │ │ │
│ TL;DR: │ │ │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│ SUBREDDIT: r/pettyrevenge │ My mom woke <span class="hljs-keyword">me</span> up <span class="hljs-keyword">with</span> a loud │ <span class="hljs-number">6.84375</span>
│ │ TV. I blasted Gangnam Style <span class="hljs-keyword">on</span> │ │
│ TITLE: So, <span class="hljs-keyword">my</span> mom woke <span class="hljs-keyword">me</span> up │ <span class="hljs-keyword">repeat</span>, <span class="hljs-keyword">with</span> <span class="hljs-keyword">the</span> bass cranked │ │
<span class="hljs-keyword">with</span> a loud TV. │ up <span class="hljs-keyword">as</span> high <span class="hljs-keyword">as</span> <span class="hljs-keyword">it</span> could │ │
│ │ go.&lt;|endoftext|&gt;[PAD][PAD][PAD… │ │
│ POST: She was <span class="hljs-keyword">in</span> her living │ │ │
│ room, watching TV. This was <span class="hljs-keyword">at</span> │ │ │
<span class="hljs-keyword">about</span> <span class="hljs-number">8</span>:<span class="hljs-number">30</span> <span class="hljs-keyword">in</span> <span class="hljs-keyword">the</span> morning, <span class="hljs-keyword">and</span> │ │ │
│ she was exercising. She turned │ │ │
<span class="hljs-keyword">the</span> TV up extra loud <span class="hljs-keyword">to</span> hear <span class="hljs-keyword">it</span> │ │ │
<span class="hljs-keyword">over</span> her excercycle, <span class="hljs-keyword">and</span> woke │ │ │
<span class="hljs-keyword">me</span> up. I went <span class="hljs-keyword">in</span> there asking │ │ │
<span class="hljs-keyword">for</span> her <span class="hljs-keyword">to</span> turn <span class="hljs-keyword">it</span> down. She │ │ │
│ said she didn&#x27;t have <span class="hljs-keyword">to</span>; I │ │ │
│ explained <span class="hljs-keyword">that</span> I always used │ │ │
│ headphones so she didn&#x27;t have │ │ │
<span class="hljs-keyword">to</span> deal <span class="hljs-keyword">with</span> <span class="hljs-keyword">my</span> noise <span class="hljs-keyword">and</span> <span class="hljs-keyword">that</span> │ │ │
│ she should give <span class="hljs-keyword">me</span> a little │ │ │
│ more respect, <span class="hljs-keyword">given</span> <span class="hljs-keyword">that</span> I paid │ │ │
│ rent <span class="hljs-keyword">at</span> <span class="hljs-keyword">the</span> <span class="hljs-built_in">time</span>. │ │ │
│ │ │ │
│ She disagreed. I went <span class="hljs-keyword">back</span> <span class="hljs-keyword">to</span> │ │ │
<span class="hljs-keyword">my</span> room, rather pissed off <span class="hljs-keyword">at</span> │ │ │
<span class="hljs-keyword">the</span> lack <span class="hljs-keyword">of</span> equality. I had no │ │ │
│ lock <span class="hljs-keyword">on</span> <span class="hljs-keyword">my</span> door; <span class="hljs-keyword">but</span> I had a │ │ │
│ dresser right next <span class="hljs-keyword">to</span> <span class="hljs-keyword">it</span>, so I │ │ │
│ pulled one <span class="hljs-keyword">of</span> <span class="hljs-keyword">the</span> drawers out │ │ │
│ enough so <span class="hljs-keyword">that</span> <span class="hljs-keyword">it</span> caused <span class="hljs-keyword">the</span> │ │ │
│ door <span class="hljs-keyword">to</span> <span class="hljs-keyword">not</span> be openable. Then, │ │ │
│ I turned <span class="hljs-keyword">my</span> speakers up really │ │ │
│ loud <span class="hljs-keyword">and</span> blasted Gangnam Style │ │ │
<span class="hljs-keyword">on</span> <span class="hljs-keyword">repeat</span>, <span class="hljs-keyword">with</span> <span class="hljs-keyword">the</span> bass │ │ │
│ cranked up <span class="hljs-keyword">as</span> high <span class="hljs-keyword">as</span> <span class="hljs-keyword">it</span> could │ │ │
│ go. │ │ │
│ │ │ │
│ If you hate Gangnam Style <span class="hljs-keyword">for</span> │ │ │
│ being overplayed, you will see │ │ │
│ why I chose <span class="hljs-keyword">that</span> particular │ │ │
│ song. I personally don&#x27;t mind │ │ │
<span class="hljs-keyword">it</span>. But here&#x27;s <span class="hljs-keyword">the</span> thing <span class="hljs-keyword">about</span> │ │ │
<span class="hljs-keyword">my</span> bass; <span class="hljs-keyword">it</span> vibrates <span class="hljs-keyword">the</span> walls, │ │ │
│ making one hell <span class="hljs-keyword">of</span> a lot <span class="hljs-keyword">of</span> │ │ │
│ noise. Needless <span class="hljs-keyword">to</span> <span class="hljs-built_in">say</span>, <span class="hljs-keyword">my</span> mom │ │ │
│ was <span class="hljs-keyword">not</span> pleased <span class="hljs-keyword">and</span> shut off │ │ │
<span class="hljs-keyword">the</span> internet. But <span class="hljs-keyword">it</span> was oh so │ │ │
│ worth <span class="hljs-keyword">it</span>. │ │ │
│ │ │ │
│ TL;DR: │ │ │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="implementation-details" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#implementation-details"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Implementation details</span></h2> <p data-svelte-h="svelte-yp8drw">The bulk of RLOOTrainer is based on the PPO implementation, which is based on the <a href="https://huggingface.co/papers/2403.17031" rel="nofollow">The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization</a>.</p> <p data-svelte-h="svelte-ytc2o8">Below is a vectorized advantage calculation for RLOO:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">test_rloo_reward</span>():
local_batch_size = <span class="hljs-number">3</span>
rloo_k = <span class="hljs-number">4</span>
rlhf_reward = torch.tensor([
<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-comment"># first rlhf reward for three prompts</span>
<span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">4</span>, <span class="hljs-comment"># second rlhf reward for three prompts</span>
<span class="hljs-number">5</span>, <span class="hljs-number">6</span>, <span class="hljs-number">7</span>, <span class="hljs-comment"># third rlhf reward for three prompts</span>
<span class="hljs-number">8</span>, <span class="hljs-number">9</span>, <span class="hljs-number">10</span>, <span class="hljs-comment"># fourth rlhf reward for three prompts</span>
]).<span class="hljs-built_in">float</span>() <span class="hljs-comment"># here we have 3 prompts which have 4 completions each</span>
baseline = (rlhf_reward.<span class="hljs-built_in">sum</span>(<span class="hljs-number">0</span>) - rlhf_reward) / (rloo_k - <span class="hljs-number">1</span>)
advantages = torch.zeros_like(rlhf_reward)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, <span class="hljs-built_in">len</span>(advantages), local_batch_size):
other_response_rlhf_rewards = []
<span class="hljs-keyword">for</span> j <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, <span class="hljs-built_in">len</span>(advantages), local_batch_size):
<span class="hljs-keyword">if</span> i != j:
other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(<span class="hljs-number">0</span>)
<span class="hljs-keyword">assert</span> (<span class="hljs-number">1</span> - (<span class="hljs-number">2</span> + <span class="hljs-number">5</span> + <span class="hljs-number">8</span>) / <span class="hljs-number">3</span> - advantages[<span class="hljs-number">0</span>].item()) &lt; <span class="hljs-number">1e-6</span> <span class="hljs-comment"># First rlhf reward for the first prompt</span>
<span class="hljs-keyword">assert</span> (<span class="hljs-number">6</span> - (<span class="hljs-number">3</span> + <span class="hljs-number">2</span> + <span class="hljs-number">9</span>) / <span class="hljs-number">3</span> - advantages[<span class="hljs-number">7</span>].item()) &lt; <span class="hljs-number">1e-6</span> <span class="hljs-comment"># Third rlhf reward for the second prompt</span>
<span class="hljs-comment"># Vectorized implementation</span>
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.<span class="hljs-built_in">sum</span>(<span class="hljs-number">0</span>) - rlhf_reward) / (rloo_k - <span class="hljs-number">1</span>)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="benchmark-experiments" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#benchmark-experiments"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Benchmark experiments</span></h2> <p data-svelte-h="svelte-po50kh">To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from <a href="https://huggingface.co/papers/2403.17031" rel="nofollow">The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch <span class="hljs-attr">--config_file</span> examples/accelerate_configs/deepspeed_zero2<span class="hljs-selector-class">.yaml</span> \
examples/scripts/rloo/rloo_tldr<span class="hljs-selector-class">.py</span> \
<span class="hljs-attr">--output_dir</span> models/minimal/rloo_tldr \
<span class="hljs-attr">--num_ppo_epochs</span> <span class="hljs-number">2</span> \
<span class="hljs-attr">--num_mini_batches</span> <span class="hljs-number">2</span> \
<span class="hljs-attr">--learning_rate</span> <span class="hljs-number">3</span>e-<span class="hljs-number">6</span> \
<span class="hljs-attr">--per_device_train_batch_size</span> <span class="hljs-number">8</span> \
<span class="hljs-attr">--gradient_accumulation_steps</span> <span class="hljs-number">8</span> \
<span class="hljs-attr">--total_episodes</span> <span class="hljs-number">1000000</span> \
<span class="hljs-attr">--model_name_or_path</span> EleutherAI/pythia-<span class="hljs-number">1</span>b-deduped \
<span class="hljs-attr">--sft_model_path</span> cleanrl/EleutherAI_pythia-<span class="hljs-number">1</span>b-deduped__sft__tldr \
<span class="hljs-attr">--reward_model_path</span> cleanrl/EleutherAI_pythia-<span class="hljs-number">1</span>b-deduped__reward__tldr \
<span class="hljs-attr">--local_rollout_forward_batch_size</span> <span class="hljs-number">16</span> \
<span class="hljs-attr">--missing_eos_penalty</span> <span class="hljs-number">1.0</span> \
<span class="hljs-attr">--stop_token</span> eos \
<span class="hljs-attr">--kl_coef</span> <span class="hljs-number">0.03</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yl0uf2">Checkpoints and experiment tracking are available at:</p> <ul data-svelte-h="svelte-61xz2y"><li><a href="https://huggingface.co/vwxyzjn/rloo_tldr" rel="nofollow">🤗 Model checkpoint</a></li> <li><a href="https://wandb.ai/huggingface/trl/runs/u2sqci34" rel="nofollow">🐝 Tracked experiment</a></li></ul> <p data-svelte-h="svelte-19n72p9">To evaluate, we use <a href="https://github.com/vllm-project/vllm" rel="nofollow">vLLM</a> to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see <a href="judges">Judges</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 51.20%<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-101qmt1">The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.</p> <p data-svelte-h="svelte-1qb748v">Metrics:</p> <p data-svelte-h="svelte-gnbqmv"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/rloo.png"></p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># pip install openrlbenchmark==0.2.1a5</span>
<span class="hljs-comment"># see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation</span>
<span class="hljs-comment"># to use it, change `?we=huggingface&amp;wpn=trl` to your own project and `?tag=pr-1540` to your own tag</span>
python -m openrlbenchmark.rlops_multi_metrics \
--filters <span class="hljs-string">&#x27;?we=huggingface&amp;wpn=trl&amp;xaxis=train/episode&amp;ceik=output_dir&amp;cen=sft_model_path&amp;metrics=train/objective/rlhf_reward&amp;metrics=train/objective/scores&amp;metrics=train/objective/kl&amp;metrics=train/objective/non_score_reward&amp;metrics=train/objective/entropy&amp;metrics=train/policy/approxkl_avg&amp;metrics=train/policy/clipfrac_avg&amp;metrics=train/loss/policy_avg&amp;metrics=train/policy/entropy_avg&amp;metrics=train/val/ratio&amp;metrics=train/val/ratio_var&amp;metrics=train/val/num_eos_tokens&amp;metrics=train/lr&amp;metrics=train/eps&#x27;</span> \
<span class="hljs-string">&quot;cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540&quot;</span> \
--env-ids models/minimal/rloo_tldr \
--pc.ncols 4 \
--pc.ncols-legend 1 \
--pc.xlabel <span class="hljs-string">&quot;Episode&quot;</span> \
--output-filename benchmark/trl/pr-1540/rloo \
--scan-history<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/rloo_trainer.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_5yobsv = {
assets: "/docs/trl/main/en",
base: "/docs/trl/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"),
import("/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 30],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
47.9 kB
·
Xet hash:
ffa4efc5e53776d3867b1d0c0c605f9d5f678effb49df85e586af4045d9a634e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.