Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Context Parallel in 🤗 accelerate","local":"context-parallel-in--accelerate","sections":[{"title":"Why context parallelism?","local":"why-context-parallelism","sections":[],"depth":2},{"title":"How to use context parallelism?","local":"how-to-use-context-parallelism","sections":[],"depth":2},{"title":"Accelerate’s interface","local":"accelerates-interface","sections":[],"depth":2},{"title":"Configurable options","local":"configurable-options","sections":[],"depth":2},{"title":"Technical details","local":"technical-details","sections":[],"depth":2},{"title":"So how does it work?","local":"so-how-does-it-work","sections":[],"depth":2},{"title":"all-to-all vs all-gather","local":"all-to-all-vs-all-gather","sections":[{"title":"all-gather","local":"all-gather","sections":[],"depth":3},{"title":"all-to-all","local":"all-to-all","sections":[],"depth":3}],"depth":2},{"title":"How to choose the right rotation method?","local":"how-to-choose-the-right-rotation-method","sections":[],"depth":2},{"title":"Why only FSDP2?","local":"why-only-fsdp2","sections":[],"depth":2},{"title":"Data dispatching in joint mesh","local":"data-dispatching-in-joint-mesh","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/accelerate/pr_4021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/scheduler.b9285784.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/singletons.7547c222.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.6d423e5c.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/paths.d42c9205.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/preload-helper.b0bd19d1.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.26bc89a1.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/0.0e7c56e8.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/11.39a528aa.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a0ae628.js"> | |
| <link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/CodeBlock.844ff9c3.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Context Parallel in 🤗 accelerate","local":"context-parallel-in--accelerate","sections":[{"title":"Why context parallelism?","local":"why-context-parallelism","sections":[],"depth":2},{"title":"How to use context parallelism?","local":"how-to-use-context-parallelism","sections":[],"depth":2},{"title":"Accelerate’s interface","local":"accelerates-interface","sections":[],"depth":2},{"title":"Configurable options","local":"configurable-options","sections":[],"depth":2},{"title":"Technical details","local":"technical-details","sections":[],"depth":2},{"title":"So how does it work?","local":"so-how-does-it-work","sections":[],"depth":2},{"title":"all-to-all vs all-gather","local":"all-to-all-vs-all-gather","sections":[{"title":"all-gather","local":"all-gather","sections":[],"depth":3},{"title":"all-to-all","local":"all-to-all","sections":[],"depth":3}],"depth":2},{"title":"How to choose the right rotation method?","local":"how-to-choose-the-right-rotation-method","sections":[],"depth":2},{"title":"Why only FSDP2?","local":"why-only-fsdp2","sections":[],"depth":2},{"title":"Data dispatching in joint mesh","local":"data-dispatching-in-joint-mesh","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="context-parallel-in--accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#context-parallel-in--accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Context Parallel in 🤗 accelerate</span></h1> <p data-svelte-h="svelte-1vuwxvs">This guide will cover basics of using context parallelism in 🤗<code>accelerate</code>, for the more curious readers, we will also cover some technicalities in the later sections.</p> <p data-svelte-h="svelte-1otubnm">See also the very related <a href="./sequence_parallelism">Guide to Sequence Parallellism</a>.</p> <h2 class="relative group"><a id="why-context-parallelism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#why-context-parallelism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Why context parallelism?</span></h2> <p data-svelte-h="svelte-2sleob">With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences. | |
| With sequence length of 128k, the memory requirement of the attention matrix is <code>128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads</code> for <code>bf16</code> precision, given vanilla attention implementation. Granted, with usage of <code>flash attention</code> or <code>SDPA</code> which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.</p> <p data-svelte-h="svelte-1u7kh92">Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.</p> <h2 class="relative group"><a id="how-to-use-context-parallelism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-use-context-parallelism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to use context parallelism?</span></h2> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->from accelerate.utils import ParallelismConfig, TorchContextParallelConfig | |
| <span class="hljs-addition">+ cp_config = TorchContextParallelConfig(</span> | |
| <span class="hljs-addition">+ cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"</span> | |
| <span class="hljs-addition">+ )</span> | |
| <span class="hljs-addition">+ parallelism_config = ParallelismConfig(</span> | |
| <span class="hljs-addition">+ cp_size=8,</span> | |
| <span class="hljs-addition">+ cp_handler=cp_config, # or just cp_size=8, if you want to use the default "allgather"</span> | |
| <span class="hljs-addition">+ )</span> | |
| accelerator = Accelerator( | |
| ..., | |
| parallelism_config=parallelism_config, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1e2e60n">As with any other feature in 🤗<code>accelerate</code>, you can enable context parallelism also by passing the corresponding flags to <code>accelerate launch</code>. | |
| In this case, it’s no different:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-1ncmwz6"><p>You can also set the <code>cp_size</code> and <code>cp_comm_strategy</code> in the <code>accelerate config</code> command, which will save them in your <code>accelerate</code> configuration file, so you don’t have to pass them every time you launch your script.</p></blockquote> <blockquote class="tip" data-svelte-h="svelte-1io1dbt"><p>Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2. | |
| You can simply combine them by setting your parallelism sizes to the desired values, e.g. <code>--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8</code>. Or you can use the <code>ParallelismConfig</code> class to set them programmatically.</p></blockquote> <blockquote class="warning" data-svelte-h="svelte-1r2slm7"><p>Context parallelism is tightly coupled with <code>FSDP2</code>, which you can learn more about in the <a href="fsdp1_vs_fsdp2">FSDP2 introduction</a>. Meaning, context parallelism only works if you use <code>FullyShardedDataParallelPlugin</code> or <code>--use-fsdp</code> with version set to 2 to your | |
| program. If no <code>FSDP2</code> is used, error will be raised.</p></blockquote> <blockquote class="warning" data-svelte-h="svelte-1vriq5e"><p>Context parallelism works only with <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">SDPA</a> and only with no mask or causal mask. We can’t properly detect this for you, so it’s your responsibility to ensure that you are using <code>SDPA</code> with no mask or causal mask. If you use any other attention implementation, it will raise an error.</p></blockquote> <p data-svelte-h="svelte-q27r1v">After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around <a href="https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel" rel="nofollow"><code>torch.distributed.tensor.experimental.context_parallel</code></a> that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a <code>noop</code> if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration. | |
| You can use it as follows:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> dataloader: | |
| <span class="hljs-keyword">with</span> accelerator.maybe_context_parallel( | |
| buffers=[batch[<span class="hljs-string">"input_ids"</span>], batch[<span class="hljs-string">"attention_mask"</span>]], | |
| buffer_seq_dims=[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>], | |
| no_restore_buffers={batch[<span class="hljs-string">"input_ids"</span>], batch[<span class="hljs-string">"labels"</span>]}, | |
| ): | |
| outputs = model(**batch) | |
| ...<!-- HTML_TAG_END --></pre></div> <blockquote class="warning" data-svelte-h="svelte-1git15w"><p>This context manager has to be recreated with each training step, as shown in the example above. It’s crucial to do so.</p></blockquote> <p data-svelte-h="svelte-l1hp35">This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.</p> <p align="center" data-svelte-h="svelte-zw16sw"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage"> <br> <em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em></p> <blockquote class="tip"><p data-svelte-h="svelte-7kbx6m">These examples were created with a script you can find <a href="https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py" rel="nofollow">in the examples folder</a>. To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000<!-- HTML_TAG_END --></pre></div></blockquote> <h2 class="relative group"><a id="accelerates-interface" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerates-interface"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate’s interface</span></h2> <p data-svelte-h="svelte-jcewir">The context manager takes a few arguments, that are used to configure the context parallelism.</p> <ul data-svelte-h="svelte-1npdv1z"><li><code>buffers</code>: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.</li> <li><code>buffer_seq_dims</code>: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the <code>buffers</code> list. If you pass <code>buffers=[input_ids, shift_labels]</code> with both having shape <code>[batch_size, sequence_length]</code>, you would pass <code>buffer_seq_dims=[1, 1]</code>. | |
| as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.</li> <li><code>no_restore_buffers</code>: The implementation of context parallelism modifies the buffers in-place, converting them to <code>torch.distributed.tensor.Dtensor</code>s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the <code>buffers</code> argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.</li></ul> <blockquote class="warning" data-svelte-h="svelte-1kfcjij"><p>Context parallelism is not compatible with <code>labels</code> that are a copy of <code>input_ids</code>, which models from 🤗 transformers can shift to enable causal language modeling themselves. | |
| Imagine this case: | |
| labels = [l1, l2, l3, l4, … li] | |
| if we apply context parallelism, each rank would end up with a part of labels, such as this: | |
| labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], … | |
| after transformers modelling code shifts the labels, it would end up with: | |
| labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], … | |
| where <code>PAD</code> is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore. | |
| Because of this, you need to manually shift the labels before passing them in the model</p></blockquote> <h2 class="relative group"><a id="configurable-options" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#configurable-options"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Configurable options</span></h2> <p data-svelte-h="svelte-tidq8o">Accelerate provides only a single option to configure context parallelism (except for <code>cp_size</code>)</p> <ul data-svelte-h="svelte-19ddlaf"><li><code>cp_comm_strategy</code>: The rotation method to use for the shards. We strongly recommend keeping this as <code>"allgather"</code>, as it’s very likely it will outperform <code>"alltoall"</code> in most cases.</li></ul> <p data-svelte-h="svelte-pt2d1o">Context parallel size is rather self-explanatory, it’s the number of ranks across which the inputs are to be-sharded. | |
| Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We’ll cover the 2 options in more detail in the next section.</p> <p data-svelte-h="svelte-8bfpem">You can see an end-to-end example in the <a href="https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py" rel="nofollow">ND parallel example</a> file, where you can train an 8B model with up-to 128k context length on a single 8xH100 node. Using multi-node training, you can scale this to 1M+ sequence length on multiple GPUs. You can also seamlessly combine it with other parallelism strategies to fit your needs.</p> <h2 class="relative group"><a id="technical-details" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#technical-details"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Technical details</span></h2> <blockquote class="tip" data-svelte-h="svelte-1f1d7d3"><p>This section is fairly technical, so if you don’t need to learn the internals of context parallelism, you can skip it and start building 🚀</p></blockquote> <p data-svelte-h="svelte-5gcsxu">We’re going to be using word <code>shard</code> extensively in the following sections, so let’s define it first. If we call tensor <code>sharded</code> across <code>Dth</code> dimension, across <code>N</code> ranks, we mean that this tensor is split into <code>N</code> parts, where each part of the tensor has shape <code>[..., D//N, ...]</code>.</p> <h2 class="relative group"><a id="so-how-does-it-work" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#so-how-does-it-work"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>So how does it work?</span></h2> <p data-svelte-h="svelte-1q8d3hy">Context parallelism works on sharding the <code>Q, K and V</code> matrices across the sequence dimension. Each rank has its assigned shard of <code>Q</code>, let’s call it <code>Q_i</code>. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of <code>K</code> and <code>V</code>, let’s call them <code>K_i</code> and <code>V_i</code>. Then, each rank calculates attention with its own shard of <code>Q_i</code>, <code>K_i</code> and <code>V_i</code>, let’s call it <code>attn_i</code>. During this computation, a communication kernel is launched to gather the <code>Ks</code> and <code>Vs</code> from all other ranks. What communication primitive is used, depends on the <code>context_parallel_shard_rotation</code> option. | |
| This way, each rank gets to calculate local attention, first with <code>Q_i</code>, <code>K_i</code> and <code>V_i</code>, then with <code>K_j</code> and <code>V_j</code> from all other ranks. As each rank holds <code>Q, K and V</code> matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.</p> <p data-svelte-h="svelte-17jpam7">We can formalize this in the following pseudocode:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->comm_kernel = {<span class="hljs-string">"allgather"</span>: allgather, <span class="hljs-string">"alltoall"</span>: alltoall}[context_parallel_shard_rotation] | |
| Qi, Ki, Vi = shard(Q, K, V, seq_dim) | |
| attn[i] = attn(Qi, Ki, Vi) | |
| <span class="hljs-keyword">for</span> j <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(context_parallel_size): | |
| Kj, Vj = comm_kernel() | |
| attn[j] = attn(Qi, Kj, Vj) <span class="hljs-comment"># [batch, num_heads, seq_len // context_parallel_size, head_dim]</span> | |
| final_attn = combine(attn)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="all-to-all-vs-all-gather" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#all-to-all-vs-all-gather"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>all-to-all vs all-gather</span></h2> <h3 class="relative group"><a id="all-gather" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#all-gather"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>all-gather</span></h3> <p data-svelte-h="svelte-6a32kj">So what’s the difference between all-to-all and all-gather? With all-gather, the communication is very simple. After (well, before, as it usually takes longer) we compute the local attention <code>attn_i</code> we launch an all-gather to gather all other <code>Ks</code> and <code>Vs</code> from all other ranks. As this communication is done, each rank has all the <code>Ks</code> and <code>Vs</code> from all other ranks, and can compute the attention with them sequentially. | |
| In ideal scenario, all-gather finishes in the exact moment as the calculation of <code>attn_i</code> is done. However, this never happens in practice, so the ideal real overlap is achieved when the full <code>attn_i</code> is overlapped with a part of the communication, then to start the computation with <code>K_j</code> and <code>V_j</code>, we wait for the all-gather to finish.</p> <h3 class="relative group"><a id="all-to-all" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#all-to-all"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>all-to-all</span></h3> <p data-svelte-h="svelte-jzgpzc">All-to-all, or sometimes called <code>ring-rotation</code> utilizes a ring-like communication pattern. After concluding <code>attn_i</code> computation, an all-to-all is launched to send <code>K_i</code> and <code>V_i</code> to the neighbouring ranks. We then repeat this <code>context_parallel_size-1</code> times, so that each rank sees all the shards of <code>K</code> and <code>V</code> from all other ranks once. In ideal scenario, we prefetch shards <code>K_i+1</code> and <code>V_i+1</code> from the neighbouring rank and this communication is exactly overlapped with computation of our current <code>attn_i</code>. Again, realistically, this perfect overlap doesn’t ever happen. Given the nature of this approach, if we don’t achieve perfect overlap, the penalty is way larger than with all-gather.</p> <h2 class="relative group"><a id="how-to-choose-the-right-rotation-method" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-choose-the-right-rotation-method"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to choose the right rotation method?</span></h2> <p data-svelte-h="svelte-1kj70ku">In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it’s more likely to achieve better performance. Extensive <a href="https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082" rel="nofollow">benchmarks</a> from the <code>torchtitan</code> team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.</p> <p data-svelte-h="svelte-1ldmtvt">You can directly see this issue in the profiler output in the image below:</p> <p align="center" data-svelte-h="svelte-1c3ofsn"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_all_to_all.png" alt="all-to-all profiler output"> <br> <em>Figure 1: In red you can see the idle time, while we wait for the all-to-all kernel to finish. Highlighted in the first blue bar, you can see that it takes ~250us to finish, which is repeated N-1 times for each attention call, where N is the context parallel size.</em></p> <h2 class="relative group"><a id="why-only-fsdp2" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#why-only-fsdp2"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Why only FSDP2?</span></h2> <p data-svelte-h="svelte-1ikc9rn">We only support context parallelism with <code>FSDP2</code>, as we create a joint mesh of <code>context_parallel_size</code> and <code>dp_shard_size</code> to | |
| utilize its full potential. | |
| How it works is: we shard the model across the joint mesh of size <code>cp_size*dp_shard_size</code>, which maximizes the memory savings. | |
| This is a “free lunch” of sorts, as <code>FSDP</code> communication is fully overlapped with the computation of attention, as shown in the images below.</p> <p align="center" data-svelte-h="svelte-1jnhmn3"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP"> <br> <em>Figure 2: In blue rectangles (Stream 23), you can see that the pre-fetch of `FSDP` shard is fully overlapped with the computation of attention (Stream 7), while in red rectangles (Stream 24), you can see that the all-gather kernel results in a bubble of idle time, in which our compute stream (7) is idle.</em></p> <p data-svelte-h="svelte-10ujnkx">In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.</p> <h2 class="relative group"><a id="data-dispatching-in-joint-mesh" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#data-dispatching-in-joint-mesh"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Data dispatching in joint mesh</span></h2> <p data-svelte-h="svelte-14bdjd0">We make sure to dispatch the same batch of data to the whole <code>cp</code> subgroup, so that the results are correct. (Meaning each rank in <code>cp</code> subgroup gets the same batch of data.) However, we also dispatch different batches to each rank of <code>dp_shard</code> group. | |
| Imagine it like this:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 8 GPUS, --dp_shard_size 4, --cp_size 2</span> | |
| <span class="hljs-comment"># mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]</span> | |
| <span class="hljs-comment"># model is sharded across the whole mesh (each GPU holds 1/8 of the model)</span> | |
| <span class="hljs-comment"># GPUs 0,1 = batch 0</span> | |
| <span class="hljs-comment"># GPUs 2,3 = batch 1</span> | |
| ... <span class="hljs-keyword">and</span> so <span class="hljs-keyword">on</span>.<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1q7nz6m = { | |
| assets: "/docs/accelerate/pr_4021/en", | |
| base: "/docs/accelerate/pr_4021/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"), | |
| import("/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 11], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 48.6 kB
- Xet hash:
- 0a2dfc30463ed131c891f2bb0d9852cb5499cb4345a2cdbc3592240c27610236
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.