Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Gradient synchronization","local":"gradient-synchronization","sections":[{"title":"The slowdown in gradient accumulation","local":"the-slowdown-in-gradient-accumulation","sections":[],"depth":2},{"title":"Solving the slowdown problem","local":"solving-the-slowdown-problem","sections":[],"depth":2},{"title":"Just how much of a slowdown is there, and easy mistakes you can make","local":"just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make","sections":[{"title":"no_sync requires additional GPU memory when using FSDP","local":"nosync-requires-additional-gpu-memory-when-using-fsdp","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/accelerate/pr_4021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/scheduler.b9285784.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/singletons.7547c222.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.6d423e5c.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/paths.d42c9205.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/preload-helper.b0bd19d1.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.26bc89a1.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/0.0e7c56e8.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/15.819e6979.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a0ae628.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/CodeBlock.844ff9c3.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Gradient synchronization","local":"gradient-synchronization","sections":[{"title":"The slowdown in gradient accumulation","local":"the-slowdown-in-gradient-accumulation","sections":[],"depth":2},{"title":"Solving the slowdown problem","local":"solving-the-slowdown-problem","sections":[],"depth":2},{"title":"Just how much of a slowdown is there, and easy mistakes you can make","local":"just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make","sections":[{"title":"no_sync requires additional GPU memory when using FSDP","local":"nosync-requires-additional-gpu-memory-when-using-fsdp","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="gradient-synchronization" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#gradient-synchronization"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Gradient synchronization</span></h1> <p data-svelte-h="svelte-1bq702f">PyTorch’s distributed module operates by communicating back and forth between all of the GPUs in your system. | |
| This communication takes time, and ensuring all processes know the states of each other happens at particular triggerpoints | |
| when using the <code>ddp</code> module.</p> <p data-svelte-h="svelte-fqoolu">These triggerpoints are added to the PyTorch model, specifically their <code>forward()</code> and <code>backward()</code> methods. | |
| This happens when the model is wrapped with <code>DistributedDataParallel</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch.nn <span class="hljs-keyword">as</span> nn | |
| <span class="hljs-keyword">from</span> torch.nn.parallel <span class="hljs-keyword">import</span> DistributedDataParallel | |
| model = nn.Linear(<span class="hljs-number">10</span>, <span class="hljs-number">10</span>) | |
| ddp_model = DistributedDataParallel(model)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17sr0v0">In Accelerate this conversion happens automatically when calling <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.prepare">prepare()</a> and passing in your model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-addition">+ from accelerate import Accelerator</span> | |
| <span class="hljs-addition">+ accelerator = Accelerator()</span> | |
| import torch.nn as nn | |
| <span class="hljs-deletion">- from torch.nn.parallel import DistributedDataParallel</span> | |
| model = nn.Linear(10,10) | |
| <span class="hljs-addition">+ model = accelerator.prepare(model)</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="the-slowdown-in-gradient-accumulation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#the-slowdown-in-gradient-accumulation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>The slowdown in gradient accumulation</span></h2> <p data-svelte-h="svelte-byuygg">You now understand that PyTorch adds hooks to the <code>forward</code> and <code>backward</code> method of your PyTorch model when | |
| training in a distributed setup. But how does this risk slowing down your code?</p> <p data-svelte-h="svelte-w9p09z">In DDP (distributed data parallel), the specific order in which processes are performed and ran are expected | |
| at specific points and these must also occur at roughly the same time before moving on.</p> <p data-svelte-h="svelte-1qb001k">The most direct example is when you update model parameters through | |
| <code>optimizer.step()</code>. | |
| Without gradient accumulation, all instances of the model need to have updated | |
| their gradients computed, collated, and updated before moving on to the next | |
| batch of data. | |
| When performing gradient accumulation, you accumulate <code>n</code> loss gradients and | |
| skip <code>optimizer.step()</code> until <code>n</code> batches have been reached. As all training | |
| processes only need to synchronize by the time <code>optimizer.step()</code> is called, | |
| without any modification to your training step, this needless inter-process | |
| communication can cause a significant slowdown.</p> <p data-svelte-h="svelte-14tkwrb">How can you avoid this overhead?</p> <h2 class="relative group"><a id="solving-the-slowdown-problem" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#solving-the-slowdown-problem"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Solving the slowdown problem</span></h2> <p data-svelte-h="svelte-olrkxe">Since you are skipping model parameter updates when training on these batches, their gradients do not need to be synchronized until the point where <code>optimizer.step()</code> is actually called. | |
| PyTorch cannot automagically tell when you need to do this, but they do provide a tool to help through the <a href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync" rel="nofollow"><code>no_sync</code></a> context manager | |
| that is added to your model after converting it to DDP.</p> <p data-svelte-h="svelte-cn8x56">Under this context manager, PyTorch will skip synchronizing the gradients when | |
| <code>.backward()</code> is called, and the first call to <code>.backward()</code> outside this | |
| context manager will trigger the synchronization. See an example below:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer) | |
| <span class="hljs-keyword">for</span> index, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataloader): | |
| inputs, targets = batch | |
| <span class="hljs-comment"># Trigger gradient synchronization on the last batch</span> | |
| <span class="hljs-keyword">if</span> index != (<span class="hljs-built_in">len</span>(dataloader) - <span class="hljs-number">1</span>): | |
| <span class="hljs-keyword">with</span> ddp_model.no_sync(): | |
| <span class="hljs-comment"># Gradients only accumulate</span> | |
| outputs = ddp_model(inputs) | |
| loss = loss_func(outputs) | |
| accelerator.backward(loss) | |
| <span class="hljs-keyword">else</span>: | |
| <span class="hljs-comment"># Gradients finally sync</span> | |
| outputs = ddp_model(inputs) | |
| loss = loss_func(outputs) | |
| accelerator.backward(loss) | |
| optimizer.step()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1odb5lu">In Accelerate to make this an API that can be called no matter the training device (though it may not do anything if you are not in a distributed system!), | |
| <code>ddp_model.no_sync</code> gets replaced with <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.no_sync">no_sync()</a> and operates the same way:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer) | |
| for index, batch in enumerate(dataloader): | |
| inputs, targets = batch | |
| # Trigger gradient synchronization on the last batch | |
| if index != (len(dataloader)-1): | |
| <span class="hljs-deletion">- with ddp_model.no_sync():</span> | |
| <span class="hljs-addition">+ with accelerator.no_sync(model):</span> | |
| # Gradients only accumulate | |
| outputs = ddp_model(inputs) | |
| loss = loss_func(outputs, targets) | |
| accelerator.backward(loss) | |
| else: | |
| # Gradients finally sync | |
| outputs = ddp_model(inputs) | |
| loss = loss_func(outputs) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-a0cagt">As you may expect, the <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> function wraps around this conditional check by keeping track of the current batch number, leaving you with the final | |
| gradient accumulation API:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->ddp_model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer) | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> dataloader: | |
| <span class="hljs-keyword">with</span> accelerator.accumulate(model): | |
| optimizer.zero_grad() | |
| inputs, targets = batch | |
| outputs = model(inputs) | |
| loss = loss_function(outputs, targets) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1b3vx79">As a result, you should either use <em><code>accelerator.accumulate</code> or <code>accelerator.no_sync</code></em> when it comes to API choice.</p> <h2 class="relative group"><a id="just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#just-how-much-of-a-slowdown-is-there-and-easy-mistakes-you-can-make"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Just how much of a slowdown is there, and easy mistakes you can make</span></h2> <p data-svelte-h="svelte-1hdyapv">To set up a realistic example, consider the following setup:</p> <ul data-svelte-h="svelte-1d218dd"><li>Two single-GPU T4 nodes and one node with two GPUs</li> <li>Each GPU is a T4, and are hosted on GCP</li> <li>The script used is a modification of the <a href="https://github.com/muellerzr/timing_experiments/blob/main/baseline.py" rel="nofollow">NLP Example</a> script</li> <li>Batch size per GPU is 16, and gradients are accumulated every 4 steps</li></ul> <p data-svelte-h="svelte-5qwyq6">All scripts are available in <a href="https://github.com/muellerzr/timing_experiments" rel="nofollow">this repository</a>.</p> <p data-svelte-h="svelte-1yxh0lm">If not careful about gradient synchronization and GPU communication, a <em>large</em> amount of time can be wasted | |
| from when these GPUs communicate to each other during unnecessary periods.</p> <p data-svelte-h="svelte-1n0bzbn">By how much?</p> <p data-svelte-h="svelte-1i6grvn">Reference:</p> <ul data-svelte-h="svelte-1pa8kyp"><li>Baseline: uses no synchronization practices discussed here</li> <li><code>no_sync</code> improperly: <code>no_sync</code> only around the <code>backward</code> call, not the <code>forward</code></li> <li><code>no_sync</code>: using the <code>no_sync</code> pattern properly</li> <li><code>accumulate</code>: using <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> properly</li></ul> <p data-svelte-h="svelte-vdaq6p">Below are the average seconds per batch iterating over 29 batches of data for each setup on both a single node and on the dual-node setup:</p> <table data-svelte-h="svelte-sjfg3j"><thead><tr><th align="center"></th> <th align="center">Baseline</th> <th align="center"><code>no_sync</code> improperly</th> <th align="center"><code>no_sync</code></th> <th align="center"><code>accumulate</code></th></tr></thead> <tbody><tr><td align="center">Multi-Node</td> <td align="center">2±0.01s</td> <td align="center">2.13±0.08s</td> <td align="center"><strong>0.91±0.11s</strong></td> <td align="center"><strong>0.91±0.11s</strong></td></tr> <tr><td align="center">Single Node</td> <td align="center">0.50±0.01s</td> <td align="center">0.50±0.01s</td> <td align="center"><strong>0.41±0.015s</strong></td> <td align="center"><strong>0.41±0.015s</strong></td></tr></tbody></table> <p data-svelte-h="svelte-yk2mxh">As you can see, if you are not careful about how you set up your gradient synchronization, you can get upwards of more than a 2x slowdown during training!</p> <p data-svelte-h="svelte-1ep09o2">If you are worried about making sure everything is done properly, we highly recommend utilizing the <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator.accumulate">accumulate()</a> function and passing in | |
| <code>gradient_accumulation_steps</code> or <code>gradient_accumulation_plugin</code> to the <a href="/docs/accelerate/pr_4021/en/package_reference/accelerator#accelerate.Accelerator">Accelerator</a> object so Accelerate can handle this for you.</p> <h3 class="relative group"><a id="nosync-requires-additional-gpu-memory-when-using-fsdp" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#nosync-requires-additional-gpu-memory-when-using-fsdp"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>no_sync requires additional GPU memory when using FSDP</span></h3> <p data-svelte-h="svelte-rzu0c8">Be aware that not syncing gradients can have adverse effects while performing FSDP training. As it has been warned in <code>torch</code>, the <a href="https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.no_sync" rel="nofollow"><code>no_sync</code> context manager for FSDP</a> will require additional memory.</p> <p data-svelte-h="svelte-zuof06">Therefore in memory intensive situations while using FSDP, we recommend to set <code>sync_each_batch</code> to <code>True</code> in the <a href="/docs/accelerate/pr_4021/en/package_reference/utilities#accelerate.utils.GradientAccumulationPlugin">GradientAccumulationPlugin</a> to disable <code>no_sync</code>.</p> <p data-svelte-h="svelte-1s75xk0">See the example below where we fine-tune Mixtral (47B parameters) on 8 A100-80GB GPUs. We see that even for a modest <code>gradient_accumulation_steps=2</code> we quickly go out-of-memory (OOM) if <code>no_sync</code> is enabled. Again, this is due to additional memory overheads due to FSDP’s <code>no_sync</code>. However, if <code>no_sync</code> is disabled via <code>sync_each_batch=True</code>, then the memory consumption for <code>gradient_accumulation_steps=16</code> reverts to that of <code>gradient_accumulation_steps=1</code>.</p> <table data-svelte-h="svelte-iqaw47"><thead><tr><th align="center">Model</th> <th align="center"><code>no_sync</code> (accum=1)</th> <th align="center"><code>no_sync</code> (accum=2)</th> <th align="center"><code>no_sync</code> disabled (accum=16)</th></tr></thead> <tbody><tr><td align="center">mixtral 8x7B</td> <td align="center">69G</td> <td align="center">OOM</td> <td align="center">69G</td></tr></tbody></table> <blockquote class="warning" data-svelte-h="svelte-1be6loi"><p>Disabling <code>no_sync</code> means there <em>will be slowdown</em> due the extra data syncs, as explained by the earlier sections of this guide.</p></blockquote> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/gradient_synchronization.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1q7nz6m = { | |
| assets: "/docs/accelerate/pr_4021/en", | |
| base: "/docs/accelerate/pr_4021/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"), | |
| import("/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 15], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 30.6 kB
- Xet hash:
- 509b9e5cfb28bb07176ce56c5998c1d9a7680b17b64ffb3c88a779018c8d9c76
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.