Buckets:

hf-doc-build/doc-dev / accelerate /pr_4021 /en /concept_guides /gradient_synchronization.html
download
raw
30.6 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Gradient synchronization&quot;,&quot;local&quot;:&quot;gradient-synchronization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;The slowdown in gradient accumulation&quot;,&quot;local&quot;:&quot;the-slowdown-in-gradient-accumulation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Solving the slowdown problem&quot;,&quot;local&quot;:&quot;solving-the-slowdown-problem&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Just how much of a slowdown is there, and easy mistakes you can make&quot;,&quot;local&quot;:&quot;just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;no_sync requires additional GPU memory when using FSDP&quot;,&quot;local&quot;:&quot;nosync-requires-additional-gpu-memory-when-using-fsdp&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/accelerate/pr_4021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/scheduler.b9285784.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/singletons.7547c222.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.6d423e5c.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/paths.d42c9205.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/preload-helper.b0bd19d1.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.26bc89a1.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/0.0e7c56e8.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/15.819e6979.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a0ae628.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/CodeBlock.844ff9c3.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Gradient synchronization&quot;,&quot;local&quot;:&quot;gradient-synchronization&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;The slowdown in gradient accumulation&quot;,&quot;local&quot;:&quot;the-slowdown-in-gradient-accumulation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Solving the slowdown problem&quot;,&quot;local&quot;:&quot;solving-the-slowdown-problem&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Just how much of a slowdown is there, and easy mistakes you can make&quot;,&quot;local&quot;:&quot;just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;no_sync requires additional GPU memory when using FSDP&quot;,&quot;local&quot;:&quot;nosync-requires-additional-gpu-memory-when-using-fsdp&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="gradient-synchronization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gradient-synchronization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Gradient synchronization</span></h1> <p data-svelte-h="svelte-1bq702f">PyTorch’s distributed module operates by communicating back and forth between all of the GPUs in your system.
This communication takes time, and ensuring all processes know the states of each other happens at particular triggerpoints
when using the <code>ddp</code> module.</p> <p data-svelte-h="svelte-fqoolu">These triggerpoints are added to the PyTorch model, specifically their <code>forward()</code> and <code>backward()</code> methods.
This happens when the model is wrapped with <code>DistributedDataParallel</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch.nn <span class="hljs-keyword">as</span> nn
<span class="hljs-keyword">from</span> torch.nn.parallel <span class="hljs-keyword">import</span> DistributedDataParallel
model = nn.Linear(<span class="hljs-number">10</span>, <span class="hljs-number">10</span>)
ddp_model = DistributedDataParallel(model)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17sr0v0">In Accelerate this conversion happens automatically when calling <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.prepare">prepare()</a> and passing in your model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-addition">+ from accelerate import Accelerator</span>
<span class="hljs-addition">+ accelerator = Accelerator()</span>
import torch.nn as nn
<span class="hljs-deletion">- from torch.nn.parallel import DistributedDataParallel</span>
model = nn.Linear(10,10)
<span class="hljs-addition">+ model = accelerator.prepare(model)</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="the-slowdown-in-gradient-accumulation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-slowdown-in-gradient-accumulation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The slowdown in gradient accumulation</span></h2> <p data-svelte-h="svelte-byuygg">You now understand that PyTorch adds hooks to the <code>forward</code> and <code>backward</code> method of your PyTorch model when
training in a distributed setup. But how does this risk slowing down your code?</p> <p data-svelte-h="svelte-w9p09z">In DDP (distributed data parallel), the specific order in which processes are performed and ran are expected
at specific points and these must also occur at roughly the same time before moving on.</p> <p data-svelte-h="svelte-1qb001k">The most direct example is when you update model parameters through
<code>optimizer.step()</code>.
Without gradient accumulation, all instances of the model need to have updated
their gradients computed, collated, and updated before moving on to the next
batch of data.
When performing gradient accumulation, you accumulate <code>n</code> loss gradients and
skip <code>optimizer.step()</code> until <code>n</code> batches have been reached. As all training
processes only need to synchronize by the time <code>optimizer.step()</code> is called,
without any modification to your training step, this needless inter-process
communication can cause a significant slowdown.</p> <p data-svelte-h="svelte-14tkwrb">How can you avoid this overhead?</p> <h2 class="relative group"><a id="solving-the-slowdown-problem" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#solving-the-slowdown-problem"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Solving the slowdown problem</span></h2> <p data-svelte-h="svelte-olrkxe">Since you are skipping model parameter updates when training on these batches, their gradients do not need to be synchronized until the point where <code>optimizer.step()</code> is actually called.
PyTorch cannot automagically tell when you need to do this, but they do provide a tool to help through the <a href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync" rel="nofollow"><code>no_sync</code></a> context manager
that is added to your model after converting it to DDP.</p> <p data-svelte-h="svelte-cn8x56">Under this context manager, PyTorch will skip synchronizing the gradients when
<code>.backward()</code> is called, and the first call to <code>.backward()</code> outside this
context manager will trigger the synchronization. See an example below:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
<span class="hljs-keyword">for</span> index, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataloader):
inputs, targets = batch
<span class="hljs-comment"># Trigger gradient synchronization on the last batch</span>
<span class="hljs-keyword">if</span> index != (<span class="hljs-built_in">len</span>(dataloader) - <span class="hljs-number">1</span>):
<span class="hljs-keyword">with</span> ddp_model.no_sync():
<span class="hljs-comment"># Gradients only accumulate</span>
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
<span class="hljs-keyword">else</span>:
<span class="hljs-comment"># Gradients finally sync</span>
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
optimizer.step()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1odb5lu">In Accelerate to make this an API that can be called no matter the training device (though it may not do anything if you are not in a distributed system!),
<code>ddp_model.no_sync</code> gets replaced with <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.no_sync">no_sync()</a> and operates the same way:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
for index, batch in enumerate(dataloader):
inputs, targets = batch
# Trigger gradient synchronization on the last batch
if index != (len(dataloader)-1):
<span class="hljs-deletion">- with ddp_model.no_sync():</span>
<span class="hljs-addition">+ with accelerator.no_sync(model):</span>
# Gradients only accumulate
outputs = ddp_model(inputs)
loss = loss_func(outputs, targets)
accelerator.backward(loss)
else:
# Gradients finally sync
outputs = ddp_model(inputs)
loss = loss_func(outputs)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-a0cagt">As you may expect, the <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> function wraps around this conditional check by keeping track of the current batch number, leaving you with the final
gradient accumulation API:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> dataloader:
<span class="hljs-keyword">with</span> accelerator.accumulate(model):
optimizer.zero_grad()
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1b3vx79">As a result, you should either use <em><code>accelerator.accumulate</code> or <code>accelerator.no_sync</code></em> when it comes to API choice.</p> <h2 class="relative group"><a id="just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Just how much of a slowdown is there, and easy mistakes you can make</span></h2> <p data-svelte-h="svelte-1hdyapv">To set up a realistic example, consider the following setup:</p> <ul data-svelte-h="svelte-1d218dd"><li>Two single-GPU T4 nodes and one node with two GPUs</li> <li>Each GPU is a T4, and are hosted on GCP</li> <li>The script used is a modification of the <a href="https://github.com/muellerzr/timing_experiments/blob/main/baseline.py" rel="nofollow">NLP Example</a> script</li> <li>Batch size per GPU is 16, and gradients are accumulated every 4 steps</li></ul> <p data-svelte-h="svelte-5qwyq6">All scripts are available in <a href="https://github.com/muellerzr/timing_experiments" rel="nofollow">this repository</a>.</p> <p data-svelte-h="svelte-1yxh0lm">If not careful about gradient synchronization and GPU communication, a <em>large</em> amount of time can be wasted
from when these GPUs communicate to each other during unnecessary periods.</p> <p data-svelte-h="svelte-1n0bzbn">By how much?</p> <p data-svelte-h="svelte-1i6grvn">Reference:</p> <ul data-svelte-h="svelte-1pa8kyp"><li>Baseline: uses no synchronization practices discussed here</li> <li><code>no_sync</code> improperly: <code>no_sync</code> only around the <code>backward</code> call, not the <code>forward</code></li> <li><code>no_sync</code>: using the <code>no_sync</code> pattern properly</li> <li><code>accumulate</code>: using <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> properly</li></ul> <p data-svelte-h="svelte-vdaq6p">Below are the average seconds per batch iterating over 29 batches of data for each setup on both a single node and on the dual-node setup:</p> <table data-svelte-h="svelte-sjfg3j"><thead><tr><th align="center"></th> <th align="center">Baseline</th> <th align="center"><code>no_sync</code> improperly</th> <th align="center"><code>no_sync</code></th> <th align="center"><code>accumulate</code></th></tr></thead> <tbody><tr><td align="center">Multi-Node</td> <td align="center">2±0.01s</td> <td align="center">2.13±0.08s</td> <td align="center"><strong>0.91±0.11s</strong></td> <td align="center"><strong>0.91±0.11s</strong></td></tr> <tr><td align="center">Single Node</td> <td align="center">0.50±0.01s</td> <td align="center">0.50±0.01s</td> <td align="center"><strong>0.41±0.015s</strong></td> <td align="center"><strong>0.41±0.015s</strong></td></tr></tbody></table> <p data-svelte-h="svelte-yk2mxh">As you can see, if you are not careful about how you set up your gradient synchronization, you can get upwards of more than a 2x slowdown during training!</p> <p data-svelte-h="svelte-1ep09o2">If you are worried about making sure everything is done properly, we highly recommend utilizing the <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> function and passing in
<code>gradient_accumulation_steps</code> or <code>gradient_accumulation_plugin</code> to the <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> object so Accelerate can handle this for you.</p> <h3 class="relative group"><a id="nosync-requires-additional-gpu-memory-when-using-fsdp" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#nosync-requires-additional-gpu-memory-when-using-fsdp"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>no_sync requires additional GPU memory when using FSDP</span></h3> <p data-svelte-h="svelte-rzu0c8">Be aware that not syncing gradients can have adverse effects while performing FSDP training. As it has been warned in <code>torch</code>, the <a href="https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync" rel="nofollow"><code>no_sync</code> context manager for FSDP</a> will require additional memory.</p> <p data-svelte-h="svelte-zuof06">Therefore in memory intensive situations while using FSDP, we recommend to set <code>sync_each_batch</code> to <code>True</code> in the <a href="/docs/accelerate/pr_4021/en/package_reference/utilities#accelerate.utils.GradientAccumulationPlugin">GradientAccumulationPlugin</a> to disable <code>no_sync</code>.</p> <p data-svelte-h="svelte-1s75xk0">See the example below where we fine-tune Mixtral (47B parameters) on 8 A100-80GB GPUs. We see that even for a modest <code>gradient_accumulation_steps=2</code> we quickly go out-of-memory (OOM) if <code>no_sync</code> is enabled. Again, this is due to additional memory overheads due to FSDP’s <code>no_sync</code>. However, if <code>no_sync</code> is disabled via <code>sync_each_batch=True</code>, then the memory consumption for <code>gradient_accumulation_steps=16</code> reverts to that of <code>gradient_accumulation_steps=1</code>.</p> <table data-svelte-h="svelte-iqaw47"><thead><tr><th align="center">Model</th> <th align="center"><code>no_sync</code> (accum=1)</th> <th align="center"><code>no_sync</code> (accum=2)</th> <th align="center"><code>no_sync</code> disabled (accum=16)</th></tr></thead> <tbody><tr><td align="center">mixtral 8x7B</td> <td align="center">69G</td> <td align="center">OOM</td> <td align="center">69G</td></tr></tbody></table> <blockquote class="warning" data-svelte-h="svelte-1be6loi"><p>Disabling <code>no_sync</code> means there <em>will be slowdown</em> due the extra data syncs, as explained by the earlier sections of this guide.</p></blockquote> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/gradient_synchronization.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1q7nz6m = {
assets: "/docs/accelerate/pr_4021/en",
base: "/docs/accelerate/pr_4021/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"),
import("/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 15],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
30.6 kB
·
Xet hash:
509b9e5cfb28bb07176ce56c5998c1d9a7680b17b64ffb3c88a779018c8d9c76

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.