Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Training customization","local":"training-customization","sections":[{"title":"Train on multiple GPUs / nodes","local":"train-on-multiple-gpus--nodes","sections":[{"title":"Distributed training with DeepSpeed","local":"distributed-training-with-deepspeed","sections":[],"depth":3}],"depth":2},{"title":"Use different optimizers","local":"use-different-optimizers","sections":[{"title":"Use LION optimizer","local":"use-lion-optimizer","sections":[],"depth":3}],"depth":2},{"title":"Add a learning rate scheduler","local":"add-a-learning-rate-scheduler","sections":[],"depth":2},{"title":"Memory efficient fine-tuning by sharing layers","local":"memory-efficient-fine-tuning-by-sharing-layers","sections":[],"depth":2},{"title":"Pass 8-bit reference models","local":"pass-8-bit-reference-models","sections":[],"depth":2},{"title":"Use the CUDA cache optimizer","local":"use-the-cuda-cache-optimizer","sections":[],"depth":2},{"title":"Use score scaling/normalization/clipping","local":"use-score-scalingnormalizationclipping","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/trl/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/scheduler.85c25b89.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/singletons.98fe034d.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/paths.eb9df337.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/index.c142fe32.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/0.5efac18d.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/nodes/8.d12ebc10.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/CodeBlock.a5e95a57.js"> | |
| <link rel="modulepreload" href="/docs/trl/main/en/_app/immutable/chunks/EditOnGithub.a592e7aa.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Training customization","local":"training-customization","sections":[{"title":"Train on multiple GPUs / nodes","local":"train-on-multiple-gpus--nodes","sections":[{"title":"Distributed training with DeepSpeed","local":"distributed-training-with-deepspeed","sections":[],"depth":3}],"depth":2},{"title":"Use different optimizers","local":"use-different-optimizers","sections":[{"title":"Use LION optimizer","local":"use-lion-optimizer","sections":[],"depth":3}],"depth":2},{"title":"Add a learning rate scheduler","local":"add-a-learning-rate-scheduler","sections":[],"depth":2},{"title":"Memory efficient fine-tuning by sharing layers","local":"memory-efficient-fine-tuning-by-sharing-layers","sections":[],"depth":2},{"title":"Pass 8-bit reference models","local":"pass-8-bit-reference-models","sections":[],"depth":2},{"title":"Use the CUDA cache optimizer","local":"use-the-cuda-cache-optimizer","sections":[],"depth":2},{"title":"Use score scaling/normalization/clipping","local":"use-score-scalingnormalizationclipping","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="training-customization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-customization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training customization</span></h1> <p data-svelte-h="svelte-1fi3h7b">TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.</p> <h2 class="relative group"><a id="train-on-multiple-gpus--nodes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#train-on-multiple-gpus--nodes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Train on multiple GPUs / nodes</span></h2> <p data-svelte-h="svelte-ctt0hb">The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ln7prb">and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch your_script.py<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15pg0q3">We also provide config files in the <a href="https://github.com/huggingface/trl/tree/main/examples/accelerate_configs" rel="nofollow">examples folder</a> that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-la5c3h">Refer to the <a href="https://github.com/huggingface/trl/tree/main/examples" rel="nofollow">examples page</a> for more details.</p> <h3 class="relative group"><a id="distributed-training-with-deepspeed" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#distributed-training-with-deepspeed"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Distributed training with DeepSpeed</span></h3> <p data-svelte-h="svelte-190jsut">All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-e8r4tv">Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the <code>zero3_init_context_manager()</code> context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the <a href="https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py" rel="nofollow"><code>sentiment_tuning</code></a> example:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin | |
| <span class="hljs-keyword">if</span> ds_plugin <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> ds_plugin.is_zero3_init_enabled(): | |
| <span class="hljs-keyword">with</span> ds_plugin.zero3_init_context_manager(enable=<span class="hljs-literal">False</span>): | |
| sentiment_pipe = pipeline(<span class="hljs-string">"sentiment-analysis"</span>, model=<span class="hljs-string">"lvwerra/distilbert-imdb"</span>, device=device) | |
| <span class="hljs-keyword">else</span>: | |
| sentiment_pipe = pipeline(<span class="hljs-string">"sentiment-analysis"</span>, model=<span class="hljs-string">"lvwerra/distilbert-imdb"</span>, device=device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hhlpqt">Consult the 🤗 Accelerate <a href="https://huggingface.co/docs/accelerate/usage_guides/deepspeed" rel="nofollow">documentation</a> for more information about the DeepSpeed plugin.</p> <h2 class="relative group"><a id="use-different-optimizers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-different-optimizers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use different optimizers</span></h2> <p data-svelte-h="svelte-lmd9ef">By default, the <code>PPOTrainer</code> creates a <code>torch.optim.Adam</code> optimizer. You can create and define a different optimizer and pass it to <code>PPOTrainer</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> GPT2Tokenizer | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| <span class="hljs-comment"># 1. load a pretrained model</span> | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| tokenizer = GPT2Tokenizer.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| <span class="hljs-comment"># 2. define config</span> | |
| ppo_config = {<span class="hljs-string">'batch_size'</span>: <span class="hljs-number">1</span>, <span class="hljs-string">'learning_rate'</span>:<span class="hljs-number">1e-5</span>} | |
| config = PPOConfig(**ppo_config) | |
| <span class="hljs-comment"># 2. Create optimizer</span> | |
| optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) | |
| <span class="hljs-comment"># 3. initialize trainer</span> | |
| ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jl9zz3">For memory efficient fine-tuning, you can also pass <code>Adam8bit</code> optimizer from <code>bitsandbytes</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">import</span> bitsandbytes <span class="hljs-keyword">as</span> bnb | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> GPT2Tokenizer | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| <span class="hljs-comment"># 1. load a pretrained model</span> | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| tokenizer = GPT2Tokenizer.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| <span class="hljs-comment"># 2. define config</span> | |
| ppo_config = {<span class="hljs-string">'batch_size'</span>: <span class="hljs-number">1</span>, <span class="hljs-string">'learning_rate'</span>:<span class="hljs-number">1e-5</span>} | |
| config = PPOConfig(**ppo_config) | |
| <span class="hljs-comment"># 2. Create optimizer</span> | |
| optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate) | |
| <span class="hljs-comment"># 3. initialize trainer</span> | |
| ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="use-lion-optimizer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-lion-optimizer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use LION optimizer</span></h3> <p data-svelte-h="svelte-wheov0">You can use the new <a href="https://huggingface.co/papers/2302.06675" rel="nofollow">LION optimizer from Google</a> as well, first take the source code of the optimizer definition <a href="https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py" rel="nofollow">here</a>, and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->optimizer = Lion(<span class="hljs-built_in">filter</span>(<span class="hljs-keyword">lambda</span> p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) | |
| ... | |
| ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1wncfxf">We advise you to use the learning rate that you would use for <code>Adam</code> divided by 3 as pointed out <a href="https://github.com/lucidrains/lion-pytorch#lion---pytorch" rel="nofollow">here</a>. We observed an improvement when using this optimizer compared to classic Adam (check the full logs <a href="https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada" rel="nofollow">here</a>):</p> <div style="text-align: center" data-svelte-h="svelte-knha24"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png"></div> <h2 class="relative group"><a id="add-a-learning-rate-scheduler" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#add-a-learning-rate-scheduler"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Add a learning rate scheduler</span></h2> <p data-svelte-h="svelte-1tamlje">You can also play with your training by adding learning rate schedulers!</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> GPT2Tokenizer | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| <span class="hljs-comment"># 1. load a pretrained model</span> | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| tokenizer = GPT2Tokenizer.from_pretrained(<span class="hljs-string">'gpt2'</span>) | |
| <span class="hljs-comment"># 2. define config</span> | |
| ppo_config = {<span class="hljs-string">'batch_size'</span>: <span class="hljs-number">1</span>, <span class="hljs-string">'learning_rate'</span>:<span class="hljs-number">1e-5</span>} | |
| config = PPOConfig(**ppo_config) | |
| <span class="hljs-comment"># 2. Create optimizer</span> | |
| optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) | |
| lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=<span class="hljs-number">0.9</span>) | |
| <span class="hljs-comment"># 3. initialize trainer</span> | |
| ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="memory-efficient-fine-tuning-by-sharing-layers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#memory-efficient-fine-tuning-by-sharing-layers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Memory efficient fine-tuning by sharing layers</span></h2> <p data-svelte-h="svelte-hswwhh">Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model | |
| <span class="hljs-comment"># 1. load a pretrained model</span> | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'bigscience/bloom-560m'</span>) | |
| ref_model = create_reference_model(model, num_shared_layers=<span class="hljs-number">6</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">'bigscience/bloom-560m'</span>) | |
| <span class="hljs-comment"># 2. initialize trainer</span> | |
| ppo_config = {<span class="hljs-string">'batch_size'</span>: <span class="hljs-number">1</span>} | |
| config = PPOConfig(**ppo_config) | |
| ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="pass-8-bit-reference-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pass-8-bit-reference-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Pass 8-bit reference models</span></h2> <div data-svelte-h="svelte-n87ems"><p>Since <code>trl</code> supports all key word arguments when loading a model from <code>transformers</code> using <code>from_pretrained</code>, you can also leverage <code>load_in_8bit</code> from <code>transformers</code> for more memory efficient fine-tuning.</p> <p>Read more about 8-bit model loading in <code>transformers</code> <a href="https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition" rel="nofollow">here</a>.</p></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 0. imports</span> | |
| <span class="hljs-comment"># pip install bitsandbytes</span> | |
| <span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer | |
| <span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| <span class="hljs-comment"># 1. load a pretrained model</span> | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'bigscience/bloom-560m'</span>) | |
| ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(<span class="hljs-string">'bigscience/bloom-560m'</span>, device_map=<span class="hljs-string">"auto"</span>, load_in_8bit=<span class="hljs-literal">True</span>) | |
| tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">'bigscience/bloom-560m'</span>) | |
| <span class="hljs-comment"># 2. initialize trainer</span> | |
| ppo_config = {<span class="hljs-string">'batch_size'</span>: <span class="hljs-number">1</span>} | |
| config = PPOConfig(**ppo_config) | |
| ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="use-the-cuda-cache-optimizer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-the-cuda-cache-optimizer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use the CUDA cache optimizer</span></h2> <p data-svelte-h="svelte-kivsgi">When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass <code>optimize_cuda_cache=True</code> to <code>PPOConfig</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->config = PPOConfig(..., optimize_cuda_cache=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="use-score-scalingnormalizationclipping" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#use-score-scalingnormalizationclipping"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Use score scaling/normalization/clipping</span></h2> <p data-svelte-h="svelte-dhbjox">As suggested by <a href="https://huggingface.co/papers/2307.04964" rel="nofollow">Secrets of RLHF in Large Language Models Part I: PPO</a>, we support score (aka reward) scaling/normalization/clipping to improve training stability via <code>PPOConfig</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> PPOConfig | |
| ppo_config = { | |
| use_score_scaling=<span class="hljs-literal">True</span>, | |
| use_score_norm=<span class="hljs-literal">True</span>, | |
| score_clip=<span class="hljs-number">0.5</span>, | |
| } | |
| config = PPOConfig(**ppo_config)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hdjzdk">To run <code>ppo.py</code>, you can use the following command:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->python examples/scripts/ppo<span class="hljs-selector-class">.py</span> <span class="hljs-attr">--log_with</span> wandb <span class="hljs-attr">--use_score_scaling</span> <span class="hljs-attr">--use_score_norm</span> <span class="hljs-attr">--score_clip</span> <span class="hljs-number">0.5</span><!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/trl/blob/main/docs/source/customization.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_5yobsv = { | |
| assets: "/docs/trl/main/en", | |
| base: "/docs/trl/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/trl/main/en/_app/immutable/entry/start.183b226a.js"), | |
| import("/docs/trl/main/en/_app/immutable/entry/app.9853b7f5.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 8], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 47.2 kB
- Xet hash:
- be99d649d4901667f9330ce0a4ce88400daec15ccf8f120252361de37c2ea555
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.