| | |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class MixedPrecisionPolicy: |
| | """ |
| | This configures FSDP's mixed precision. Unlike autocast, this applies mixed |
| | precision at the module level, not op level, which means low-precision |
| | activations are saved for backward and high-to-low-precision casts are |
| | incurred only at module boundaries. |
| | |
| | FSDP works well with module-level mixed precision since it keeps the |
| | high-precision sharded parameters in memory anyway. In other words, FSDP |
| | does not require any extra memory to keep a high-precision copy of the |
| | parameters for the optimizer step. |
| | |
| | Attributes: |
| | param_dtype (Optional[torch.dtype]): This specifies the dtype for |
| | the unsharded parameter and hence the dtype for forward/backward |
| | computation and the parameter all-gather. If this is ``None``, then |
| | the unsharded parameter uses the original dtype. The optimizer step |
| | uses the sharded parameter in the original dtype. (Default: |
| | ``None``) |
| | reduce_dtype (Optional[torch.dtype]): This specifies the dtype for |
| | gradient reduction (i.e. reduce-scatter or all-reduce). If this is |
| | ``None`` but ``param_dtype`` is not ``None``, then the reduction |
| | uses the compute dtype. This can be used to run gradient reduction |
| | in full precision while using low precision for compute. If also |
| | gradient reduction is disabled via :meth:`set_requires_gradient_sync`, |
| | then FSDP will accumulate gradients using ``reduce_dtype``. |
| | (Default: ``None``) |
| | output_dtype (Optional[torch.dtype]): This specifies the dtype for |
| | casting floating-point forward outputs. This can be used to |
| | help implement cases where different modules have different mixed |
| | precision policies. (Default: ``None``) |
| | cast_forward_inputs (bool): This specifies whether FSDP should cast the |
| | forward's floating-point input tensors to ``param_dtype`` or not. |
| | """ |
| |
|
| | param_dtype: Optional[torch.dtype] = None |
| | reduce_dtype: Optional[torch.dtype] = None |
| | output_dtype: Optional[torch.dtype] = None |
| | cast_forward_inputs: bool = True |
| |
|
| |
|
| | @dataclass |
| | class OffloadPolicy: |
| | """ |
| | This base class represents the policy of no offloading and is only used as |
| | the default value for the ``offload_policy`` arg. |
| | """ |
| |
|
| |
|
| | @dataclass |
| | class CPUOffloadPolicy(OffloadPolicy): |
| | """ |
| | This offload policy offloads parameters, gradients, and optimizer states to |
| | CPU. Sharded parameters are copied host-to-device before all-gather. The |
| | all-gathered parameters are freed according to ``reshard_after_forward``. |
| | Sharded gradients are copied device-to-host in backward, and the optimizer |
| | step runs on CPU with CPU optimizer states. |
| | |
| | Attributes: |
| | pin_memory (bool): Whether to pin sharded parameter and gradient |
| | memory. Pinning memory allows both more efficient H2D/D2H copies |
| | and for the copies to overlap with compute. However, the pinned |
| | memory cannot be used by other processes. Set this to ``False`` if |
| | you have insufficient CPU memory. (Default: ``True``) |
| | """ |
| |
|
| | pin_memory: bool = True |
| |
|