Buckets:

rtrm's picture
download
raw
72.4 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Trainer&quot;,&quot;local&quot;:&quot;trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Checkpoints&quot;,&quot;local&quot;:&quot;checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Logging&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Customize&quot;,&quot;local&quot;:&quot;customize&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Callbacks&quot;,&quot;local&quot;:&quot;callbacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Accelerate&quot;,&quot;local&quot;:&quot;accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Optimizations&quot;,&quot;local&quot;:&quot;optimizations&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;torch.compile&quot;,&quot;local&quot;:&quot;torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GaLore&quot;,&quot;local&quot;:&quot;galore&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Liger&quot;,&quot;local&quot;:&quot;liger&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;NEFTune&quot;,&quot;local&quot;:&quot;neftune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/pr_36839/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/scheduler.01eeda35.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/singletons.177df05e.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.4862150a.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/paths.517376d1.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.6dd51b66.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/0.8897c14d.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/491.876e4d1c.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/Tip.de9bae2b.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/CodeBlock.864da1b0.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/EditOnGithub.7faefd25.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/HfOption.f7f04550.js">
<link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/stores.318eade7.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Trainer&quot;,&quot;local&quot;:&quot;trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Checkpoints&quot;,&quot;local&quot;:&quot;checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Logging&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Customize&quot;,&quot;local&quot;:&quot;customize&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Callbacks&quot;,&quot;local&quot;:&quot;callbacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Accelerate&quot;,&quot;local&quot;:&quot;accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Optimizations&quot;,&quot;local&quot;:&quot;optimizations&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;torch.compile&quot;,&quot;local&quot;:&quot;torchcompile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;GaLore&quot;,&quot;local&quot;:&quot;galore&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Liger&quot;,&quot;local&quot;:&quot;liger&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;NEFTune&quot;,&quot;local&quot;:&quot;neftune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Trainer</span></h1> <p data-svelte-h="svelte-17y3lle"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> is a complete training and evaluation loop for Transformers’ PyTorch models. Plug a model, preprocessor, dataset, and training arguments into <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> and let it handle the rest to start training faster.</p> <p data-svelte-h="svelte-1brfiow"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> is also powered by <a href="https://hf.co/docs/accelerate/index" rel="nofollow">Accelerate</a>, a library for handling large models for distributed training.</p> <p data-svelte-h="svelte-itpr3w">This guide will show you how <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> works and how to customize it for your use case with a callback.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip install accelerate --upgrade<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9spklj"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> contains all the necessary components of a training loop.</p> <ol data-svelte-h="svelte-13k8azk"><li>calculate the loss from a training step</li> <li>calculate the gradients with the <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.backward" rel="nofollow">backward</a> method</li> <li>update the weights based on the gradients</li> <li>repeat until the predetermined number of epochs is reached</li></ol> <p data-svelte-h="svelte-pd3rsx">Manually coding this training loop everytime can be inconvenient or a barrier if you’re just getting started with machine learning. <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> abstracts this process, allowing you to focus on the model, dataset, and training design choices.</p> <p data-svelte-h="svelte-1qnpmk8">Configure your training with hyperparameters and options from <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> which supports many features such as distributed training, torch.compile, mixed precision training, and saving the model to the Hub.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-zccfj9">The number of available parameters available in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> may be intimidating at first. If there is a specific hyperparameter or feature you want to use, try searching for it directly. Otherwise, feel free to start with the default values and gradually customize them as you become more familiar with the training process.</p></div> <p data-svelte-h="svelte-1qtcje1">The example below demonstrates an example of <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> that evaluates and saves the model at the end of each epoch. It also loads the best model found during training and pushes it to the Hub.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;your-model&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
per_device_train_batch_size=<span class="hljs-number">16</span>,
per_device_eval_batch_size=<span class="hljs-number">16</span>,
num_train_epochs=<span class="hljs-number">2</span>,
weight_decay=<span class="hljs-number">0.01</span>,
eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
load_best_model_at_end=<span class="hljs-literal">True</span>,
push_to_hub=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1gzkq6n">Pass your model, dataset, preprocessor, and <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a>, and call <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.train">train()</a> to start training.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-imgvnc">Refer to the <a href="./training">Fine-tuning</a> guide for a more complete overview of the training process.</p></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=dataset[<span class="hljs-string">&quot;test&quot;</span>],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="checkpoints" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#checkpoints"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Checkpoints</span></h2> <p data-svelte-h="svelte-t2wjzo"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> saves checkpoints (the optimizer state is not saved by default) to the directory in <code>output_dir</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to a subfolder named <code>checkpoint-000</code>. The number at the end is the training step at which the checkpoint was saved.</p> <p data-svelte-h="svelte-wo0sum">Saving checkpoints are useful for resuming training or recovering your training progress if you encounter an error. Set the <code>resume_from_checkpoint</code> parameter in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.train">train()</a> to resume training from the last checkpoint or a specific checkpoint.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">latest checkpoint </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">specific checkpoint </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train(resume_from_checkpoint=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-jh9l5w">Checkpoints can be saved to the Hub by setting <code>push_to_hub=True</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. The default method (<code>&quot;every_save&quot;</code>) saves a checkpoint to the Hub every time a model is saved, which is typically the final model at the end of training. Some other options for deciding how to save checkpoints to the Hub include the following.</p> <ul data-svelte-h="svelte-1nmyt88"><li><code>hub_strategy=&quot;end&quot;</code> only pushes a checkpoint when <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.save_model">save_model()</a> is called</li> <li><code>hub_strategy=&quot;checkpoint&quot;</code> pushes the latest checkpoint to a subfolder named <em>last-checkpoint</em> from which training can be resumed</li> <li><code>hub_strategy=&quot;all_checkpoints&quot;</code> pushes all checkpoints to the Hub with one checkpoint per subfolder in your model repository</li></ul> <p data-svelte-h="svelte-pkbot6"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> attempts to maintain the same Python, NumPy, and PyTorch RNG states when you resume training from a checkpoint. But PyTorch has various non-deterministic settings which can’t guarantee the RNG states are identical. To enable full determinism, refer to the <a href="https://pytorch.org/docs/stable/notes/randomness#controlling-sources-of-randomness" rel="nofollow">Controlling sources of randomness</a> guide to learn what settings to adjust to make training fully deterministic (some settings may result in slower training).</p> <h2 class="relative group"><a id="logging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#logging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Logging</span></h2> <p data-svelte-h="svelte-19cqfij"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> is set to <code>logging.INFO</code> by default to report errors, warnings, and other basic information. Use <code>log_level()</code> to change the logging level and log verbosity.</p> <p data-svelte-h="svelte-13ujenl">The example below sets the main code and modules to use the same log level.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->logger = logging.getLogger(__name__)
logging.basicConfig(
<span class="hljs-built_in">format</span>=<span class="hljs-string">&quot;%(asctime)s - %(levelname)s - %(name)s - %(message)s&quot;</span>,
datefmt=<span class="hljs-string">&quot;%m/%d/%Y %H:%M:%S&quot;</span>,
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hdekwa">In a distributed environment, <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> replicas are set to <code>logging.WARNING</code> to only report errors and warnings. Use <code>log_level_replica()</code> to change the logging level and log verbosity. To configure the log level for each node, use <code>log_on_each_node()</code> to determine whether to use a specific log level on each node or only the main node.</p> <p data-svelte-h="svelte-iuct8l">Use different combinations of <code>log_level</code> and <code>log_level_replica</code> to configure what gets logged on each node.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">single node </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">multi-node </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->my_app.py ... --log_level warning --log_level_replica error<!-- HTML_TAG_END --></pre></div> </div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-5chms7">The log level is separately set for each node in the <code>__init__()</code> method. Consider setting this sooner if you’re using other Transformers functionalities before creating the <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> instance.</p></div> <h2 class="relative group"><a id="customize" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#customize"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Customize</span></h2> <p data-svelte-h="svelte-1cbvniw">Tailor <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> to your use case by subclassing or overriding its methods to support the functionality you want to add or use, without rewriting the entire training loop from scratch. The table below lists some of the methods that can be customized.</p> <table data-svelte-h="svelte-6ggdkm"><thead><tr><th>method</th> <th>description</th></tr></thead> <tbody><tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.get_train_dataloader">get_train_dataloader()</a></td> <td>create a training DataLoader</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.get_eval_dataloader">get_eval_dataloader()</a></td> <td>create an evaluation DataLoader</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.get_test_dataloader">get_test_dataloader()</a></td> <td>create a test DataLoader</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.log">log()</a></td> <td>log information about the training process</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.create_optimizer_and_scheduler">create_optimizer_and_scheduler()</a></td> <td>create an optimizer and learning rate scheduler (can also be separately customized with <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.create_optimizer">create_optimizer()</a> and <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.create_scheduler">create_scheduler()</a> if they weren’t passed in <code>__init__</code>)</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.compute_loss">compute_loss()</a></td> <td>compute the loss of a batch of training inputs</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.training_step">training_step()</a></td> <td>perform the training step</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.prediction_step">prediction_step()</a></td> <td>perform the prediction and test step</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.evaluate">evaluate()</a></td> <td>evaluate the model and return the evaluation metric</td></tr> <tr><td><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.predict">predict()</a></td> <td>make a prediction (with metrics if labels are available) on the test set</td></tr></tbody></table> <p data-svelte-h="svelte-e2g5v1">For example, to use weighted loss, rewrite <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.compute_loss">compute_loss()</a> inside <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
<span class="hljs-keyword">class</span> <span class="hljs-title class_">CustomTrainer</span>(<span class="hljs-title class_ inherited__">Trainer</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_loss</span>(<span class="hljs-params">self, model, inputs, return_outputs=<span class="hljs-literal">False</span>, num_items_in_batch=<span class="hljs-literal">None</span></span>):
labels = inputs.pop(<span class="hljs-string">&quot;labels&quot;</span>)
<span class="hljs-comment"># forward pass</span>
outputs = model(**inputs)
logits = outputs.get(<span class="hljs-string">&quot;logits&quot;</span>)
<span class="hljs-comment"># compute custom loss for 3 labels with different weights</span>
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([<span class="hljs-number">1.0</span>, <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>], device=model.device))
loss = loss_fct(logits.view(-<span class="hljs-number">1</span>, self.model.config.num_labels), labels.view(-<span class="hljs-number">1</span>))
<span class="hljs-keyword">return</span> (loss, outputs) <span class="hljs-keyword">if</span> return_outputs <span class="hljs-keyword">else</span> loss<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="callbacks" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#callbacks"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Callbacks</span></h3> <p data-svelte-h="svelte-1d0q7ic"><a href="./main_classes/callback">Callbacks</a> are another way to customize <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a>, but they don’t change anything <em>inside the training loop</em>. Instead, a callback inspects the training loop state and executes some action (early stopping, logging, etc.) depending on the state. For example, you can’t implement a custom loss function with a callback because that requires overriding <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer.compute_loss">compute_loss()</a>.</p> <p data-svelte-h="svelte-efpti5">To use a callback, create a class that inherits from <a href="/docs/transformers/pr_36839/en/main_classes/callback#transformers.TrainerCallback">TrainerCallback</a> and implements the functionality you want. Then pass the callback to the <code>callback</code> parameter in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a>. The example below implements an early stopping callback that stops training after 10 steps.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainerCallback, Trainer
<span class="hljs-keyword">class</span> <span class="hljs-title class_">EarlyStoppingCallback</span>(<span class="hljs-title class_ inherited__">TrainerCallback</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, num_steps=<span class="hljs-number">10</span></span>):
self.num_steps = num_steps
<span class="hljs-keyword">def</span> <span class="hljs-title function_">on_step_end</span>(<span class="hljs-params">self, args, state, control, **kwargs</span>):
<span class="hljs-keyword">if</span> state.global_step &gt;= self.num_steps:
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;should_training_stop&quot;</span>: <span class="hljs-literal">True</span>}
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">return</span> {}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=dataset[<span class="hljs-string">&quot;test&quot;</span>],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()],
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate</span></h2> <p data-svelte-h="svelte-pt3qay"><a href="https://hf.co/docs/accelerate/index" rel="nofollow">Accelerate</a> is a library that simplifies training in distributed environments and across different hardware. Its integration with <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> means <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> supports distributed training frameworks like <a href="https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/" rel="nofollow">Fully Sharded Data Parallel (FSDP)</a> and <a href="https://www.deepspeed.ai/" rel="nofollow">DeepSpeed</a>.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-10gsu19">Learn more about FSDP sharding strategies, CPU offloading, and more with <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> in the <a href="./fsdp">Fully Sharded Data Parallel</a> guide.</p></div> <p data-svelte-h="svelte-357ij6">To use Accelerate with <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a>, run the <a href="https://hf.co/docs/accelerate/package_reference/cli#accelerate-config" rel="nofollow">accelerate_config</a> command to configure your training environment. This command creates a <code>config_file.yaml</code> file that stores the configuration settings of your training environment and it’s used whenever you launch your training script. Some example distributed training configurations are shown below.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">DistributedDataParallel </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">FullyShardedDataParallel </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">DeepSpeed </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">DeepSpeed with Accelerate plugin </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attr">compute_environment:</span> <span class="hljs-string">LOCAL_MACHINE</span>
<span class="hljs-attr">distributed_type:</span> <span class="hljs-string">MULTI_GPU</span>
<span class="hljs-attr">downcast_bf16:</span> <span class="hljs-string">&#x27;no&#x27;</span>
<span class="hljs-attr">gpu_ids:</span> <span class="hljs-string">all</span>
<span class="hljs-attr">machine_rank:</span> <span class="hljs-number">0</span> <span class="hljs-comment">#change rank as per the node</span>
<span class="hljs-attr">main_process_ip:</span> <span class="hljs-number">192.168</span><span class="hljs-number">.20</span><span class="hljs-number">.1</span>
<span class="hljs-attr">main_process_port:</span> <span class="hljs-number">9898</span>
<span class="hljs-attr">main_training_function:</span> <span class="hljs-string">main</span>
<span class="hljs-attr">mixed_precision:</span> <span class="hljs-string">fp16</span>
<span class="hljs-attr">num_machines:</span> <span class="hljs-number">2</span>
<span class="hljs-attr">num_processes:</span> <span class="hljs-number">8</span>
<span class="hljs-attr">rdzv_backend:</span> <span class="hljs-string">static</span>
<span class="hljs-attr">same_network:</span> <span class="hljs-literal">true</span>
<span class="hljs-attr">tpu_env:</span> []
<span class="hljs-attr">tpu_use_cluster:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">tpu_use_sudo:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">use_cpu:</span> <span class="hljs-literal">false</span><!-- HTML_TAG_END --></pre></div> <hfoption id="Tensor parallelism with PyTorch 2"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attr">compute_environment:</span> <span class="hljs-string">LOCAL_MACHINE</span>
<span class="hljs-attr">tp_config:</span>
<span class="hljs-attr">tp_size:</span> <span class="hljs-number">4</span>
<span class="hljs-attr">distributed_type:</span> <span class="hljs-string">TP</span>
<span class="hljs-attr">downcast_bf16:</span> <span class="hljs-string">&#x27;no&#x27;</span>
<span class="hljs-attr">machine_rank:</span> <span class="hljs-number">0</span>
<span class="hljs-attr">main_training_function:</span> <span class="hljs-string">main</span>
<span class="hljs-attr">mixed_precision:</span> <span class="hljs-string">&#x27;no&#x27;</span>
<span class="hljs-attr">num_machines:</span> <span class="hljs-number">1</span>
<span class="hljs-attr">num_processes:</span> <span class="hljs-number">4</span>
<span class="hljs-attr">rdzv_backend:</span> <span class="hljs-string">static</span>
<span class="hljs-attr">same_network:</span> <span class="hljs-literal">true</span>
<span class="hljs-attr">tpu_env:</span> []
<span class="hljs-attr">tpu_use_cluster:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">tpu_use_sudo:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">use_cpu:</span> <span class="hljs-literal">false</span><!-- HTML_TAG_END --></pre></div></hfoption></div> <p data-svelte-h="svelte-g8vtta">Run <a href="https://hf.co/docs/accelerate/package_reference/cli#accelerate-launch" rel="nofollow">accelerate_launch</a> to start training with the configurations set in <code>config_file.yaml</code>. This file is saved to the Accelerate cache folder and automatically loaded when you run <code>accelerate_launch</code>.</p> <p data-svelte-h="svelte-1w2ffu3">The example below launches the <a href="../../../examples/pytorch/text-classification/run_glue">run_glue.py</a> script with the FSDP configuration shown earlier. Parameters from the <code>config_file.yaml</code> file can also be directly set in the command line.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name <span class="hljs-variable">$TASK_NAME</span> \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/<span class="hljs-variable">$TASK_NAME</span>/ \
--overwrite_output_dir<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1ih0qfg">Refer to the <a href="https://hf.co/docs/accelerate/basic_tutorials/launch" rel="nofollow">Launching your Accelerate scripts</a> tutorial to learn more about <code>accelerate_launch</code> and custom configurations.</p></div> <h2 class="relative group"><a id="optimizations" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#optimizations"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Optimizations</span></h2> <p data-svelte-h="svelte-mc9q72"><a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> supports various optimizations to improve <em>training</em> performance - reduce memory and increase training speed - and <em>model</em> performance.</p> <h3 class="relative group"><a id="torchcompile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchcompile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.compile</span></h3> <p data-svelte-h="svelte-1wim1fg"><a href="./perf_torch_compile">torch.compile</a> can significantly speed up training and reduce computational overhead. Configure your torch.compile settings in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. Set <code>torch.compile</code> to <code>True</code>, and select a backend and compile mode.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(
torch.<span class="hljs-built_in">compile</span>=<span class="hljs-literal">True</span>,
torch.compile_backend=<span class="hljs-string">&quot;inductor&quot;</span>,
torch_compile_mode=<span class="hljs-string">&quot;default&quot;</span>,
...,
)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="galore" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#galore"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GaLore</span></h3> <p data-svelte-h="svelte-1vgfgwy"><a href="https://hf.co/papers/2403.03507" rel="nofollow">Gradient Low-Rank Projection (GaLore)</a> significantly reduces memory usage when training large language models (LLMs). One of GaLores key benefits is <em>full-parameter</em> learning, unlike low-rank adaptation methods like <a href="https://hf.co/papers/2106.09685" rel="nofollow">LoRA</a>, which produces better model performance.</p> <p data-svelte-h="svelte-uszk55">Install the <a href="https://github.com/jiaweizzhao/GaLore" rel="nofollow">GaLore</a> library, <a href="https://hf.co/docs/trl/index" rel="nofollow">TRL</a>, and <a href="https://hf.co/docs/datasets/index" rel="nofollow">Datasets</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install galore-torch trl datasets<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-avpqe5">Pick a GaLore optimizer (<code>&quot;galore_adamw&quot;</code>, <code>&quot;galore_adafactor&quot;</code>, <code>&quot;galore_adamw_8bit</code>”) and pass it to the <code>optim</code> parameter in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. Use the <code>optim_target_modules</code> parameter to specify which modules to adapt (can be a list of strings, regex, or a full path).</p> <p data-svelte-h="svelte-ho9ecg">Extra parameters supported by GaLore, <code>rank</code>, <code>update_proj_gap</code>, and <code>scale</code>, should be passed to the <code>optim_args</code> parameter in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>.</p> <p data-svelte-h="svelte-1cvj6mo">The example below enables GaLore with <a href="https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTTrainer" rel="nofollow">SFTTrainer</a> that targets the <code>attn</code> and <code>mlp</code> layers with regex.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-jkq35c">It can take some time before training starts (~3 minutes for a 2B model on a NVIDIA A100).</p></div> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">GaLore optimizer </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">GaLore optimizer with layerwise optimization </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw&quot;</span>,
optim_target_modules=[<span class="hljs-string">r&quot;.*.attn.*&quot;</span>, <span class="hljs-string">r&quot;.*.mlp.*&quot;</span>],
optim_args=<span class="hljs-string">&quot;rank=64, update_proj_gap=100, scale=0.10&quot;</span>,
)
config = AutoConfig.from_pretrained(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>)
tokenizer = AutoTokenizer.from_pretrained(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>)
model = AutoModelForCausalLM.from_config(<span class="hljs-string">&quot;google/gemma-2b&quot;</span>).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-193bcdy">Only linear layers that are considered GaLore layers can be trained with low-rank decomposition. The rest of the model layers are optimized in the usual way.</p> <h3 class="relative group"><a id="liger" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#liger"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Liger</span></h3> <p data-svelte-h="svelte-1hqueez"><a href="https://github.com/linkedin/Liger-Kernel" rel="nofollow">Liger Kernel</a> is a collection of layers such as RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more that have been fused into a single Triton kernel for training LLMs. These kernels are also compatible with FlashAttention, FSDP, and DeepSpeed. As a result, Liger Kernel can increase multi-GPU training throughput and reduce memory usage. This is useful for multi-head training and supporting larger vocabulary sizes, larger batch sizes, and longer context lengths.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install liger-kernel<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-b5hse5">Enable Liger Kernel for training by setting <code>use_liger_kernel=True</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. This patches the corresponding layers in the model with Ligers kernels.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-v7ceec">Liger Kernel supports Llama, Gemma, Mistral, and Mixtral models. Refer to the <a href="https://github.com/linkedin/Liger-Kernel#patching" rel="nofollow">patching</a> list for the latest list of supported models.</p></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;your-model&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
per_device_train_batch_size=<span class="hljs-number">16</span>,
per_device_eval_batch_size=<span class="hljs-number">16</span>,
num_train_epochs=<span class="hljs-number">2</span>,
weight_decay=<span class="hljs-number">0.01</span>,
eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
load_best_model_at_end=<span class="hljs-literal">True</span>,
push_to_hub=<span class="hljs-literal">True</span>,
use_liger_kernel=<span class="hljs-literal">True</span>
)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="neftune" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#neftune"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>NEFTune</span></h3> <p data-svelte-h="svelte-1pey9e8"><a href="https://hf.co/papers/2310.05914" rel="nofollow">NEFTune</a> adds noise to the embedding vectors during training to improve model performance. Enable it in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.Trainer">Trainer</a> with the <code>neftune_noise_alpha</code> parameter in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to control how much noise is added.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, Trainer
training_args = TrainingArguments(..., neftune_noise_alpha=<span class="hljs-number">0.1</span>)
trainer = Trainer(..., args=training_args)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dfz8sx">The original embedding layer is restored after training to avoid any unexpected behavior.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/trainer.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1bm5psi = {
assets: "/docs/transformers/pr_36839/en",
base: "/docs/transformers/pr_36839/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"),
import("/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 491],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
72.4 kB
·
Xet hash:
c9f331b23e26769f2c4ea518a7243bda3d8713ef77ec910b5f75c609360b5e32

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.