Buckets:

rtrm's picture
download
raw
97.8 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Trainer&quot;,&quot;local&quot;:&quot;trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Basic usage&quot;,&quot;local&quot;:&quot;basic-usage&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Checkpoints&quot;,&quot;local&quot;:&quot;checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Customize the Trainer&quot;,&quot;local&quot;:&quot;customize-the-trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Callbacks&quot;,&quot;local&quot;:&quot;callbacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Logging&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;NEFTune&quot;,&quot;local&quot;:&quot;neftune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Liger Kernel&quot;,&quot;local&quot;:&quot;liger-kernel&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Optimizers&quot;,&quot;local&quot;:&quot;optimizers&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;GaLore&quot;,&quot;local&quot;:&quot;galore&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;LOMO optimizer&quot;,&quot;local&quot;:&quot;lomo-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GrokAdamW optimizer&quot;,&quot;local&quot;:&quot;grokadamw-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Schedule Free Optimizer&quot;,&quot;local&quot;:&quot;schedule-free-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Accelerate and Trainer&quot;,&quot;local&quot;:&quot;accelerate-and-trainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_33913/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/scheduler.25b97de1.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/singletons.62a184e0.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.e188933d.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/paths.51881b9e.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/index.d9030fc9.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/0.05e395f5.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/nodes/452.de0effe6.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/Tip.baa67368.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/CodeBlock.e6cd0d95.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/EditOnGithub.91d95064.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/HfOption.1e589c90.js">
<link rel="modulepreload" href="/docs/transformers/pr_33913/en/_app/immutable/chunks/stores.c3f24f16.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Trainer&quot;,&quot;local&quot;:&quot;trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Basic usage&quot;,&quot;local&quot;:&quot;basic-usage&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Checkpoints&quot;,&quot;local&quot;:&quot;checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Customize the Trainer&quot;,&quot;local&quot;:&quot;customize-the-trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Callbacks&quot;,&quot;local&quot;:&quot;callbacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Logging&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;NEFTune&quot;,&quot;local&quot;:&quot;neftune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Liger Kernel&quot;,&quot;local&quot;:&quot;liger-kernel&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Optimizers&quot;,&quot;local&quot;:&quot;optimizers&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;GaLore&quot;,&quot;local&quot;:&quot;galore&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;LOMO optimizer&quot;,&quot;local&quot;:&quot;lomo-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GrokAdamW optimizer&quot;,&quot;local&quot;:&quot;grokadamw-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Schedule Free Optimizer&quot;,&quot;local&quot;:&quot;schedule-free-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Accelerate and Trainer&quot;,&quot;local&quot;:&quot;accelerate-and-trainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Trainer</span></h1> <p data-svelte-h="svelte-srtlbx">The <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> is a complete training and evaluation loop for PyTorch models implemented in the Transformers library. You only need to pass it the necessary pieces for training (model, tokenizer, dataset, evaluation function, training hyperparameters, etc.), and the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class takes care of the rest. This makes it easier to start training faster without manually writing your own training loop. But at the same time, <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> is very customizable and offers a ton of training options so you can tailor it to your exact training needs.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-12jzg5">In addition to the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class, Transformers also provides a <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Seq2SeqTrainer">Seq2SeqTrainer</a> class for sequence-to-sequence tasks like translation or summarization. There is also the <a href="https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTTrainer" rel="nofollow">SFTTrainer</a> class from the <a href="https://hf.co/docs/trl" rel="nofollow">TRL</a> library which wraps the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class and is optimized for training language models like Llama-2 and Mistral with autoregressive techniques. <a href="https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTTrainer" rel="nofollow">SFTTrainer</a> also supports features like sequence packing, LoRA, quantization, and DeepSpeed for efficiently scaling to any model size.</p> <br> <p data-svelte-h="svelte-7ytry2">Feel free to check out the <a href="./main_classes/trainer">API reference</a> for these other <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a>-type classes to learn more about when to use which one. In general, <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> is the most versatile option and is appropriate for a broad spectrum of tasks. <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Seq2SeqTrainer">Seq2SeqTrainer</a> is designed for sequence-to-sequence tasks and <a href="https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTTrainer" rel="nofollow">SFTTrainer</a> is designed for training language models.</p></div> <p data-svelte-h="svelte-17cwfvo">Before you start, make sure <a href="https://hf.co/docs/accelerate" rel="nofollow">Accelerate</a> - a library for enabling and running PyTorch training across distributed environments - is installed.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install accelerate
<span class="hljs-comment"># upgrade</span>
pip install accelerate --upgrade<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-a82l3h">This guide provides an overview of the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class.</p> <h2 class="relative group"><a id="basic-usage" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#basic-usage"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Basic usage</span></h2> <p data-svelte-h="svelte-1yl3pjc"><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> includes all the code you’ll find in a basic training loop:</p> <ol data-svelte-h="svelte-300ub0"><li>perform a training step to calculate the loss</li> <li>calculate the gradients with the <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.backward" rel="nofollow">backward</a> method</li> <li>update the weights based on the gradients</li> <li>repeat this process until you’ve reached a predetermined number of epochs</li></ol> <p data-svelte-h="svelte-ts83c3">The <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class abstracts all of this code away so you don’t have to worry about manually writing a training loop every time or if you’re just getting started with PyTorch and training. You only need to provide the essential components required for training, such as a model and a dataset, and the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class handles everything else.</p> <p data-svelte-h="svelte-3go597">If you want to specify any training options or hyperparameters, you can find them in the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> class. For example, let’s define where to save the model in <code>output_dir</code> and push the model to the Hub after training with <code>push_to_hub=True</code>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;your-model&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
per_device_train_batch_size=<span class="hljs-number">16</span>,
per_device_eval_batch_size=<span class="hljs-number">16</span>,
num_train_epochs=<span class="hljs-number">2</span>,
weight_decay=<span class="hljs-number">0.01</span>,
eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
load_best_model_at_end=<span class="hljs-literal">True</span>,
push_to_hub=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lvnx9p">Pass <code>training_args</code> to the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> along with a model, dataset, something to preprocess the dataset with (depending on your data type it could be a tokenizer, feature extractor or image processor), a data collator, and a function to compute the metrics you want to track during training.</p> <p data-svelte-h="svelte-c44now">Finally, call <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.train">train()</a> to start training!</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=dataset[<span class="hljs-string">&quot;test&quot;</span>],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="checkpoints" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#checkpoints"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Checkpoints</span></h3> <p data-svelte-h="svelte-f2gz59">The <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class saves your model checkpoints to the directory specified in the <code>output_dir</code> parameter of <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. You’ll find the checkpoints saved in a <code>checkpoint-000</code> subfolder where the numbers at the end correspond to the training step. Saving checkpoints are useful for resuming training later.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># resume from latest checkpoint</span>
trainer.train(resume_from_checkpoint=<span class="hljs-literal">True</span>)
<span class="hljs-comment"># resume from specific checkpoint saved in output directory</span>
trainer.train(resume_from_checkpoint=<span class="hljs-string">&quot;your-model/checkpoint-1000&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-p7csl6">You can save your checkpoints (the optimizer state is not saved by default) to the Hub by setting <code>push_to_hub=True</code> in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to commit and push them. Other options for deciding how your checkpoints are saved are set up in the <a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.hub_strategy" rel="nofollow"><code>hub_strategy</code></a> parameter:</p> <ul data-svelte-h="svelte-1f7sie2"><li><code>hub_strategy=&quot;checkpoint&quot;</code> pushes the latest checkpoint to a subfolder named “last-checkpoint” from which you can resume training</li> <li><code>hub_strategy=&quot;all_checkpoints&quot;</code> pushes all checkpoints to the directory defined in <code>output_dir</code> (you’ll see one checkpoint per folder in your model repository)</li></ul> <p data-svelte-h="svelte-1cs1tjc">When you resume training from a checkpoint, the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> tries to keep the Python, NumPy, and PyTorch RNG states the same as they were when the checkpoint was saved. But because PyTorch has various non-deterministic default settings, the RNG states aren’t guaranteed to be the same. If you want to enable full determinism, take a look at the <a href="https://pytorch.org/docs/stable/notes/randomness#controlling-sources-of-randomness" rel="nofollow">Controlling sources of randomness</a> guide to learn what you can enable to make your training fully deterministic. Keep in mind though that by making certain settings deterministic, training may be slower.</p> <h2 class="relative group"><a id="customize-the-trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#customize-the-trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Customize the Trainer</span></h2> <p data-svelte-h="svelte-1itvinc">While the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class is designed to be accessible and easy-to-use, it also offers a lot of customizability for more adventurous users. Many of the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a>’s method can be subclassed and overridden to support the functionality you want, without having to rewrite the entire training loop from scratch to accommodate it. These methods include:</p> <ul data-svelte-h="svelte-18qbdh5"><li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.get_train_dataloader">get_train_dataloader()</a> creates a training DataLoader</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.get_eval_dataloader">get_eval_dataloader()</a> creates an evaluation DataLoader</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.get_test_dataloader">get_test_dataloader()</a> creates a test DataLoader</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.log">log()</a> logs information on the various objects that watch training</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.create_optimizer_and_scheduler">create_optimizer_and_scheduler()</a> creates an optimizer and learning rate scheduler if they weren’t passed in the <code>__init__</code>; these can also be separately customized with <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.create_optimizer">create_optimizer()</a> and <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.create_scheduler">create_scheduler()</a> respectively</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.compute_loss">compute_loss()</a> computes the loss on a batch of training inputs</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.training_step">training_step()</a> performs the training step</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.prediction_step">prediction_step()</a> performs the prediction and test step</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.evaluate">evaluate()</a> evaluates the model and returns the evaluation metrics</li> <li><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.predict">predict()</a> makes predictions (with metrics if labels are available) on the test set</li></ul> <p data-svelte-h="svelte-1edijll">For example, if you want to customize the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.compute_loss">compute_loss()</a> method to use a weighted loss instead.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
<span class="hljs-keyword">class</span> <span class="hljs-title class_">CustomTrainer</span>(<span class="hljs-title class_ inherited__">Trainer</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_loss</span>(<span class="hljs-params">self, model, inputs, return_outputs=<span class="hljs-literal">False</span></span>):
labels = inputs.pop(<span class="hljs-string">&quot;labels&quot;</span>)
<span class="hljs-comment"># forward pass</span>
outputs = model(**inputs)
logits = outputs.get(<span class="hljs-string">&quot;logits&quot;</span>)
<span class="hljs-comment"># compute custom loss for 3 labels with different weights</span>
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([<span class="hljs-number">1.0</span>, <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>], device=model.device))
loss = loss_fct(logits.view(-<span class="hljs-number">1</span>, self.model.config.num_labels), labels.view(-<span class="hljs-number">1</span>))
<span class="hljs-keyword">return</span> (loss, outputs) <span class="hljs-keyword">if</span> return_outputs <span class="hljs-keyword">else</span> loss<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="callbacks" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#callbacks"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Callbacks</span></h3> <p data-svelte-h="svelte-1vzqlvv">Another option for customizing the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> is to use <a href="callbacks">callbacks</a>. Callbacks <em>don’t change</em> anything in the training loop. They inspect the training loop state and then execute some action (early stopping, logging results, etc.) depending on the state. In other words, a callback can’t be used to implement something like a custom loss function and you’ll need to subclass and override the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer.compute_loss">compute_loss()</a> method for that.</p> <p data-svelte-h="svelte-ckk1ok">For example, if you want to add an early stopping callback to the training loop after 10 steps.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainerCallback
<span class="hljs-keyword">class</span> <span class="hljs-title class_">EarlyStoppingCallback</span>(<span class="hljs-title class_ inherited__">TrainerCallback</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, num_steps=<span class="hljs-number">10</span></span>):
self.num_steps = num_steps
<span class="hljs-keyword">def</span> <span class="hljs-title function_">on_step_end</span>(<span class="hljs-params">self, args, state, control, **kwargs</span>):
<span class="hljs-keyword">if</span> state.global_step &gt;= self.num_steps:
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;should_training_stop&quot;</span>: <span class="hljs-literal">True</span>}
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">return</span> {}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ygxhy3">Then pass it to the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a>’s <code>callback</code> parameter.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=dataset[<span class="hljs-string">&quot;test&quot;</span>],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()],
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="logging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#logging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Logging</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-6jeeaq">Check out the <a href="./main_classes/logging">logging</a> API reference for more information about the different logging levels.</p></div> <p data-svelte-h="svelte-4gijgx">The <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> is set to <code>logging.INFO</code> by default which reports errors, warnings, and other basic information. A <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> replica - in distributed environments - is set to <code>logging.WARNING</code> which only reports errors and warnings. You can change the logging level with the <a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level" rel="nofollow"><code>log_level</code></a> and <a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level_replica" rel="nofollow"><code>log_level_replica</code></a> parameters in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>.</p> <p data-svelte-h="svelte-3l1nfv">To configure the log level setting for each node, use the <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.log_on_each_node" rel="nofollow"><code>log_on_each_node</code></a> parameter to determine whether to use the log level on each node or only on the main node.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1vc8x31"><a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> sets the log level separately for each node in the <code>Trainer.__init__()</code> method, so you may want to consider setting this sooner if you’re using other Transformers functionalities before creating the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> object.</p></div> <p data-svelte-h="svelte-14qffqk">For example, to set your main code and modules to use the same log level according to each node:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->logger = logging.getLogger(__name__)
logging.basicConfig(
<span class="hljs-built_in">format</span>=<span class="hljs-string">&quot;%(asctime)s - %(levelname)s - %(name)s - %(message)s&quot;</span>,
datefmt=<span class="hljs-string">&quot;%m/%d/%Y %H:%M:%S&quot;</span>,
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-fy4z00">Use different combinations of <code>log_level</code> and <code>log_level_replica</code> to configure what gets logged on each of the nodes.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">single node </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">multi-node </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->my_app.py ... --log_level warning --log_level_replica error<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="neftune" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#neftune"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>NEFTune</span></h2> <p data-svelte-h="svelte-1scrzgx"><a href="https://hf.co/papers/2310.05914" rel="nofollow">NEFTune</a> is a technique that can improve performance by adding noise to the embedding vectors during training. To enable it in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a>, set the <code>neftune_noise_alpha</code> parameter in <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to control how much noise is added.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, Trainer
training_args = TrainingArguments(..., neftune_noise_alpha=<span class="hljs-number">0.1</span>)
trainer = Trainer(..., args=training_args)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1tyg2gl">NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.</p> <h2 class="relative group"><a id="liger-kernel" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#liger-kernel"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Liger Kernel</span></h2> <p data-svelte-h="svelte-1s4q0bp"><a href="https://github.com/linkedin/Liger-Kernel" rel="nofollow">Liger-Kernel</a> Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400">Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)</div> <p data-svelte-h="svelte-y2r2m1">First make sure to install Liger official repository:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install liger-kernel<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1jlq39z">You should pass <code>use_liger_kernel=True</code> to apply liger kernel on your model, for example:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;your-model&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
per_device_train_batch_size=<span class="hljs-number">16</span>,
per_device_eval_batch_size=<span class="hljs-number">16</span>,
num_train_epochs=<span class="hljs-number">2</span>,
weight_decay=<span class="hljs-number">0.01</span>,
eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
load_best_model_at_end=<span class="hljs-literal">True</span>,
push_to_hub=<span class="hljs-literal">True</span>,
use_liger_kernel=<span class="hljs-literal">True</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1t0qhj8">The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found <a href="https://github.com/linkedin/Liger-Kernel" rel="nofollow">here</a>. When <code>use_liger_kernel</code> is set to <code>True</code>, the corresponding layers in the original model will be patched with Liger’s efficient implementation, so you don’t need to do anything extra other than setting the argument value.</p> <h2 class="relative group"><a id="optimizers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#optimizers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Optimizers</span></h2> <p data-svelte-h="svelte-5eb3w7">You can choose a built-in optimizer for training using:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(..., optim=<span class="hljs-string">&quot;adamw_torch&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-keyyt">See <a href="https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py" rel="nofollow"><code>OptimizerNames</code></a> for a full list of choices. We include advanced examples in the sections below.</p> <p data-svelte-h="svelte-1va5gwm">You can also use an arbitrary PyTorch optimizer via:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
optimizer_cls = torch.optim.AdamW
optimizer_kwargs = {
<span class="hljs-string">&quot;lr&quot;</span>: <span class="hljs-number">4e-3</span>,
<span class="hljs-string">&quot;betas&quot;</span>: (<span class="hljs-number">0.9</span>, <span class="hljs-number">0.999</span>),
<span class="hljs-string">&quot;weight_decay&quot;</span>: <span class="hljs-number">0.05</span>,
}
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(..., optimizer_cls_and_kwargs=(optimizer_cls, optimizer_kwargs))<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="galore" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#galore"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GaLore</span></h3> <p data-svelte-h="svelte-qi406x">Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.</p> <p data-svelte-h="svelte-6gjlmq">First make sure to install GaLore official repository:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install galore-torch<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g9eiy6">Then simply add one of <code>[&quot;galore_adamw&quot;, &quot;galore_adafactor&quot;, &quot;galore_adamw_8bit&quot;]</code> in <code>optim</code> together with <code>optim_target_modules</code>, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to <code>pip install trl datasets</code>):</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw&quot;</span>,
optim_target_modules=[<span class="hljs-string">r&quot;.*.attn.*&quot;</span>, <span class="hljs-string">r&quot;.*.mlp.*&quot;</span>]
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-i13fla">To pass extra arguments supported by GaLore, you should pass correctly <code>optim_args</code>, for example:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw&quot;</span>,
optim_target_modules=[<span class="hljs-string">r&quot;.*.attn.*&quot;</span>, <span class="hljs-string">r&quot;.*.mlp.*&quot;</span>],
optim_args=<span class="hljs-string">&quot;rank=64, update_proj_gap=100, scale=0.10&quot;</span>,
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ag0i93">You can read more about the method in the <a href="https://github.com/jiaweizzhao/GaLore" rel="nofollow">original repository</a> or the <a href="https://arxiv.org/abs/2403.03507" rel="nofollow">paper</a>.</p> <p data-svelte-h="svelte-14lab64">Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.</p> <p data-svelte-h="svelte-k79irj">Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.</p> <p data-svelte-h="svelte-uvc2be">You can also perform layer-wise optimization by post-pending the optimizer name with <code>layerwise</code> like below:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw_layerwise&quot;</span>,
optim_target_modules=[<span class="hljs-string">r&quot;.*.attn.*&quot;</span>, <span class="hljs-string">r&quot;.*.mlp.*&quot;</span>]
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-c2nuo5">Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see <a href="https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory" rel="nofollow">this appropriate section</a> for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please <a href="https://github.com/huggingface/transformers/issues" rel="nofollow">raise an issue on GitHub</a> if you encounter such issue.</p> <h3 class="relative group"><a id="lomo-optimizer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#lomo-optimizer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>LOMO optimizer</span></h3> <p data-svelte-h="svelte-1nv1q2g">The LOMO optimizers have been introduced in <a href="https://hf.co/papers/2306.09782" rel="nofollow">Full Parameter Fine-Tuning for Large Language Models with Limited Resources</a> and <a href="https://hf.co/papers/2310.10195" rel="nofollow">AdaLomo: Low-memory Optimization with Adaptive Learning Rate</a>.
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are <code>&quot;lomo&quot;</code> and <code>&quot;adalomo&quot;</code>. First either install LOMO from pypi <code>pip install lomo-optim</code> or install it from source with <code>pip install git+https://github.com/OpenLMLab/LOMO.git</code>.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-q6vma3">According to the authors, it is recommended to use <code>AdaLomo</code> without <code>grad_norm</code> to get better performance and higher throughput.</p></div> <p data-svelte-h="svelte-1ed4he8">Below is a simple script to demonstrate how to fine-tune <a href="https://huggingface.co/google/gemma-2b" rel="nofollow">google/gemma-2b</a> on IMDB dataset in full precision:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> trl
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-lomo&quot;</span>,
max_steps=<span class="hljs-number">1000</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
optim=<span class="hljs-string">&quot;adalomo&quot;</span>,
gradient_checkpointing=<span class="hljs-literal">True</span>,
logging_strategy=<span class="hljs-string">&quot;steps&quot;</span>,
logging_steps=<span class="hljs-number">1</span>,
learning_rate=<span class="hljs-number">2e-6</span>,
save_strategy=<span class="hljs-string">&quot;no&quot;</span>,
run_name=<span class="hljs-string">&quot;lomo-imdb&quot;</span>,
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=<span class="hljs-literal">True</span>).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">1024</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="grokadamw-optimizer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#grokadamw-optimizer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GrokAdamW optimizer</span></h3> <p data-svelte-h="svelte-19cxj8v">The GrokAdamW optimizer is designed to enhance training performance and stability, particularly for models that benefit from grokking signal functions. To use GrokAdamW, first install the optimizer package with <code>pip install grokadamw</code>.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-kys0va">GrokAdamW is particularly useful for models that require advanced optimization techniques to achieve better performance and stability.</p></div> <p data-svelte-h="svelte-sjszbo">Below is a simple script to demonstrate how to fine-tune <a href="https://huggingface.co/google/gemma-2b" rel="nofollow">google/gemma-2b</a> on the IMDB dataset using the GrokAdamW optimizer:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoTokenizer, AutoModelForCausalLM, Trainer
<span class="hljs-comment"># Load the IMDB dataset</span>
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
<span class="hljs-comment"># Define the training arguments</span>
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-grokadamw&quot;</span>,
max_steps=<span class="hljs-number">1000</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
optim=<span class="hljs-string">&quot;grokadamw&quot;</span>,
logging_strategy=<span class="hljs-string">&quot;steps&quot;</span>,
logging_steps=<span class="hljs-number">1</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
save_strategy=<span class="hljs-string">&quot;no&quot;</span>,
run_name=<span class="hljs-string">&quot;grokadamw-imdb&quot;</span>,
)
<span class="hljs-comment"># Load the model and tokenizer</span>
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=<span class="hljs-literal">True</span>).to(<span class="hljs-number">0</span>)
<span class="hljs-comment"># Initialize the Trainer</span>
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
)
<span class="hljs-comment"># Train the model</span>
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-pxxonu">This script demonstrates how to fine-tune the <code>google/gemma-2b</code> model on the IMDB dataset using the GrokAdamW optimizer. The <code>TrainingArguments</code> are configured to use GrokAdamW, and the dataset is passed to the <code>Trainer</code> for training.</p> <h3 class="relative group"><a id="schedule-free-optimizer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#schedule-free-optimizer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Schedule Free Optimizer</span></h3> <p data-svelte-h="svelte-w5uo13">The Schedule Free optimizers have been introduced in <a href="https://hf.co/papers/2405.15682" rel="nofollow">The Road Less Scheduled</a>.
Schedule-Free learning replaces the momentum of the base optimizer with a combination of averaging and interpolation, to completely remove the need to anneal the learning rate with a traditional schedule.
Supported optimizers for SFO are <code>&quot;schedule_free_adamw&quot;</code> and <code>&quot;schedule_free_sgd&quot;</code>. First install schedulefree from pypi <code>pip install schedulefree</code>.</p> <p data-svelte-h="svelte-1ed4he8">Below is a simple script to demonstrate how to fine-tune <a href="https://huggingface.co/google/gemma-2b" rel="nofollow">google/gemma-2b</a> on IMDB dataset in full precision:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> trl
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-schedulefree&quot;</span>,
max_steps=<span class="hljs-number">1000</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
optim=<span class="hljs-string">&quot;schedule_free_adamw&quot;</span>,
gradient_checkpointing=<span class="hljs-literal">True</span>,
logging_strategy=<span class="hljs-string">&quot;steps&quot;</span>,
logging_steps=<span class="hljs-number">1</span>,
learning_rate=<span class="hljs-number">2e-6</span>,
save_strategy=<span class="hljs-string">&quot;no&quot;</span>,
run_name=<span class="hljs-string">&quot;sfo-imdb&quot;</span>,
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=<span class="hljs-literal">True</span>).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">1024</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="accelerate-and-trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerate-and-trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate and Trainer</span></h2> <p data-svelte-h="svelte-1fzp47z">The <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> class is powered by <a href="https://hf.co/docs/accelerate" rel="nofollow">Accelerate</a>, a library for easily training PyTorch models in distributed environments with support for integrations such as <a href="https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/" rel="nofollow">FullyShardedDataParallel (FSDP)</a> and <a href="https://www.deepspeed.ai/" rel="nofollow">DeepSpeed</a>.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1gqp6lx">Learn more about FSDP sharding strategies, CPU offloading, and more with the <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> in the <a href="fsdp">Fully Sharded Data Parallel</a> guide.</p></div> <p data-svelte-h="svelte-1rk54vp">To use Accelerate with <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a>, run the <a href="https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-config" rel="nofollow"><code>accelerate.config</code></a> command to set up training for your training environment. This command creates a <code>config_file.yaml</code> that’ll be used when you launch your training script. For example, some example configurations you can setup are:</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">DistributedDataParallel </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">FSDP </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">DeepSpeed </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">DeepSpeed with Accelerate plugin </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attr">compute_environment:</span> <span class="hljs-string">LOCAL_MACHINE</span>
<span class="hljs-attr">distributed_type:</span> <span class="hljs-string">MULTI_GPU</span>
<span class="hljs-attr">downcast_bf16:</span> <span class="hljs-string">&#x27;no&#x27;</span>
<span class="hljs-attr">gpu_ids:</span> <span class="hljs-string">all</span>
<span class="hljs-attr">machine_rank:</span> <span class="hljs-number">0</span> <span class="hljs-comment">#change rank as per the node</span>
<span class="hljs-attr">main_process_ip:</span> <span class="hljs-number">192.168</span><span class="hljs-number">.20</span><span class="hljs-number">.1</span>
<span class="hljs-attr">main_process_port:</span> <span class="hljs-number">9898</span>
<span class="hljs-attr">main_training_function:</span> <span class="hljs-string">main</span>
<span class="hljs-attr">mixed_precision:</span> <span class="hljs-string">fp16</span>
<span class="hljs-attr">num_machines:</span> <span class="hljs-number">2</span>
<span class="hljs-attr">num_processes:</span> <span class="hljs-number">8</span>
<span class="hljs-attr">rdzv_backend:</span> <span class="hljs-string">static</span>
<span class="hljs-attr">same_network:</span> <span class="hljs-literal">true</span>
<span class="hljs-attr">tpu_env:</span> []
<span class="hljs-attr">tpu_use_cluster:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">tpu_use_sudo:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">use_cpu:</span> <span class="hljs-literal">false</span><!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-1l55lf4">The <a href="https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch" rel="nofollow"><code>accelerate_launch</code></a> command is the recommended way to launch your training script on a distributed system with Accelerate and <a href="/docs/transformers/pr_33913/en/main_classes/trainer#transformers.Trainer">Trainer</a> with the parameters specified in <code>config_file.yaml</code>. This file is saved to the Accelerate cache folder and automatically loaded when you run <code>accelerate_launch</code>.</p> <p data-svelte-h="svelte-cneufk">For example, to run the <a href="https://github.com/huggingface/transformers/blob/f4db565b695582891e43a5e042e5d318e28f20b8/examples/pytorch/text-classification/run_glue.py#L4" rel="nofollow">run_glue.py</a> training script with the FSDP configuration:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name <span class="hljs-variable">$TASK_NAME</span> \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/<span class="hljs-variable">$TASK_NAME</span>/ \
--overwrite_output_dir<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13yow9q">You could also specify the parameters from the <code>config_file.yaml</code> file directly in the command line:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch --num_processes=2 \
--use_fsdp \
--mixed_precision=bf16 \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap=<span class="hljs-string">&quot;BertLayer&quot;</span> \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py
--model_name_or_path google-bert/bert-base-cased \
--task_name <span class="hljs-variable">$TASK_NAME</span> \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/<span class="hljs-variable">$TASK_NAME</span>/ \
--overwrite_output_dir<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zjyylh">Check out the <a href="https://huggingface.co/docs/accelerate/basic_tutorials/launch" rel="nofollow">Launching your Accelerate scripts</a> tutorial to learn more about <code>accelerate_launch</code> and custom configurations.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/trainer.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_z647wz = {
assets: "/docs/transformers/pr_33913/en",
base: "/docs/transformers/pr_33913/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_33913/en/_app/immutable/entry/start.b67f883f.js"),
import("/docs/transformers/pr_33913/en/_app/immutable/entry/app.e436b1f2.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 452],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
97.8 kB
·
Xet hash:
fbc94728808542d7467420231dd3871a52a67190c7fafffb1615ad8387b2759f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.