Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Performing gradient accumulation with 🤗 Accelerate","local":"performing-gradient-accumulation-with--accelerate","sections":[{"title":"Converting it to 🤗 Accelerate","local":"converting-it-to--accelerate","sections":[],"depth":2},{"title":"Letting 🤗 Accelerate handle gradient accumulation","local":"letting--accelerate-handle-gradient-accumulation","sections":[],"depth":2},{"title":"The finished code","local":"the-finished-code","sections":[],"depth":2},{"title":"Self-contained example","local":"self-contained-example","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/accelerate/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/entry/start.2ea03080.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/scheduler.defa9a21.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/singletons.aff0b9fc.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/index.beade68d.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/paths.2c85d1a6.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/entry/app.e6812672.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/index.fe795e71.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/nodes/0.39c84d5d.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/nodes/42.dfc30789.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/Tip.179eb360.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/CodeBlock.42404125.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/EditOnGithub.0f575778.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Performing gradient accumulation with 🤗 Accelerate","local":"performing-gradient-accumulation-with--accelerate","sections":[{"title":"Converting it to 🤗 Accelerate","local":"converting-it-to--accelerate","sections":[],"depth":2},{"title":"Letting 🤗 Accelerate handle gradient accumulation","local":"letting--accelerate-handle-gradient-accumulation","sections":[],"depth":2},{"title":"The finished code","local":"the-finished-code","sections":[],"depth":2},{"title":"Self-contained example","local":"self-contained-example","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="performing-gradient-accumulation-with--accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#performing-gradient-accumulation-with--accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Performing gradient accumulation with 🤗 Accelerate</span></h1> <p data-svelte-h="svelte-1762c2w">Gradient accumulation is a technique where you can train on bigger batch sizes than | |
| your machine would normally be able to fit into memory. This is done by accumulating gradients over | |
| several batches, and only stepping the optimizer after a certain number of batches have been performed.</p> <p data-svelte-h="svelte-1u2vc4p">While technically standard gradient accumulation code would work fine in a distributed setup, it is not the most efficient | |
| method for doing so and you may experience considerable slowdowns!</p> <p data-svelte-h="svelte-elnr4k">In this tutorial you will see how to quickly setup gradient accumulation and perform it with the utilities provided in 🤗 Accelerate, | |
| which can total to adding just one new line of code!</p> <p data-svelte-h="svelte-n9bxrd">This example will use a very simplistic PyTorch training loop that performs gradient accumulation every two batches:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->device = <span class="hljs-string">"cuda"</span> | |
| model.to(device) | |
| gradient_accumulation_steps = <span class="hljs-number">2</span> | |
| <span class="hljs-keyword">for</span> index, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(training_dataloader): | |
| inputs, targets = batch | |
| inputs = inputs.to(device) | |
| targets = targets.to(device) | |
| outputs = model(inputs) | |
| loss = loss_function(outputs, targets) | |
| loss = loss / gradient_accumulation_steps | |
| loss.backward() | |
| <span class="hljs-keyword">if</span> (index + <span class="hljs-number">1</span>) % gradient_accumulation_steps == <span class="hljs-number">0</span>: | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="converting-it-to--accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#converting-it-to--accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Converting it to 🤗 Accelerate</span></h2> <p data-svelte-h="svelte-41b2zq">First the code shown earlier will be converted to utilize 🤗 Accelerate without the special gradient accumulation helper:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-addition">+ from accelerate import Accelerator</span> | |
| <span class="hljs-addition">+ accelerator = Accelerator()</span> | |
| <span class="hljs-addition">+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(</span> | |
| <span class="hljs-addition">+ model, optimizer, training_dataloader, scheduler</span> | |
| <span class="hljs-addition">+ )</span> | |
| for index, batch in enumerate(training_dataloader): | |
| inputs, targets = batch | |
| <span class="hljs-deletion">- inputs = inputs.to(device)</span> | |
| <span class="hljs-deletion">- targets = targets.to(device)</span> | |
| outputs = model(inputs) | |
| loss = loss_function(outputs, targets) | |
| loss = loss / gradient_accumulation_steps | |
| <span class="hljs-addition">+ accelerator.backward(loss)</span> | |
| if (index+1) % gradient_accumulation_steps == 0: | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-s0dec7">In its current state, this code is not going to perform gradient accumulation efficiently due to a process called gradient synchronization. Read more about that in the <a href="../concept_guides/gradient_synchronization">Concepts tutorial</a>!</p></div> <h2 class="relative group"><a id="letting--accelerate-handle-gradient-accumulation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#letting--accelerate-handle-gradient-accumulation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Letting 🤗 Accelerate handle gradient accumulation</span></h2> <p data-svelte-h="svelte-70ndtd">All that is left now is to let 🤗 Accelerate handle the gradient accumulation for us. To do so you should pass in a <code>gradient_accumulation_steps</code> parameter to <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a>, dictating the number | |
| of steps to perform before each call to <code>step()</code> and how to automatically adjust the loss during the call to <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.backward">backward()</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> from accelerate import Accelerator | |
| <span class="hljs-deletion">- accelerator = Accelerator()</span> | |
| <span class="hljs-addition">+ accelerator = Accelerator(gradient_accumulation_steps=2)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-j5vx6o">Alternatively, you can pass in a <code>gradient_accumulation_plugin</code> parameter to the <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> object’s <code>__init__</code>, which will allow you to further customize the gradient accumulation behavior. | |
| Read more about that in the <a href="../package_reference/accelerator#accelerate.utils.GradientAccumulationPlugin">GradientAccumulationPlugin</a> docs.</p> <p data-svelte-h="svelte-1yxrf5e">From here you can use the <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> context manager from inside your training loop to automatically perform the gradient accumulation for you! | |
| You just wrap it around the entire training part of our code:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-deletion">- for index, batch in enumerate(training_dataloader):</span> | |
| <span class="hljs-addition">+ for batch in training_dataloader:</span> | |
| <span class="hljs-addition">+ with accelerator.accumulate(model):</span> | |
| inputs, targets = batch | |
| outputs = model(inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1mcvl9w">You can remove all the special checks for the step number and the loss adjustment:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-deletion">- loss = loss / gradient_accumulation_steps</span> | |
| accelerator.backward(loss) | |
| <span class="hljs-deletion">- if (index+1) % gradient_accumulation_steps == 0:</span> | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g13ex2">As you can see the <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> is able to keep track of the batch number you are on and it will automatically know whether to step through the prepared optimizer and how to adjust the loss.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1niv7yj">Typically with gradient accumulation, you would need to adjust the number of steps to reflect the change in total batches you are | |
| training on. 🤗 Accelerate automagically does this for you by default. Behind the scenes we instantiate a <code>GradientAccumulationPlugin</code> configured to do this.</p></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-lg50r9">The <a href="/docs/accelerate/main/en/package_reference/state#accelerate.state.GradientState">state.GradientState</a> is sync’d with the active dataloader being iterated upon. As such it assumes naively that when we have reached the end of the dataloader everything will sync and a step will be performed. To disable this, set <code>sync_with_dataloader</code> to be <code>False</code> in the <code>GradientAccumulationPlugin</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate import Accelerator | |
| <span class="hljs-keyword">from</span> accelerate.utils import GradientAccumulationPlugin | |
| plugin = GradientAccumulationPlugin(<span class="hljs-attribute">sync_with_dataloader</span>=<span class="hljs-literal">False</span>) | |
| accelerator = Accelerator(<span class="hljs-built_in">..</span>., <span class="hljs-attribute">gradient_accumulation_plugin</span>=plugin)<!-- HTML_TAG_END --></pre></div></div> <h2 class="relative group"><a id="the-finished-code" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-finished-code"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The finished code</span></h2> <p data-svelte-h="svelte-9alw5p">Below is the finished implementation for performing gradient accumulation with 🤗 Accelerate</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| accelerator = Accelerator(gradient_accumulation_steps=<span class="hljs-number">2</span>) | |
| model, optimizer, training_dataloader, scheduler = accelerator.prepare( | |
| model, optimizer, training_dataloader, scheduler | |
| ) | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> training_dataloader: | |
| <span class="hljs-keyword">with</span> accelerator.accumulate(model): | |
| inputs, targets = batch | |
| outputs = model(inputs) | |
| loss = loss_function(outputs, targets) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1kw2co7">It’s important that <strong>only one forward/backward</strong> should be done inside the context manager <code>with accelerator.accumulate(model)</code>.</p></div> <p data-svelte-h="svelte-aw0h59">To learn more about what magic this wraps around, read the <a href="../concept_guides/gradient_synchronization">Gradient Synchronization concept guide</a></p> <h2 class="relative group"><a id="self-contained-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#self-contained-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Self-contained example</span></h2> <p data-svelte-h="svelte-1wysb4w">Here is a self-contained example that you can run to see gradient accumulation in action with 🤗 Accelerate:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">import</span> copy | |
| <span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> set_seed | |
| <span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> TensorDataset, DataLoader | |
| <span class="hljs-comment"># seed</span> | |
| set_seed(<span class="hljs-number">0</span>) | |
| <span class="hljs-comment"># define toy inputs and labels</span> | |
| x = torch.tensor([<span class="hljs-number">1.</span>, <span class="hljs-number">2.</span>, <span class="hljs-number">3.</span>, <span class="hljs-number">4.</span>, <span class="hljs-number">5.</span>, <span class="hljs-number">6.</span>, <span class="hljs-number">7.</span>, <span class="hljs-number">8.</span>]) | |
| y = torch.tensor([<span class="hljs-number">2.</span>, <span class="hljs-number">4.</span>, <span class="hljs-number">6.</span>, <span class="hljs-number">8.</span>, <span class="hljs-number">10.</span>, <span class="hljs-number">12.</span>, <span class="hljs-number">14.</span>, <span class="hljs-number">16.</span>]) | |
| gradient_accumulation_steps = <span class="hljs-number">4</span> | |
| batch_size = <span class="hljs-built_in">len</span>(x) // gradient_accumulation_steps | |
| <span class="hljs-comment"># define dataset and dataloader</span> | |
| dataset = TensorDataset(x, y) | |
| dataloader = DataLoader(dataset, batch_size=batch_size) | |
| <span class="hljs-comment"># define model, optimizer and loss function</span> | |
| model = torch.zeros((<span class="hljs-number">1</span>, <span class="hljs-number">1</span>), requires_grad=<span class="hljs-literal">True</span>) | |
| model_clone = copy.deepcopy(model) | |
| criterion = torch.nn.MSELoss() | |
| model_optimizer = torch.optim.SGD([model], lr=<span class="hljs-number">0.02</span>) | |
| accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) | |
| model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader) | |
| model_clone_optimizer = torch.optim.SGD([model_clone], lr=<span class="hljs-number">0.02</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"initial model weight is <span class="hljs-subst">{model.mean().item():<span class="hljs-number">.5</span>f}</span>"</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"initial model weight is <span class="hljs-subst">{model_clone.mean().item():<span class="hljs-number">.5</span>f}</span>"</span>) | |
| <span class="hljs-keyword">for</span> i, (inputs, labels) <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataloader): | |
| <span class="hljs-keyword">with</span> accelerator.accumulate(model): | |
| inputs = inputs.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>) | |
| <span class="hljs-built_in">print</span>(i, inputs.flatten()) | |
| labels = labels.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>) | |
| outputs = inputs @ model | |
| loss = criterion(outputs, labels) | |
| accelerator.backward(loss) | |
| model_optimizer.step() | |
| model_optimizer.zero_grad() | |
| loss = criterion(x.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>) @ model_clone, y.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>)) | |
| model_clone_optimizer.zero_grad() | |
| loss.backward() | |
| model_clone_optimizer.step() | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"w/ accumulation, the final model weight is <span class="hljs-subst">{model.mean().item():<span class="hljs-number">.5</span>f}</span>"</span>) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">f"w/o accumulation, the final model weight is <span class="hljs-subst">{model_clone.mean().item():<span class="hljs-number">.5</span>f}</span>"</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attribute">initial</span> model weight is <span class="hljs-number">0</span>.<span class="hljs-number">00000</span> | |
| <span class="hljs-attribute">initial</span> model weight is <span class="hljs-number">0</span>.<span class="hljs-number">00000</span> | |
| <span class="hljs-attribute">0</span> tensor([<span class="hljs-number">1</span>., <span class="hljs-number">2</span>.]) | |
| <span class="hljs-attribute">1</span> tensor([<span class="hljs-number">3</span>., <span class="hljs-number">4</span>.]) | |
| <span class="hljs-attribute">2</span> tensor([<span class="hljs-number">5</span>., <span class="hljs-number">6</span>.]) | |
| <span class="hljs-attribute">3</span> tensor([<span class="hljs-number">7</span>., <span class="hljs-number">8</span>.]) | |
| <span class="hljs-attribute">w</span>/ accumulation, the final model weight is <span class="hljs-number">2</span>.<span class="hljs-number">04000</span> | |
| <span class="hljs-attribute">w</span>/o accumulation, the final model weight is <span class="hljs-number">2</span>.<span class="hljs-number">04000</span><!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/usage_guides/gradient_accumulation.md" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1fyccrg = { | |
| assets: "/docs/accelerate/main/en", | |
| base: "/docs/accelerate/main/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/accelerate/main/en/_app/immutable/entry/start.2ea03080.js"), | |
| import("/docs/accelerate/main/en/_app/immutable/entry/app.e6812672.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 42], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 35.5 kB
- Xet hash:
- d00ff7eacb2ffaaeefdefb8bc41ca32192bd899b0d813ec9020babb9d3436fc0
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.