Buckets:

hf-doc-build/doc-dev / accelerate /main /en /usage_guides /gradient_accumulation.html
rtrm's picture
download
raw
35.5 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Performing gradient accumulation with 🤗 Accelerate&quot;,&quot;local&quot;:&quot;performing-gradient-accumulation-with--accelerate&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Converting it to 🤗 Accelerate&quot;,&quot;local&quot;:&quot;converting-it-to--accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Letting 🤗 Accelerate handle gradient accumulation&quot;,&quot;local&quot;:&quot;letting--accelerate-handle-gradient-accumulation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;The finished code&quot;,&quot;local&quot;:&quot;the-finished-code&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Self-contained example&quot;,&quot;local&quot;:&quot;self-contained-example&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/accelerate/main/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/entry/start.2ea03080.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/scheduler.defa9a21.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/singletons.aff0b9fc.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/index.beade68d.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/paths.2c85d1a6.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/entry/app.e6812672.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/index.fe795e71.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/nodes/0.39c84d5d.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/nodes/42.dfc30789.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/Tip.179eb360.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/CodeBlock.42404125.js">
<link rel="modulepreload" href="/docs/accelerate/main/en/_app/immutable/chunks/EditOnGithub.0f575778.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Performing gradient accumulation with 🤗 Accelerate&quot;,&quot;local&quot;:&quot;performing-gradient-accumulation-with--accelerate&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Converting it to 🤗 Accelerate&quot;,&quot;local&quot;:&quot;converting-it-to--accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Letting 🤗 Accelerate handle gradient accumulation&quot;,&quot;local&quot;:&quot;letting--accelerate-handle-gradient-accumulation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;The finished code&quot;,&quot;local&quot;:&quot;the-finished-code&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Self-contained example&quot;,&quot;local&quot;:&quot;self-contained-example&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="performing-gradient-accumulation-with--accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#performing-gradient-accumulation-with--accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Performing gradient accumulation with 🤗 Accelerate</span></h1> <p data-svelte-h="svelte-1762c2w">Gradient accumulation is a technique where you can train on bigger batch sizes than
your machine would normally be able to fit into memory. This is done by accumulating gradients over
several batches, and only stepping the optimizer after a certain number of batches have been performed.</p> <p data-svelte-h="svelte-1u2vc4p">While technically standard gradient accumulation code would work fine in a distributed setup, it is not the most efficient
method for doing so and you may experience considerable slowdowns!</p> <p data-svelte-h="svelte-elnr4k">In this tutorial you will see how to quickly setup gradient accumulation and perform it with the utilities provided in 🤗 Accelerate,
which can total to adding just one new line of code!</p> <p data-svelte-h="svelte-n9bxrd">This example will use a very simplistic PyTorch training loop that performs gradient accumulation every two batches:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->device = <span class="hljs-string">&quot;cuda&quot;</span>
model.to(device)
gradient_accumulation_steps = <span class="hljs-number">2</span>
<span class="hljs-keyword">for</span> index, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(training_dataloader):
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss = loss / gradient_accumulation_steps
loss.backward()
<span class="hljs-keyword">if</span> (index + <span class="hljs-number">1</span>) % gradient_accumulation_steps == <span class="hljs-number">0</span>:
optimizer.step()
scheduler.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="converting-it-to--accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#converting-it-to--accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Converting it to 🤗 Accelerate</span></h2> <p data-svelte-h="svelte-41b2zq">First the code shown earlier will be converted to utilize 🤗 Accelerate without the special gradient accumulation helper:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-addition">+ from accelerate import Accelerator</span>
<span class="hljs-addition">+ accelerator = Accelerator()</span>
<span class="hljs-addition">+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(</span>
<span class="hljs-addition">+ model, optimizer, training_dataloader, scheduler</span>
<span class="hljs-addition">+ )</span>
for index, batch in enumerate(training_dataloader):
inputs, targets = batch
<span class="hljs-deletion">- inputs = inputs.to(device)</span>
<span class="hljs-deletion">- targets = targets.to(device)</span>
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss = loss / gradient_accumulation_steps
<span class="hljs-addition">+ accelerator.backward(loss)</span>
if (index+1) % gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-s0dec7">In its current state, this code is not going to perform gradient accumulation efficiently due to a process called gradient synchronization. Read more about that in the <a href="../concept_guides/gradient_synchronization">Concepts tutorial</a>!</p></div> <h2 class="relative group"><a id="letting--accelerate-handle-gradient-accumulation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#letting--accelerate-handle-gradient-accumulation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Letting 🤗 Accelerate handle gradient accumulation</span></h2> <p data-svelte-h="svelte-70ndtd">All that is left now is to let 🤗 Accelerate handle the gradient accumulation for us. To do so you should pass in a <code>gradient_accumulation_steps</code> parameter to <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a>, dictating the number
of steps to perform before each call to <code>step()</code> and how to automatically adjust the loss during the call to <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.backward">backward()</a>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> from accelerate import Accelerator
<span class="hljs-deletion">- accelerator = Accelerator()</span>
<span class="hljs-addition">+ accelerator = Accelerator(gradient_accumulation_steps=2)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-j5vx6o">Alternatively, you can pass in a <code>gradient_accumulation_plugin</code> parameter to the <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> object’s <code>__init__</code>, which will allow you to further customize the gradient accumulation behavior.
Read more about that in the <a href="../package_reference/accelerator#accelerate.utils.GradientAccumulationPlugin">GradientAccumulationPlugin</a> docs.</p> <p data-svelte-h="svelte-1yxrf5e">From here you can use the <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> context manager from inside your training loop to automatically perform the gradient accumulation for you!
You just wrap it around the entire training part of our code:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-deletion">- for index, batch in enumerate(training_dataloader):</span>
<span class="hljs-addition">+ for batch in training_dataloader:</span>
<span class="hljs-addition">+ with accelerator.accumulate(model):</span>
inputs, targets = batch
outputs = model(inputs)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1mcvl9w">You can remove all the special checks for the step number and the loss adjustment:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-deletion">- loss = loss / gradient_accumulation_steps</span>
accelerator.backward(loss)
<span class="hljs-deletion">- if (index+1) % gradient_accumulation_steps == 0:</span>
optimizer.step()
scheduler.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1g13ex2">As you can see the <a href="/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> is able to keep track of the batch number you are on and it will automatically know whether to step through the prepared optimizer and how to adjust the loss.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1niv7yj">Typically with gradient accumulation, you would need to adjust the number of steps to reflect the change in total batches you are
training on. 🤗 Accelerate automagically does this for you by default. Behind the scenes we instantiate a <code>GradientAccumulationPlugin</code> configured to do this.</p></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-lg50r9">The <a href="/docs/accelerate/main/en/package_reference/state#accelerate.state.GradientState">state.GradientState</a> is sync’d with the active dataloader being iterated upon. As such it assumes naively that when we have reached the end of the dataloader everything will sync and a step will be performed. To disable this, set <code>sync_with_dataloader</code> to be <code>False</code> in the <code>GradientAccumulationPlugin</code>:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate import Accelerator
<span class="hljs-keyword">from</span> accelerate.utils import GradientAccumulationPlugin
plugin = GradientAccumulationPlugin(<span class="hljs-attribute">sync_with_dataloader</span>=<span class="hljs-literal">False</span>)
accelerator = Accelerator(<span class="hljs-built_in">..</span>., <span class="hljs-attribute">gradient_accumulation_plugin</span>=plugin)<!-- HTML_TAG_END --></pre></div></div> <h2 class="relative group"><a id="the-finished-code" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-finished-code"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The finished code</span></h2> <p data-svelte-h="svelte-9alw5p">Below is the finished implementation for performing gradient accumulation with 🤗 Accelerate</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
accelerator = Accelerator(gradient_accumulation_steps=<span class="hljs-number">2</span>)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
model, optimizer, training_dataloader, scheduler
)
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> training_dataloader:
<span class="hljs-keyword">with</span> accelerator.accumulate(model):
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1kw2co7">It’s important that <strong>only one forward/backward</strong> should be done inside the context manager <code>with accelerator.accumulate(model)</code>.</p></div> <p data-svelte-h="svelte-aw0h59">To learn more about what magic this wraps around, read the <a href="../concept_guides/gradient_synchronization">Gradient Synchronization concept guide</a></p> <h2 class="relative group"><a id="self-contained-example" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#self-contained-example"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Self-contained example</span></h2> <p data-svelte-h="svelte-1wysb4w">Here is a self-contained example that you can run to see gradient accumulation in action with 🤗 Accelerate:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> copy
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> set_seed
<span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> TensorDataset, DataLoader
<span class="hljs-comment"># seed</span>
set_seed(<span class="hljs-number">0</span>)
<span class="hljs-comment"># define toy inputs and labels</span>
x = torch.tensor([<span class="hljs-number">1.</span>, <span class="hljs-number">2.</span>, <span class="hljs-number">3.</span>, <span class="hljs-number">4.</span>, <span class="hljs-number">5.</span>, <span class="hljs-number">6.</span>, <span class="hljs-number">7.</span>, <span class="hljs-number">8.</span>])
y = torch.tensor([<span class="hljs-number">2.</span>, <span class="hljs-number">4.</span>, <span class="hljs-number">6.</span>, <span class="hljs-number">8.</span>, <span class="hljs-number">10.</span>, <span class="hljs-number">12.</span>, <span class="hljs-number">14.</span>, <span class="hljs-number">16.</span>])
gradient_accumulation_steps = <span class="hljs-number">4</span>
batch_size = <span class="hljs-built_in">len</span>(x) // gradient_accumulation_steps
<span class="hljs-comment"># define dataset and dataloader</span>
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size)
<span class="hljs-comment"># define model, optimizer and loss function</span>
model = torch.zeros((<span class="hljs-number">1</span>, <span class="hljs-number">1</span>), requires_grad=<span class="hljs-literal">True</span>)
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
model_optimizer = torch.optim.SGD([model], lr=<span class="hljs-number">0.02</span>)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)
model_clone_optimizer = torch.optim.SGD([model_clone], lr=<span class="hljs-number">0.02</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;initial model weight is <span class="hljs-subst">{model.mean().item():<span class="hljs-number">.5</span>f}</span>&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;initial model weight is <span class="hljs-subst">{model_clone.mean().item():<span class="hljs-number">.5</span>f}</span>&quot;</span>)
<span class="hljs-keyword">for</span> i, (inputs, labels) <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataloader):
<span class="hljs-keyword">with</span> accelerator.accumulate(model):
inputs = inputs.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>)
<span class="hljs-built_in">print</span>(i, inputs.flatten())
labels = labels.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>)
outputs = inputs @ model
loss = criterion(outputs, labels)
accelerator.backward(loss)
model_optimizer.step()
model_optimizer.zero_grad()
loss = criterion(x.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>) @ model_clone, y.view(-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;w/ accumulation, the final model weight is <span class="hljs-subst">{model.mean().item():<span class="hljs-number">.5</span>f}</span>&quot;</span>)
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;w/o accumulation, the final model weight is <span class="hljs-subst">{model_clone.mean().item():<span class="hljs-number">.5</span>f}</span>&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attribute">initial</span> model weight is <span class="hljs-number">0</span>.<span class="hljs-number">00000</span>
<span class="hljs-attribute">initial</span> model weight is <span class="hljs-number">0</span>.<span class="hljs-number">00000</span>
<span class="hljs-attribute">0</span> tensor([<span class="hljs-number">1</span>., <span class="hljs-number">2</span>.])
<span class="hljs-attribute">1</span> tensor([<span class="hljs-number">3</span>., <span class="hljs-number">4</span>.])
<span class="hljs-attribute">2</span> tensor([<span class="hljs-number">5</span>., <span class="hljs-number">6</span>.])
<span class="hljs-attribute">3</span> tensor([<span class="hljs-number">7</span>., <span class="hljs-number">8</span>.])
<span class="hljs-attribute">w</span>/ accumulation, the final model weight is <span class="hljs-number">2</span>.<span class="hljs-number">04000</span>
<span class="hljs-attribute">w</span>/o accumulation, the final model weight is <span class="hljs-number">2</span>.<span class="hljs-number">04000</span><!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/usage_guides/gradient_accumulation.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1fyccrg = {
assets: "/docs/accelerate/main/en",
base: "/docs/accelerate/main/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/accelerate/main/en/_app/immutable/entry/start.2ea03080.js"),
import("/docs/accelerate/main/en/_app/immutable/entry/app.e6812672.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 42],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
35.5 kB
·
Xet hash:
d00ff7eacb2ffaaeefdefb8bc41ca32192bd899b0d813ec9020babb9d3436fc0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.