Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"GPU","local":"gpu","sections":[{"title":"Trainer","local":"trainer","sections":[{"title":"Batch size","local":"batch-size","sections":[],"depth":3},{"title":"Gradient accumulation","local":"gradient-accumulation","sections":[],"depth":3},{"title":"Gradient checkpointing","local":"gradient-checkpointing","sections":[],"depth":3},{"title":"Mixed precision","local":"mixed-precision","sections":[],"depth":3},{"title":"Optimizers","local":"optimizers","sections":[],"depth":3},{"title":"Data preloading","local":"data-preloading","sections":[],"depth":3}],"depth":2},{"title":"PyTorch","local":"pytorch","sections":[{"title":"torch.empty_cache_steps","local":"torchemptycachesteps","sections":[],"depth":3},{"title":"torch.compile","local":"torchcompile","sections":[],"depth":3},{"title":"Scaled dot production attention","local":"scaled-dot-production-attention","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/transformers/pr_36839/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/scheduler.01eeda35.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/singletons.177df05e.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.4862150a.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/paths.517376d1.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/index.6dd51b66.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/0.8897c14d.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/nodes/421.15145b15.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/CodeBlock.864da1b0.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/EditOnGithub.7faefd25.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/HfOption.f7f04550.js"> | |
| <link rel="modulepreload" href="/docs/transformers/pr_36839/en/_app/immutable/chunks/stores.318eade7.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"GPU","local":"gpu","sections":[{"title":"Trainer","local":"trainer","sections":[{"title":"Batch size","local":"batch-size","sections":[],"depth":3},{"title":"Gradient accumulation","local":"gradient-accumulation","sections":[],"depth":3},{"title":"Gradient checkpointing","local":"gradient-checkpointing","sections":[],"depth":3},{"title":"Mixed precision","local":"mixed-precision","sections":[],"depth":3},{"title":"Optimizers","local":"optimizers","sections":[],"depth":3},{"title":"Data preloading","local":"data-preloading","sections":[],"depth":3}],"depth":2},{"title":"PyTorch","local":"pytorch","sections":[{"title":"torch.empty_cache_steps","local":"torchemptycachesteps","sections":[],"depth":3},{"title":"torch.compile","local":"torchcompile","sections":[],"depth":3},{"title":"Scaled dot production attention","local":"scaled-dot-production-attention","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="gpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GPU</span></h1> <p data-svelte-h="svelte-3v6hvz">GPUs are commonly used to train deep learning models due to their high memory bandwidth and parallel processing capabilities. Depending on your GPU and model size, it is possible to even train models with billions of parameters. The key is to find the right balance between GPU memory utilization (data throughput/training time) and training speed.</p> <p data-svelte-h="svelte-ryqnw1">This guide will show you the features available in Transformers and PyTorch for efficiently training a model on GPUs. In many cases, you’ll want to use a combination of these features to optimize training.</p> <p data-svelte-h="svelte-n9nqqn">Refer to the table below to quickly help you identify the features relevant to your training scenario.</p> <table data-svelte-h="svelte-ig1emq"><thead><tr><th>Feature</th> <th>Training speed</th> <th>Memory usage</th></tr></thead> <tbody><tr><td>batch size</td> <td>yes</td> <td>yes</td></tr> <tr><td>gradient accumulation</td> <td>no</td> <td>yes</td></tr> <tr><td>gradient checkpointing</td> <td>no</td> <td>yes</td></tr> <tr><td>mixed precision</td> <td>yes</td> <td>depends</td></tr> <tr><td>optimizers</td> <td>yes</td> <td>yes</td></tr> <tr><td>data preloading</td> <td>yes</td> <td>no</td></tr> <tr><td>torch_empty_cache_steps</td> <td>no</td> <td>yes</td></tr> <tr><td>torch.compile</td> <td>yes</td> <td>no</td></tr> <tr><td>PEFT</td> <td>no</td> <td>yes</td></tr></tbody></table> <h2 class="relative group"><a id="trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Trainer</span></h2> <p data-svelte-h="svelte-d32ivx"><a href="./trainer">Trainer</a> supports many useful training features that can be configured through <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>. This section highlights some of the more important features for optimizing training.</p> <h3 class="relative group"><a id="batch-size" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#batch-size"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Batch size</span></h3> <p data-svelte-h="svelte-wr8hhu">Batch size is one of the most important hyperparameters for efficient GPU training because it affects memory usage and training speed. Larger batch sizes lead to faster training because it takes advantage of a GPUs parallel processing power. It is recommended to use batch sizes that are powers of 2, such as 8, 64, 128, 256, 512, etc. The batch size depends on your GPU and the models data type.</p> <p data-svelte-h="svelte-1gek0ou">Configure <code>per_device_train_batch_size()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">256</span>, | |
| per_device_eval_batch_size=<span class="hljs-number">256</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1is384h">Refer to the NVIDIA <a href="https://docs.nvidia.com/deeplearning/performance/dl-performance-fully-connected/index.html#input-features" rel="nofollow">Performance</a> guide to learn more about how input features and output neuron counts and batch size affect performance. These are involved in the General Matrix Multiplications (GEMMs) performed by the GPU. Larger parameters are better for parallelization and efficiency.</p> <p data-svelte-h="svelte-hcpy7s">The <a href="https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" rel="nofollow">Tensor Core Requirements</a> section is also useful for selecting a batch size that maximizes the speed of tensor multiplication based on the data type and GPU. For example, multiples of 8 are recommended for fp16, unless it’s an A100 GPU, in which case use multiples of 64.</p> <p data-svelte-h="svelte-1ftr8dz">Finally, consider <a href="https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#dim-quantization" rel="nofollow">Dimension Quantization Effects</a> for smaller parameters. Tile quantization results when matrix dimensions aren’t divisible by a GPUs thread block tile size, causing the GPU to underutilize its resources. Selecting the correct batch size multiplier, such that the matrix is divisible by the tile size, can significantly speed up training.</p> <h3 class="relative group"><a id="gradient-accumulation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gradient-accumulation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Gradient accumulation</span></h3> <p data-svelte-h="svelte-1g17env">Gradient accumulation overcomes memory constraints - useful for fitting a very large model that otherwise wouldn’t fit on a single GPU - by accumulating gradients over multiple mini-batches before updating the parameters. This reduces memory by storing fewer gradients and enables training with a larger <em>effective batch size</em> because usually, the parameters are updated from a single batch of data. Training can slow down though due to the additional forward and backward passes introduced by gradient accumulation.</p> <p data-svelte-h="svelte-1tsssk9">Configure <code>per_device_train_batch_size()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to enable gradient accumulation.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| <span class="hljs-comment"># effective batch size of 64</span> | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-k7npc4">Try to avoid too many gradient accumulation steps because it can really slow down training. Consider the example below, where the maximum batch size that’ll fit on your GPU is 4. You should keep your batch size at 4 to better utilize the GPU.</p> <table data-svelte-h="svelte-bg5pgv"><thead><tr><th>batch size</th> <th>gradient accumulation steps</th> <th>effective batch size</th> <th></th></tr></thead> <tbody><tr><td>1</td> <td>64</td> <td>64</td> <td>👎</td></tr> <tr><td>4</td> <td>16</td> <td>64</td> <td>👍</td></tr></tbody></table> <h3 class="relative group"><a id="gradient-checkpointing" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gradient-checkpointing"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Gradient checkpointing</span></h3> <p data-svelte-h="svelte-18nswaa">Gradient checkpointing reduces memory usage by only storing some of the intermediate activations during the backward pass and recomputing the remaining activations. This avoids storing <em>all</em> of the intermediate activations from the forward pass, which can require a lot of memory overhead. However, it comes at the cost of slower training speed (~20%).</p> <p data-svelte-h="svelte-3eudc9">Configure <code>gradient_checkpointing()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to enable gradient checkpointing.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| gradient_checkpointing=<span class="hljs-literal">True</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="mixed-precision" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#mixed-precision"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Mixed precision</span></h3> <p data-svelte-h="svelte-45ttfx">Mixed precision accelerates training speed by performing some calculations in half-precision (fp16) and some in full-precision (fp32). The half-precision calculations boosts training speed because it’s not as computationally expensive as performing the calculations in full-precision. Meanwhile, preserving some of the calculations in full-precision maintains accuracy.</p> <p data-svelte-h="svelte-1qka8ue">There are several data types available for mixed precision training.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">fp16 </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">bf16 </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">tf32 </div></div> <div class="language-select"><p data-svelte-h="svelte-1bscn8l">The main advantage of mixed precision training is saving the activations in fp16.</p> <p data-svelte-h="svelte-1c90l0">Configure <code>fp16()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to enable mixed precision training with the fp16 data type.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| gradient_checkpointing=<span class="hljs-literal">True</span>, | |
| fp16=<span class="hljs-literal">True</span>. | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-dmja6y">fp16 isn’t memory-optimized because the gradients that are computed in fp16 are converted back to fp32 during the optimization step. You may end up using more GPU memory, especially for small batch sizes, because there are now two versions (fp16 and fp32) of the model on the GPU.</p> </div> <h3 class="relative group"><a id="optimizers" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#optimizers"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Optimizers</span></h3> <p data-svelte-h="svelte-16728rv">Transformers implements the <a href="https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html" rel="nofollow">AdamW (adamw_torch)</a> optimizer from PyTorch by default. But because it stores a weighted average of past gradients, it requires additional memory proportional to the number of model parameters to store the past gradients. This can be an issue when training very large models, and in such cases, you should consider choosing a different optimizer. For example, if you have <a href="https://nvidia.github.io/apex/index.html" rel="nofollow">Apex</a> installed on either <a href="https://github.com/NVIDIA/apex" rel="nofollow">NVIDIA</a> or <a href="https://github.com/ROCm/apex" rel="nofollow">AMD</a>, then using the <code>adamw_apex_fused</code> optimizer provides the fastest training for all AdamW optimizers.</p> <p data-svelte-h="svelte-1e5u9bf">Configure <code>optim()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to choose an optimizer.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| gradient_checkpointing=<span class="hljs-literal">True</span>, | |
| bf16=<span class="hljs-literal">True</span>, | |
| optim=<span class="hljs-string">"adamw_bnb_8bit"</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17at21e">There are many optimizers to choose from (refer to <a href="https://github.com/huggingface/transformers/blob/34f4080ff59b1668d919a1ba9f8bc4a3a2a3f478/src/transformers/training_args.py#L145" rel="nofollow">OptimizerNames</a> for a full supported list) depending on your training scenario. For example, Adafactor can significantly reduce memory requirements by storing a weighted average of a row or column instead of each element in the matrix at the cost of slower convergence. Another example is using a <a href="https://huggingface.co/docs/bitsandbytes" rel="nofollow">8-bit AdamW optimizer</a> from bitsandbytes to quantize optimizer states. The optimizer state is stored in a lower precision and dequantized before being used in the optimizer step.</p> <p data-svelte-h="svelte-50osy">Refer to the <a href="./optimizers">optimizer</a> guide for to learn about more specialized optimizers.</p> <h3 class="relative group"><a id="data-preloading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-preloading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Data preloading</span></h3> <p data-svelte-h="svelte-pe87pi">Data preloading loads and prepares batches of data in advance on the CPU to ensure the GPU is continuously working, reducing GPU idling and increasing utilization. There are two ways to preload data to ensure the GPU is always working.</p> <ol data-svelte-h="svelte-1rkn94"><li>Allocate pinned memory on the CPU to store the data and transfer it directly to the GPU.</li> <li>Increase the number of CPU threads or workers to preload the data faster.</li></ol> <p data-svelte-h="svelte-txknbc">Configure <code>dataloader_pin_memory()</code> and <code>dataloader_num_workers()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to allocate pinned memory and increase the number of workers.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| gradient_checkpointing=<span class="hljs-literal">True</span>, | |
| bf16=<span class="hljs-literal">True</span>, | |
| optim=<span class="hljs-string">"adamw_bnb_8bit"</span>, | |
| dataloader_pin_memory=<span class="hljs-literal">True</span>, | |
| dataloader_num_workers=<span class="hljs-number">4</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="pytorch" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch</span></h2> <p data-svelte-h="svelte-13irw7t">PyTorch provides several features for reducing memory requirements and increasing training speed. These features can often be enabled in Transformers by only adding a few lines of code.</p> <h3 class="relative group"><a id="torchemptycachesteps" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchemptycachesteps"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.empty_cache_steps</span></h3> <p data-svelte-h="svelte-1nu9tjf">The <a href="https://pytorch.org/docs/stable/generated/torch.cuda.empty_cache.html#torch.cuda.empty_cache" rel="nofollow">torch.cuda.empty_cache</a> function releases unused cached memory, which can help avoid out-of-memory (OOM) errors at the cost of ~10% slower training.</p> <p data-svelte-h="svelte-10dt3jc">Use <a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.torch_empty_cache_steps" rel="nofollow">torch_empty_cache_steps()</a> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to enable it after a certain number of training steps.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| gradient_checkpointing=<span class="hljs-literal">True</span>, | |
| bf16=<span class="hljs-literal">True</span>, | |
| optim=<span class="hljs-string">"adamw_bnb_8bit"</span>, | |
| dataloader_pin_memory=<span class="hljs-literal">True</span>, | |
| dataloader_num_workers=<span class="hljs-number">4</span>, | |
| torch_empty_cache_steps=<span class="hljs-number">4</span>, | |
| )<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="torchcompile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchcompile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.compile</span></h3> <p data-svelte-h="svelte-bl5sl8"><a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html" rel="nofollow">torch.compile</a> compiles PyTorch code into optimized kernels that significantly speed up training. This feature relies on TorchDynamo to capture PyTorch graphs with the Frame Evaluation API. The graph can be further compiled into optimized kernels for different backends.</p> <p data-svelte-h="svelte-h2pswg">Configure <code>torch_compile()</code> in <a href="/docs/transformers/pr_36839/en/main_classes/trainer#transformers.TrainingArguments">TrainingArguments</a> to enable it, and configure <a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.torch_compile_backend" rel="nofollow">torch_compile_backend()</a> to select a backend to use.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments | |
| args = TrainingArguments( | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">16</span>, | |
| gradient_checkpointing=<span class="hljs-literal">True</span>, | |
| bf16=<span class="hljs-literal">True</span>, | |
| optim=<span class="hljs-string">"adamw_bnb_8bit"</span>, | |
| dataloader_pin_memory=<span class="hljs-literal">True</span>, | |
| dataloader_num_workers=<span class="hljs-number">4</span>, | |
| torch_empty_cache_steps=<span class="hljs-number">4</span>, | |
| torch_compile=<span class="hljs-literal">True</span>, | |
| torch_compile_backend=<span class="hljs-string">"inductor"</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ci7jd6">Refer to the table below to help you choose the right backend for your training scenario.</p> <table data-svelte-h="svelte-3qwh9d"><thead><tr><th>backend</th> <th>description</th> <th>goal</th></tr></thead> <tbody><tr><td>eager</td> <td>uses PyTorch to run extracted GraphModule</td> <td>debugging</td></tr> <tr><td>aot_eager</td> <td>uses PyTorch eager mode for AOTAutograd’s extracted forward and backward graphs</td> <td>debugging</td></tr> <tr><td>inductor</td> <td>uses TorchInductor with AOTAutograd and CUDA Graphs by leveraging Triton kernels</td> <td>training and inference</td></tr> <tr><td>nvfuser</td> <td>uses nvFuser with TorchScript</td> <td>training and inference</td></tr> <tr><td>aot_nvfuser</td> <td>uses nvFuser with AOTAutograd</td> <td>training and inference</td></tr> <tr><td>aot_cudagraphs</td> <td>uses CUDA Graphs with AOTAutograd</td> <td>training and inference</td></tr> <tr><td>ofi</td> <td>uses TorchScripts <a href="https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html#torch-jit-optimize-for-inference" rel="nofollow">optimize_for_inference</a></td> <td>inference</td></tr> <tr><td>fx2trt</td> <td>uses <a href="https://pytorch.org/TensorRT/tutorials/getting_started_with_fx_path.html" rel="nofollow">Torch-TensorRT</a></td> <td>inference</td></tr> <tr><td>onnxrt</td> <td>uses <a href="https://onnxruntime.ai/" rel="nofollow">ONNX-RT</a> for CPU and GPU inference</td> <td>inference</td></tr> <tr><td>ipex</td> <td>uses <a href="https://github.com/intel/intel-extension-for-pytorch" rel="nofollow">IPEX</a> for CPU inference</td> <td>inference</td></tr></tbody></table> <h3 class="relative group"><a id="scaled-dot-production-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#scaled-dot-production-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Scaled dot production attention</span></h3> <p data-svelte-h="svelte-sep83h"><a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">torch.nn.functional.scaled_dot_product_attention</a> (SDPA) is a native PyTorch implementation of the scaled dot product attention mechanism. SDPA is more efficient and optimized than the original attention mechanism in transformer models. It supports three types of scaled dot product attention.</p> <ul data-svelte-h="svelte-jhgahq"><li><a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow">FlashAttention2</a> is automatically enabled for models with the fp16 or bf16 torch type. Make sure to cast your model to the appropriate type first.</li> <li><a href="https://github.com/facebookresearch/xformers" rel="nofollow">xFormers</a> or Memory-Efficient Attention supports models with the fp32 torch type.</li> <li>C++ implementation of scaled dot product attention.</li></ul> <p data-svelte-h="svelte-1ljwkb0">SDPA is enabled by default for PyTorch 2.1.1+, but it can be explicitly enabled by setting <code>attn_implementation="sdpa"</code> in <a href="/docs/transformers/pr_36839/en/main_classes/model#transformers.PreTrainedModel.from_pretrained">from_pretrained()</a>.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained(<span class="hljs-string">"meta-llama/Llama-3.1-8B"</span>, device_map=<span class="hljs-string">"auto"</span>, attn_implementation=<span class="hljs-string">"sdpa"</span>)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/en/perf_train_gpu_one.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1bm5psi = { | |
| assets: "/docs/transformers/pr_36839/en", | |
| base: "/docs/transformers/pr_36839/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/transformers/pr_36839/en/_app/immutable/entry/start.6be8d590.js"), | |
| import("/docs/transformers/pr_36839/en/_app/immutable/entry/app.09748b4b.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 421], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 49.4 kB
- Xet hash:
- 0aa0646441c22716f2eee61e71b800532f11792e4d336f9c902260cbdeb453d8
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.