| | """ |
| | This file includes public APIs for FSDP such as the classes used for the |
| | constructor arguments. |
| | """ |
| |
|
| | from collections.abc import Sequence |
| | from dataclasses import dataclass |
| | from enum import auto, Enum |
| | from typing import Optional |
| |
|
| | import torch |
| | from torch.nn.modules.batchnorm import _BatchNorm |
| |
|
| |
|
| | __all__ = [ |
| | "ShardingStrategy", |
| | "BackwardPrefetch", |
| | "MixedPrecision", |
| | "CPUOffload", |
| | "StateDictType", |
| | "StateDictConfig", |
| | "FullStateDictConfig", |
| | "LocalStateDictConfig", |
| | "ShardedStateDictConfig", |
| | "OptimStateDictConfig", |
| | "FullOptimStateDictConfig", |
| | "LocalOptimStateDictConfig", |
| | "ShardedOptimStateDictConfig", |
| | "StateDictSettings", |
| | ] |
| |
|
| |
|
| | class ShardingStrategy(Enum): |
| | """ |
| | This specifies the sharding strategy to be used for distributed training by |
| | :class:`FullyShardedDataParallel`. |
| | |
| | - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. |
| | For the parameters, this strategy unshards (via all-gather) before the |
| | forward, reshards after the forward, unshards before the backward |
| | computation, and reshards after the backward computation. For gradients, |
| | it synchronizes and shards them (via reduce-scatter) after the backward |
| | computation. The sharded optimizer states are updated locally per rank. |
| | - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during |
| | computation, and additionally, parameters are sharded outside |
| | computation. For the parameters, this strategy unshards before the |
| | forward, does not reshard them after the forward, and only reshards them |
| | after the backward computation. The sharded optimizer states are updated |
| | locally per rank. Inside ``no_sync()``, the parameters are not resharded |
| | after the backward computation. |
| | - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded |
| | but instead replicated across ranks similar to PyTorch's |
| | :class:`DistributedDataParallel` API. For gradients, this strategy |
| | synchronizes them (via all-reduce) after the backward computation. The |
| | unsharded optimizer states are updated locally per rank. |
| | - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across |
| | nodes. This results in reduced communication volume as expensive all-gathers and |
| | reduce-scatters are only done within a node, which can be more performant for medium |
| | -sized models. |
| | - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across |
| | nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput |
| | since the unsharded parameters are not freed after the forward pass, saving the |
| | all-gathers in the pre-backward. |
| | """ |
| |
|
| | FULL_SHARD = auto() |
| | SHARD_GRAD_OP = auto() |
| | NO_SHARD = auto() |
| | HYBRID_SHARD = auto() |
| | _HYBRID_SHARD_ZERO2 = auto() |
| |
|
| |
|
| | class BackwardPrefetch(Enum): |
| | """ |
| | This configures explicit backward prefetching, which improves throughput by |
| | enabling communication and computation overlap in the backward pass at the |
| | cost of slightly increased memory usage. |
| | |
| | - ``BACKWARD_PRE``: This enables the most overlap but increases memory |
| | usage the most. This prefetches the next set of parameters *before* the |
| | current set of parameters' gradient computation. This overlaps the *next |
| | all-gather* and the *current gradient computation*, and at the peak, it |
| | holds the current set of parameters, next set of parameters, and current |
| | set of gradients in memory. |
| | - ``BACKWARD_POST``: This enables less overlap but requires less memory |
| | usage. This prefetches the next set of parameters *after* the current |
| | set of parameters' gradient computation. This overlaps the *current |
| | reduce-scatter* and the *next gradient computation*, and it frees the |
| | current set of parameters before allocating memory for the next set of |
| | parameters, only holding the next set of parameters and current set of |
| | gradients in memory at the peak. |
| | - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables |
| | the backward prefetching altogether. This has no overlap and does not |
| | increase memory usage. In general, we do not recommend this setting since |
| | it may degrade throughput significantly. |
| | |
| | For more technical context: For a single process group using NCCL backend, |
| | any collectives, even if issued from different streams, contend for the |
| | same per-device NCCL stream, which implies that the relative order in which |
| | the collectives are issued matters for overlapping. The two backward |
| | prefetching values correspond to different issue orders. |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | BACKWARD_PRE = auto() |
| | BACKWARD_POST = auto() |
| |
|
| |
|
| | @dataclass |
| | class MixedPrecision: |
| | """ |
| | This configures FSDP-native mixed precision training. |
| | |
| | Attributes: |
| | param_dtype (Optional[torch.dtype]): This specifies the dtype for model |
| | parameters during forward and backward and thus the dtype for |
| | forward and backward computation. Outside forward and backward, the |
| | *sharded* parameters are kept in full precision (e.g. for the |
| | optimizer step), and for model checkpointing, the parameters are |
| | always saved in full precision. (Default: ``None``) |
| | reduce_dtype (Optional[torch.dtype]): This specifies the dtype for |
| | gradient reduction (i.e. reduce-scatter or all-reduce). If this is |
| | ``None`` but ``param_dtype`` is not ``None``, then this takes on |
| | the ``param_dtype`` value, still running gradient reduction in low |
| | precision. This is permitted to differ from ``param_dtype``, e.g. |
| | to force gradient reduction to run in full precision. (Default: |
| | ``None``) |
| | buffer_dtype (Optional[torch.dtype]): This specifies the dtype for |
| | buffers. FSDP does not shard buffers. Rather, FSDP casts them to |
| | ``buffer_dtype`` in the first forward pass and keeps them in that |
| | dtype thereafter. For model checkpointing, the buffers are saved |
| | in full precision except for ``LOCAL_STATE_DICT``. (Default: |
| | ``None``) |
| | keep_low_precision_grads (bool): If ``False``, then FSDP upcasts |
| | gradients to full precision after the backward pass in preparation |
| | for the optimizer step. If ``True``, then FSDP keeps the gradients |
| | in the dtype used for gradient reduction, which can save memory if |
| | using a custom optimizer that supports running in low precision. |
| | (Default: ``False``) |
| | cast_forward_inputs (bool): If ``True``, then this FSDP module casts |
| | its forward args and kwargs to ``param_dtype``. This is to ensure |
| | that parameter and input dtypes match for forward computation, as |
| | required by many ops. This may need to be set to ``True`` when only |
| | applying mixed precision to some but not all FSDP modules, in which |
| | case a mixed-precision FSDP submodule needs to recast its inputs. |
| | (Default: ``False``) |
| | cast_root_forward_inputs (bool): If ``True``, then the root FSDP module |
| | casts its forward args and kwargs to ``param_dtype``, overriding |
| | the value of ``cast_forward_inputs``. For non-root FSDP modules, |
| | this does not do anything. (Default: ``True``) |
| | _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies |
| | module classes to ignore for mixed precision when using an |
| | ``auto_wrap_policy``: Modules of these classes will have FSDP |
| | applied to them separately with mixed precision disabled (meaning |
| | that the final FSDP construction would deviate from the specified |
| | policy). If ``auto_wrap_policy`` is not specified, then this does |
| | not do anything. This API is experimental and subject to change. |
| | (Default: ``(_BatchNorm,)``) |
| | |
| | .. note:: This API is experimental and subject to change. |
| | |
| | .. note:: Only floating point tensors are cast to their specified dtypes. |
| | |
| | .. note:: In ``summon_full_params``, parameters are forced to full |
| | precision, but buffers are not. |
| | |
| | .. note:: Layer norm and batch norm accumulate in ``float32`` even when |
| | their inputs are in a low precision like ``float16`` or ``bfloat16``. |
| | Disabling FSDP's mixed precision for those norm modules only means that |
| | the affine parameters are kept in ``float32``. However, this incurs |
| | separate all-gathers and reduce-scatters for those norm modules, which |
| | may be inefficient, so if the workload permits, the user should prefer |
| | to still apply mixed precision to those modules. |
| | |
| | .. note:: By default, if the user passes a model with any ``_BatchNorm`` |
| | modules and specifies an ``auto_wrap_policy``, then the batch norm |
| | modules will have FSDP applied to them separately with mixed precision |
| | disabled. See the ``_module_classes_to_ignore`` argument. |
| | |
| | .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and |
| | ``cast_forward_inputs=False`` by default. For the root FSDP instance, |
| | its ``cast_root_forward_inputs`` takes precedence over its |
| | ``cast_forward_inputs``. For non-root FSDP instances, their |
| | ``cast_root_forward_inputs`` values are ignored. The default setting is |
| | sufficient for the typical case where each FSDP instance has the same |
| | ``MixedPrecision`` configuration and only needs to cast inputs to the |
| | ``param_dtype`` at the beginning of the model's forward pass. |
| | |
| | .. note:: For nested FSDP instances with different ``MixedPrecision`` |
| | configurations, we recommend setting individual ``cast_forward_inputs`` |
| | values to configure casting inputs or not before each instance's |
| | forward. In such a case, since the casts happen before each FSDP |
| | instance's forward, a parent FSDP instance should have its non-FSDP |
| | submodules run before its FSDP submodules to avoid the activation dtype |
| | being changed due to a different ``MixedPrecision`` configuration. |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +SKIP("undefined variables") |
| | >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) |
| | >>> model[1] = FSDP( |
| | >>> model[1], |
| | >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), |
| | >>> ) |
| | >>> model = FSDP( |
| | >>> model, |
| | >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), |
| | >>> ) |
| | |
| | The above shows a working example. On the other hand, if ``model[1]`` |
| | were replaced with ``model[0]``, meaning that the submodule using |
| | different ``MixedPrecision`` ran its forward first, then ``model[1]`` |
| | would incorrectly see ``float16`` activations instead of ``bfloat16`` |
| | ones. |
| | |
| | """ |
| |
|
| | param_dtype: Optional[torch.dtype] = None |
| | reduce_dtype: Optional[torch.dtype] = None |
| | buffer_dtype: Optional[torch.dtype] = None |
| | keep_low_precision_grads: bool = False |
| | cast_forward_inputs: bool = False |
| | cast_root_forward_inputs: bool = True |
| | _module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,) |
| |
|
| |
|
| | @dataclass |
| | class CPUOffload: |
| | """ |
| | This configures CPU offloading. |
| | |
| | Attributes: |
| | offload_params (bool): This specifies whether to offload parameters to |
| | CPU when not involved in computation. If ``True``, then this |
| | offloads gradients to CPU as well, meaning that the optimizer step |
| | runs on CPU. |
| | """ |
| |
|
| | offload_params: bool = False |
| |
|
| |
|
| | class StateDictType(Enum): |
| | """ |
| | This enum indicates that which type of ``state_dict`` the FSDP module is |
| | currently processing (returning or loading). |
| | The default value is FULL_STATE_DICT to comply the PyTorch convention. |
| | |
| | .. note:: |
| | FSDP currently supports three types of ``state_dict``: |
| | 1. ``state_dict/load_state_dict`: this pair of APIs return and load |
| | the non-sharded, unflattened parameters. The semantics is the |
| | same as using DDP. |
| | 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return |
| | and load local sharded, flattened parameters. The values returned |
| | by ``_local_state_dict`` can be directly used by FSDP and is only |
| | meaningful to FSDP (because parameters are flattened). Note that |
| | these APIs are meant for use via the :func:`state_dict_type` |
| | context manager as follows: |
| | >>> # xdoctest: +SKIP("undefined variables") |
| | >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): |
| | ... state = fsdp.state_dict() # loads local state dict |
| | 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs |
| | return and load sharded, unflattened parameters. The ``state_dict`` |
| | return by ``sharded_state_dict`` can be used by all other parallel |
| | schemes (resharding may be required). |
| | """ |
| |
|
| | FULL_STATE_DICT = auto() |
| | LOCAL_STATE_DICT = auto() |
| | SHARDED_STATE_DICT = auto() |
| |
|
| |
|
| | @dataclass |
| | class StateDictConfig: |
| | """ |
| | ``StateDictConfig`` is the base class for all ``state_dict`` configuration |
| | classes. Users should instantiate a child class (e.g. |
| | ``FullStateDictConfig``) in order to configure settings for the |
| | corresponding ``state_dict`` type supported by FSDP. |
| | |
| | Attributes: |
| | offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict |
| | values to CPU, and if ``False``, then FSDP keeps them on GPU. |
| | (Default: ``False``) |
| | """ |
| |
|
| | offload_to_cpu: bool = False |
| |
|
| |
|
| | @dataclass |
| | class FullStateDictConfig(StateDictConfig): |
| | """ |
| | ``FullStateDictConfig`` is a config class meant to be used with |
| | ``StateDictType.FULL_STATE_DICT``. We recommend enabling both |
| | ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state |
| | dicts to save GPU memory and CPU memory, respectively. This config class |
| | is meant to be used via the :func:`state_dict_type` context manager as |
| | follows: |
| | |
| | >>> # xdoctest: +SKIP("undefined variables") |
| | >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | >>> fsdp = FSDP(model, auto_wrap_policy=...) |
| | >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| | >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): |
| | >>> state = fsdp.state_dict() |
| | >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. |
| | >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: |
| | >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP |
| | >>> if dist.get_rank() == 0: |
| | >>> # Load checkpoint only on rank 0 to avoid memory redundancy |
| | >>> state_dict = torch.load("my_checkpoint.pt") |
| | >>> model.load_state_dict(state_dict) |
| | >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument |
| | >>> # communicates loaded checkpoint states from rank 0 to rest of the world. |
| | >>> fsdp = FSDP( |
| | ... model, |
| | ... device_id=torch.cuda.current_device(), |
| | ... auto_wrap_policy=..., |
| | ... sync_module_states=True, |
| | ... ) |
| | >>> # After this point, all ranks have FSDP model with loaded checkpoint. |
| | |
| | Attributes: |
| | rank0_only (bool): If ``True``, then only rank 0 saves the full state |
| | dict, and nonzero ranks save an empty dict. If ``False``, then all |
| | ranks save the full state dict. (Default: ``False``) |
| | """ |
| |
|
| | rank0_only: bool = False |
| |
|
| |
|
| | @dataclass |
| | class LocalStateDictConfig(StateDictConfig): |
| | pass |
| |
|
| |
|
| | @dataclass |
| | class ShardedStateDictConfig(StateDictConfig): |
| | """ |
| | ``ShardedStateDictConfig`` is a config class meant to be used with |
| | ``StateDictType.SHARDED_STATE_DICT``. |
| | |
| | Attributes: |
| | _use_dtensor (bool): If ``True``, then FSDP saves the state dict values |
| | as ``DTensor``, and if ``False``, then FSDP saves them as |
| | ``ShardedTensor``. (Default: ``False``) |
| | |
| | .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` |
| | and it is used by FSDP to determine the type of state dict values. Users should not |
| | manually modify ``_use_dtensor``. |
| | """ |
| |
|
| | _use_dtensor: bool = False |
| |
|
| |
|
| | @dataclass |
| | class OptimStateDictConfig: |
| | """ |
| | ``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` |
| | configuration classes. Users should instantiate a child class (e.g. |
| | ``FullOptimStateDictConfig``) in order to configure settings for the |
| | corresponding ``optim_state_dict`` type supported by FSDP. |
| | |
| | Attributes: |
| | offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's |
| | tensor values to CPU, and if ``False``, then FSDP keeps them on the |
| | original device (which is GPU unless parameter CPU offloading is |
| | enabled). (Default: ``True``) |
| | """ |
| |
|
| | offload_to_cpu: bool = True |
| |
|
| |
|
| | @dataclass |
| | class FullOptimStateDictConfig(OptimStateDictConfig): |
| | """ |
| | Attributes: |
| | rank0_only (bool): If ``True``, then only rank 0 saves the full state |
| | dict, and nonzero ranks save an empty dict. If ``False``, then all |
| | ranks save the full state dict. (Default: ``False``) |
| | """ |
| |
|
| | rank0_only: bool = False |
| |
|
| |
|
| | @dataclass |
| | class LocalOptimStateDictConfig(OptimStateDictConfig): |
| | offload_to_cpu: bool = False |
| |
|
| |
|
| | @dataclass |
| | class ShardedOptimStateDictConfig(OptimStateDictConfig): |
| | """ |
| | ``ShardedOptimStateDictConfig`` is a config class meant to be used with |
| | ``StateDictType.SHARDED_STATE_DICT``. |
| | |
| | Attributes: |
| | _use_dtensor (bool): If ``True``, then FSDP saves the state dict values |
| | as ``DTensor``, and if ``False``, then FSDP saves them as |
| | ``ShardedTensor``. (Default: ``False``) |
| | |
| | .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` |
| | and it is used by FSDP to determine the type of state dict values. Users should not |
| | manually modify ``_use_dtensor``. |
| | """ |
| |
|
| | _use_dtensor: bool = False |
| |
|
| |
|
| | @dataclass |
| | class StateDictSettings: |
| | state_dict_type: StateDictType |
| | state_dict_config: StateDictConfig |
| | optim_state_dict_config: OptimStateDictConfig |
| |
|