# mypy: allow-untyped-defs from dataclasses import dataclass from typing import Optional import torch @dataclass(frozen=True) class MixedPrecisionPolicy: """ This configures FSDP's mixed precision. Unlike autocast, this applies mixed precision at the module level, not op level, which means low-precision activations are saved for backward and high-to-low-precision casts are incurred only at module boundaries. FSDP works well with module-level mixed precision since it keeps the high-precision sharded parameters in memory anyway. In other words, FSDP does not require any extra memory to keep a high-precision copy of the parameters for the optimizer step. Attributes: param_dtype (Optional[torch.dtype]): This specifies the dtype for the unsharded parameter and hence the dtype for forward/backward computation and the parameter all-gather. If this is ``None``, then the unsharded parameter uses the original dtype. The optimizer step uses the sharded parameter in the original dtype. (Default: ``None``) reduce_dtype (Optional[torch.dtype]): This specifies the dtype for gradient reduction (i.e. reduce-scatter or all-reduce). If this is ``None`` but ``param_dtype`` is not ``None``, then the reduction uses the compute dtype. This can be used to run gradient reduction in full precision while using low precision for compute. If also gradient reduction is disabled via :meth:`set_requires_gradient_sync`, then FSDP will accumulate gradients using ``reduce_dtype``. (Default: ``None``) output_dtype (Optional[torch.dtype]): This specifies the dtype for casting floating-point forward outputs. This can be used to help implement cases where different modules have different mixed precision policies. (Default: ``None``) cast_forward_inputs (bool): This specifies whether FSDP should cast the forward's floating-point input tensors to ``param_dtype`` or not. """ param_dtype: Optional[torch.dtype] = None reduce_dtype: Optional[torch.dtype] = None output_dtype: Optional[torch.dtype] = None cast_forward_inputs: bool = True @dataclass class OffloadPolicy: """ This base class represents the policy of no offloading and is only used as the default value for the ``offload_policy`` arg. """ @dataclass class CPUOffloadPolicy(OffloadPolicy): """ This offload policy offloads parameters, gradients, and optimizer states to CPU. Sharded parameters are copied host-to-device before all-gather. The all-gathered parameters are freed according to ``reshard_after_forward``. Sharded gradients are copied device-to-host in backward, and the optimizer step runs on CPU with CPU optimizer states. Attributes: pin_memory (bool): Whether to pin sharded parameter and gradient memory. Pinning memory allows both more efficient H2D/D2H copies and for the copies to overlap with compute. However, the pinned memory cannot be used by other processes. Set this to ``False`` if you have insufficient CPU memory. (Default: ``True``) """ pin_memory: bool = True