Buckets:

download
raw
40.7 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Fully Sharded Data Parallel&quot;,&quot;local&quot;:&quot;fully-sharded-data-parallel&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;How it works out of the box&quot;,&quot;local&quot;:&quot;how-it-works-out-of-the-box&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Saving and loading&quot;,&quot;local&quot;:&quot;saving-and-loading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;State Dict&quot;,&quot;local&quot;:&quot;state-dict&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages&quot;,&quot;local&quot;:&quot;mapping-between-fsdp-sharding-strategies-and-deepspeed-zero-stages&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;A few caveats to be aware of&quot;,&quot;local&quot;:&quot;a-few-caveats-to-be-aware-of&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/accelerate/pr_4021/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/scheduler.b9285784.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/singletons.7547c222.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.6d423e5c.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/paths.d42c9205.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/preload-helper.b0bd19d1.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/index.26bc89a1.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/0.0e7c56e8.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/nodes/46.7959cc08.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/Tip.e4eba3d6.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.7a0ae628.js">
<link rel="modulepreload" href="/docs/accelerate/pr_4021/en/_app/immutable/chunks/CodeBlock.844ff9c3.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Fully Sharded Data Parallel&quot;,&quot;local&quot;:&quot;fully-sharded-data-parallel&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;How it works out of the box&quot;,&quot;local&quot;:&quot;how-it-works-out-of-the-box&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Saving and loading&quot;,&quot;local&quot;:&quot;saving-and-loading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;State Dict&quot;,&quot;local&quot;:&quot;state-dict&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages&quot;,&quot;local&quot;:&quot;mapping-between-fsdp-sharding-strategies-and-deepspeed-zero-stages&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;A few caveats to be aware of&quot;,&quot;local&quot;:&quot;a-few-caveats-to-be-aware-of&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="fully-sharded-data-parallel" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fully-sharded-data-parallel"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fully Sharded Data Parallel</span></h1> <p data-svelte-h="svelte-qnvjvw">To accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model.
This type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters.
To read more about it and the benefits, check out the <a href="https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/" rel="nofollow">Fully Sharded Data Parallel blog</a>.
We have integrated the latest PyTorch’s Fully Sharded Data Parallel (FSDP) training feature.
All you need to do is enable it through the config.</p> <h2 class="relative group"><a id="how-it-works-out-of-the-box" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-it-works-out-of-the-box"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How it works out of the box</span></h2> <p data-svelte-h="svelte-3pof6s">On your machine(s) just run:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ovz84c">and answer the questions asked. This will generate a config file that will be used automatically to properly set the
default options when doing</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch my_script.py --args_to_my_script<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-109sebi">For instance, here is how you would run <code>examples/nlp_example.py</code> (from the root of the repo) with FSDP enabled:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->compute_environment: LOCAL_MACHINE
debug: <span class="hljs-literal">false</span>
distributed_type: FSDP
downcast_bf16: <span class="hljs-string">&#x27;no&#x27;</span>
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: <span class="hljs-literal">false</span>
fsdp_cpu_ram_efficient_loading: <span class="hljs-literal">true</span>
fsdp_offload_params: <span class="hljs-literal">false</span>
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: <span class="hljs-literal">true</span>
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: <span class="hljs-literal">true</span>
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: <span class="hljs-literal">true</span>
tpu_env: []
tpu_use_cluster: <span class="hljs-literal">false</span>
tpu_use_sudo: <span class="hljs-literal">false</span>
use_cpu: <span class="hljs-literal">false</span><!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch examples/nlp_example.py<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-aporbc">Currently, <code>Accelerate</code> supports the following config through the CLI:</p> <p data-svelte-h="svelte-mtvmo1"><code>fsdp_sharding_strategy</code>: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD (DDP), [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official <a href="https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy" rel="nofollow">PyTorch docs</a>.</p> <p data-svelte-h="svelte-173mpds"><code>fsdp_offload_params</code> : Decides Whether to offload parameters and gradients to CPU</p> <p data-svelte-h="svelte-56dvve"><code>fsdp_auto_wrap_policy</code>: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP</p> <p data-svelte-h="svelte-1tuw09a"><code>fsdp_transformer_layer_cls_to_wrap</code>: Only applicable for Transformers. When using <code>fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP</code>, a user may provide a comma-separated string of transformer layer class names (case-sensitive) to wrap, e.g., <code>BertLayer</code>, <code>GPTJBlock</code>, <code>T5Block</code>, <code>BertLayer,BertEmbeddings,BertSelfOutput</code>. This is important because submodules that share weights (e.g., embedding layers) should not end up in different FSDP wrapped units. Using this policy, wrapping happens for each block containing Multi-Head Attention followed by a couple of MLP layers. Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. Therefore, use this for transformer-based models. You can use the <code>model._no_split_modules</code> for Transformer models by answering <code>yes</code> to <code>Do you want to use the model&#39;s </code>_no_split_modules<code>to wrap. It will try to use</code>model._no_split_modules` when possible.</p> <p data-svelte-h="svelte-fbrtfo"><code>fsdp_min_num_params</code>: minimum number of parameters when using <code>fsdp_auto_wrap_policy=SIZE_BASED_WRAP</code>.</p> <p data-svelte-h="svelte-gnni23"><code>fsdp_backward_prefetch_policy</code>: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH</p> <p data-svelte-h="svelte-1b8s7mr"><code>fsdp_forward_prefetch</code>: if True, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. Should only be used for static-graph models since the prefetching follows the first iteration’s execution order. i.e., if the sub-modules’ order changes dynamically during the model’s execution do not enable this feature.</p> <p data-svelte-h="svelte-1bw4fgg"><code>fsdp_state_dict_type</code>: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT</p> <p data-svelte-h="svelte-1hsz8jb"><code>fsdp_use_orig_params</code>: If True, allows non-uniform <code>requires_grad</code> during init, which means support for interspersed frozen and trainable parameters. This setting is useful in cases such as parameter-efficient fine-tuning as discussed in <a href="https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019" rel="nofollow">this post</a>. This option also allows one to have multiple optimizer param groups. This should be <code>True</code> when creating an optimizer before preparing/wrapping the model with FSDP.</p> <p data-svelte-h="svelte-w8koju"><code>fsdp_cpu_ram_efficient_loading</code>: Only applicable for Transformers models. If True, only the first process loads the pretrained model checkpoint while all other processes have empty weights. This should be set to False if you experience errors when loading the pretrained Transformers model via <code>from_pretrained</code> method. When this setting is True <code>fsdp_sync_module_states</code> also must to be True, otherwise all the processes except the main process would have random weights leading to unexpected behaviour during training. For this to work, make sure the distributed process group is initialized before calling Transformers <code>from_pretrained</code> method. When using Trainer API, the distributed process group is initialized when you create an instance of <code>TrainingArguments</code> class.</p> <p data-svelte-h="svelte-9z96c5"><code>fsdp_sync_module_states</code>: If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0.</p> <p data-svelte-h="svelte-1mlu7qv">For additional and more nuanced control, you can specify other FSDP parameters via <code>FullyShardedDataParallelPlugin</code>.
When creating <code>FullyShardedDataParallelPlugin</code> object, pass it the parameters that weren’t part of the accelerate config or if you want to override them.
The FSDP parameters will be picked based on the accelerate config file or launch command arguments and other parameters that you will pass directly through the <code>FullyShardedDataParallelPlugin</code> object will set/override that.</p> <p data-svelte-h="svelte-ks9e2y">Below is an example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> FullyShardedDataParallelPlugin
<span class="hljs-keyword">from</span> torch.distributed.fsdp.fully_sharded_data_parallel <span class="hljs-keyword">import</span> FullOptimStateDictConfig, FullStateDictConfig
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=<span class="hljs-literal">False</span>, rank0_only=<span class="hljs-literal">False</span>),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=<span class="hljs-literal">False</span>, rank0_only=<span class="hljs-literal">False</span>),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="saving-and-loading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#saving-and-loading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Saving and loading</span></h2> <p data-svelte-h="svelte-105rpu1">The new recommended way of checkpointing when using FSDP models is to use <code>SHARDED_STATE_DICT</code> as <code>StateDictType</code> when setting up the accelerate config.
Below is the code snippet to save using <code>save_state</code> utility of accelerate.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerator.save_state(<span class="hljs-string">&quot;ckpt&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zz881x">Inspect the checkpoint folder to see model and optimizer as shards per process:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">ls</span> ckpt
<span class="hljs-comment"># optimizer_0 pytorch_model_0 random_states_0.pkl random_states_1.pkl scheduler.bin</span>
<span class="hljs-built_in">cd</span> ckpt
<span class="hljs-built_in">ls</span> optimizer_0
<span class="hljs-comment"># __0_0.distcp __1_0.distcp</span>
<span class="hljs-built_in">ls</span> pytorch_model_0
<span class="hljs-comment"># __0_0.distcp __1_0.distcp</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1v7i9ht">To load them back for resuming the training, use the <code>load_state</code> utility of accelerate</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerator.load_state(<span class="hljs-string">&quot;ckpt&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1y7f54p">When using transformers <code>save_pretrained</code>, pass <code>state_dict=accelerator.get_state_dict(model)</code> to save the model state dict.
Below is an example:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --> unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
<span class="hljs-addition">+ state_dict=accelerator.get_state_dict(model),</span>
)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="state-dict" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#state-dict"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>State Dict</span></h3> <p data-svelte-h="svelte-y5myy"><code>accelerator.get_state_dict</code> will call the underlying <code>model.state_dict</code> implementation using <code>FullStateDictConfig(offload_to_cpu=True, rank0_only=True)</code> context manager to get the state dict only for rank 0 and it will be offloaded to CPU.</p> <p data-svelte-h="svelte-57zs50">You can then pass <code>state</code> into the <code>save_pretrained</code> method. There are several modes for <code>StateDictType</code> and <code>FullStateDictConfig</code> that you can use to control the behavior of <code>state_dict</code>. For more information, see the <a href="https://pytorch.org/docs/stable/fsdp.html" rel="nofollow">PyTorch documentation</a>.</p> <p data-svelte-h="svelte-ik1th2">If you choose to use <code>StateDictType.SHARDED_STATE_DICT</code>, the weights of the model during <code>Accelerator.save_state</code> will be split into <code>n</code> files for each sub-split on the model. To merge them back into
a single dictionary to load back into the model later after training you can use the <code>merge_weights</code> utility:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> accelerate.utils <span class="hljs-keyword">import</span> merge_fsdp_weights
<span class="hljs-comment"># Our weights are saved usually in a `pytorch_model_fsdp_{model_number}` folder</span>
merge_fsdp_weights(<span class="hljs-string">&quot;pytorch_model_fsdp_0&quot;</span>, <span class="hljs-string">&quot;output_path&quot;</span>, safe_serialization=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qedos9">The final output will then either be saved to <code>model.safetensors</code> or <code>pytorch_model.bin</code> (if <code>safe_serialization=False</code> is passed).</p> <p data-svelte-h="svelte-1duybas">This can also be called using the CLI:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate merge-weights pytorch_model_fsdp_0/ output_path<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="mapping-between-fsdp-sharding-strategies-and-deepspeed-zero-stages" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#mapping-between-fsdp-sharding-strategies-and-deepspeed-zero-stages"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Mapping between FSDP sharding strategies and DeepSpeed ZeRO Stages</span></h2> <ul data-svelte-h="svelte-1mojn5u"><li><code>FULL_SHARD</code> maps to the DeepSpeed <code>ZeRO Stage-3</code>. Shards optimizer states, gradients and parameters.</li> <li><code>SHARD_GRAD_OP</code> maps to the DeepSpeed <code>ZeRO Stage-2</code>. Shards optimizer states and gradients.</li> <li><code>NO_SHARD</code> maps to <code>ZeRO Stage-0</code>. No sharding wherein each GPU has full copy of model, optimizer states and gradients.</li> <li><code>HYBRID_SHARD</code> maps to <code>ZeRO++ Stage-3</code> wherein <code>zero_hpz_partition_size=&lt;num_gpus_per_node&gt;</code>. Here, this will shard optimizer states, gradients and parameters within each node while each node has full copy.</li></ul> <h2 class="relative group"><a id="a-few-caveats-to-be-aware-of" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#a-few-caveats-to-be-aware-of"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>A few caveats to be aware of</span></h2> <ul data-svelte-h="svelte-1cyfpa6"><li>In case of multiple models, pass the optimizers to the prepare call in the same order as corresponding models else <code>accelerator.save_state()</code> and <code>accelerator.load_state()</code> will result in wrong/unexpected behaviour.</li> <li>This feature is incompatible with <code>--predict_with_generate</code> in the <code>run_translation.py</code> script of <code>Transformers</code> library.</li></ul> <p data-svelte-h="svelte-1dlu8nc">For more control, users can leverage the <code>FullyShardedDataParallelPlugin</code>. After creating an instance of this class, users can pass it to the Accelerator class instantiation.
For more information on these options, please refer to the PyTorch <a href="https://github.com/pytorch/pytorch/blob/0df2e863fbd5993a7b9e652910792bd21a516ff3/torch/distributed/fsdp/fully_sharded_data_parallel.py#L236" rel="nofollow">FullyShardedDataParallel</a> code.</p> <blockquote class="tip"><p data-svelte-h="svelte-we3qam">For those interested in the similarities and differences between FSDP and DeepSpeed, please check out the <a href="../concept_guides/fsdp_and_deepspeed">concept guide here</a>!</p></blockquote> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/accelerate/blob/main/docs/source/usage_guides/fsdp.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1q7nz6m = {
assets: "/docs/accelerate/pr_4021/en",
base: "/docs/accelerate/pr_4021/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/accelerate/pr_4021/en/_app/immutable/entry/start.8a49e72b.js"),
import("/docs/accelerate/pr_4021/en/_app/immutable/entry/app.1df4d18e.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 46],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
40.7 kB
·
Xet hash:
c4c33ded571a4b87f83154082d87610aa2667d4821d03e67b5d3d1acd7d862bc

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.