| |
| import inspect |
| import itertools |
| from collections.abc import Sequence |
| from dataclasses import dataclass, field |
| from enum import auto, Enum |
| from typing import Any, Callable, cast, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch._prims_common import make_contiguous_strides_for |
| from torch.distributed._functional_collectives import AsyncCollectiveTensor |
| from torch.distributed.tensor import DTensor, Replicate, Shard |
| from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta |
| from torch.distributed.tensor.device_mesh import _mesh_resources |
| from torch.distributed.tensor.placement_types import _StridedShard, Placement |
|
|
| from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy |
| from ._fsdp_common import ( |
| _chunk_with_empty, |
| _from_local_no_grad, |
| _get_dim_chunked_size, |
| _raise_assert_with_print, |
| _to_dtype_if_needed, |
| compiled_autograd_enabled, |
| FSDPMeshInfo, |
| HSDPMeshInfo, |
| ) |
|
|
|
|
| """ |
| [Note: FSDP tensors] |
| FSDP considers the following tensors: |
| - Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one |
| on the module when applying FSDP |
| - Sharded parameter: sharding the original parameter on dim-0 (or a |
| user-specified dim) as a DTensor over the main mesh |
| - All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather, |
| derived from the sharded parameter |
| - All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from |
| all-gathering the all-gather inputs |
| - Unsharded parameter: parameter used for forward/backward computation, derived |
| from the all-gather output; autograd leaf |
| |
| We define these tensors to describe the general framework that can accomodate |
| extensions, where: |
| - all-gather-inputs = pre-all-gather-transform(sharded-parameter) |
| - unsharded-parameter = post-all-gather-transform(all-gather-outputs) |
| |
| For the default ``torch.Tensor`` case, there is only one all-gather input, and |
| it shares the same underlying tensor data as the sharded parameter, meaning |
| that they can be thought of as the same tensors. The same applies for the |
| all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions, |
| these equivalences may no longer hold due to the pre/post-all-gather |
| transforms, and some may have multiple all-gather inputs/outputs (e.g. |
| quantized data and scales). |
| |
| [Note: FSDP and autograd] |
| FSDP dynamically frees and allocates the unsharded parameter. Since autograd |
| can pack a reference to it or a view to save for backward, we use storage |
| resizing to implement the freeing/allocation since that preserves the aliasing. |
| This implies that we construct the unsharded parameter object once and write to |
| it in-place thereafter. For the default ``torch.Tensor` original parameter |
| case, the all-gather output and unsharded parameter share the same |
| data, so we use storage resizing on the all-gather output. |
| """ |
|
|
| lib = torch.library.Library("fsdp", "FRAGMENT") |
|
|
| lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()") |
|
|
|
|
| @torch.library.impl(lib, "copy_", "Meta") |
| @torch.library.impl(lib, "copy_", "CUDA") |
| @torch.library.impl(lib, "copy_", "XPU") |
| @torch.library.impl(lib, "copy_", "HPU") |
| @torch.library.impl(lib, "copy_", "CPU") |
| @torch.library.impl(lib, "copy_", "MTIA") |
| def copy_(tensor, data): |
| tensor.copy_(data) |
|
|
|
|
| """ |
| [Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_] |
| |
| Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op |
| (i.e. they show up as a mutation op in the middle of the AOT joint graph). |
| |
| Reason: |
| Traceable FSDP2 compiled autograd BWD graph have the following traits: |
| (1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). |
| (2) One of them is mutated (copy_ and resize_ to handle the all-gathered param). |
| (3) They are both subclasses. |
| The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). |
| So this doesn't work at all for Traceable FSDP2. |
| |
| The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops. |
| This avoids the problem above, because from AOTAutograd point-of-view there are no mutations |
| that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) |
| |
| We can avoid this functionalization because: |
| (1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created), |
| so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream. |
| (2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. |
| So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay |
| (since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). |
| |
| Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance? |
| A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops |
| for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0). |
| |
| TODO: |
| Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward, |
| it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are: |
| |
| Cons (of functionalizing those ops): |
| (1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time |
| in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them |
| in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse |
| peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_: |
| https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089) |
| |
| Pros (of functionalizing): |
| (1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations |
| mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more). |
| (2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it. |
| But maybe there are few enough mutations induced by FSDP for this to matter. |
| """ |
|
|
|
|
| @torch.library.impl(lib, "copy_", "Functionalize") |
| def copy__functionalize(tensor, data): |
| torch._sync(tensor) |
| torch._sync(data) |
| tensor_inner = torch._from_functional_tensor(tensor) |
| data_inner = torch._from_functional_tensor(data) |
| with torch._C._ExcludeDispatchKeyGuard( |
| torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) |
| ): |
| torch.ops.fsdp.copy_.default(tensor_inner, data_inner) |
|
|
|
|
| if not torch._running_with_deploy(): |
| torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) |
|
|
|
|
| class ShardedState(Enum): |
| """ |
| - ``SHARDED``: The sharded parameter is registered to the module. It is the |
| only contributor to parameter memory. |
| - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a |
| smaller world size. Since this data should not be used for computation, |
| we do not register it to the module. Users should reshard the module |
| before any in-place modifications. Both it and the sharded parameter |
| contribute to parameter memory. |
| - ``UNSHARDED``: The unsharded parameter is registered to the module. Both |
| it and the sharded parameter contribute to parameter memory. |
| """ |
|
|
| SHARDED = auto() |
| SHARDED_POST_FORWARD = auto() |
| UNSHARDED = auto() |
|
|
|
|
| @dataclass |
| class ParamModuleInfo: |
| """ |
| For a parameter, this stores the module and the parameter name to be able |
| to do a parameter swap via ``setattr(module, param_name, ...)`` or to get |
| the parameter via ``getattr(module, param_name)``. We additionally save |
| shared modules and shared parameter names to update them accordingly. |
| """ |
|
|
| |
| module: nn.Module |
| param_name: str |
| shared_modules: list[nn.Module] = field(default_factory=list) |
| shared_param_names: list[str] = field(default_factory=list) |
|
|
|
|
| @dataclass |
| class ExtensionsData: |
| |
| all_gather_metadata: Optional[Any] = None |
| |
| all_gather_input_sizes: Sequence[torch.Size] = () |
|
|
| def clear(self): |
| self.all_gather_metadata = None |
| self.all_gather_input_sizes = () |
|
|
|
|
| class FSDPParam: |
| """ |
| This class manages a parameter with FSDP or FSDP variants applied, |
| implementing dim-0 per-parameter sharding. |
| """ |
|
|
| orig_dtype: torch.dtype |
| param_dtype: Optional[torch.dtype] |
| reduce_dtype: Optional[torch.dtype] |
| _orig_size: torch.Size |
| sharded_size: torch.Size |
| contiguous_sharded_stride: tuple[int, ...] |
| padded_sharded_param_size: torch.Size |
| sharded_post_forward_size: torch.Size |
| contiguous_sharded_post_forward_stride: tuple[int, ...] |
| _sharded_param_data: torch.Tensor |
| sharded_param: nn.Parameter |
| _sharded_post_forward_param_data: Optional[torch.Tensor] |
| _sharded_post_forward_param: Optional[nn.Parameter] |
| _unsharded_param: nn.Parameter |
| unsharded_accumulated_grad: Optional[torch.Tensor] |
| _sharding_spec: DTensorSpec |
| |
| _tp_spec: DTensorSpec |
| all_gather_outputs: list[torch.Tensor] |
| |
| _extensions_data: ExtensionsData |
| _unsharded_inner_tensors: list[torch.Tensor] |
|
|
| def __init__( |
| self, |
| param: nn.Parameter, |
| module_info: ParamModuleInfo, |
| mesh_info: FSDPMeshInfo, |
| post_forward_mesh_info: Optional[FSDPMeshInfo], |
| device: torch.device, |
| shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], |
| mp_policy: MixedPrecisionPolicy, |
| offload_policy: OffloadPolicy, |
| ): |
| self._module_info: ParamModuleInfo = module_info |
| self.mesh_info = mesh_info |
| self.post_forward_mesh_info = post_forward_mesh_info |
| self.device = device |
| self.mp_policy = mp_policy |
| self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) |
| self.pin_memory = ( |
| self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory |
| ) |
| self.grad_offload_event: Optional[torch.Event] = None |
| self._init_sharded_param(param, device, shard_placement_fn) |
| if self.post_forward_mesh_info: |
| self._init_sharded_post_forward_param_metadata(param) |
| self._init_extensions() |
| self.all_gather_outputs: list[torch.Tensor] = [] |
| self.unsharded_accumulated_grad = None |
| self._param_fqn: Optional[str] = None |
| |
| |
| self._post_load_hook_handle = ( |
| module_info.module.register_load_state_dict_post_hook( |
| lambda *args, **kwargs: self.reset_sharded_param() |
| ) |
| ) |
|
|
| @torch.no_grad() |
| def _init_sharded_param( |
| self, |
| param: nn.Parameter, |
| device: torch.device, |
| shard_placement_fn: Optional[Callable], |
| ): |
| if param.device != device and param.device.type != "meta": |
| raise AssertionError( |
| f"Expects the parameter to already be moved to device {device} but got {param.device}" |
| ) |
| if not param.is_contiguous(): |
| raise NotImplementedError( |
| f"FSDP does not support non-contiguous parameters yet: {param.shape=} {param.stride()=}" |
| ) |
| fsdp_placement = shard_placement_fn(param) if shard_placement_fn else None |
| if fsdp_placement is None: |
| fsdp_placement = Shard(0) |
| elif fsdp_placement.dim < 0: |
| fsdp_placement = Shard(fsdp_placement.dim + param.ndim) |
| assert isinstance(fsdp_placement, Shard), f"{fsdp_placement}" |
| self.fsdp_placement = fsdp_placement |
| shard_dim = fsdp_placement.dim |
| |
| |
| |
| |
| self.is_dtensor = isinstance(param, DTensor) |
| if self.is_dtensor: |
| self._tp_spec = cast(DTensor, param)._spec |
| dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) |
| dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh) |
| tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh) |
| if dp_global_mesh != tp_global_mesh or ( |
| dp_global_mesh is None or tp_global_mesh is None |
| ): |
| raise AssertionError( |
| "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" |
| f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" |
| ) |
| name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" |
| assert dp_mesh.mesh_dim_names is not None, name_dims_error |
| assert tp_mesh.mesh_dim_names is not None, name_dims_error |
| submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names |
| self._spmd_mesh = dp_global_mesh[submesh_names] |
| if len(self._tp_spec.placements) != 1: |
| raise NotImplementedError( |
| f"FSDP only supports 1D TP, not {self._tp_spec.placements}" |
| ) |
| split_factor = self._tp_spec.num_shards_map[shard_dim] |
| assert 2 <= self._spmd_mesh.ndim <= 3, ( |
| f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." |
| ) |
| self._spmd_placements: tuple[Placement, ...] |
| dp_shard_tp_placement = ( |
| ( |
| _StridedShard(shard_dim, split_factor=split_factor) |
| if split_factor > 1 |
| else fsdp_placement |
| ), |
| self._tp_spec.placements[0], |
| ) |
| if self._spmd_mesh.ndim == 2: |
| self._spmd_placements = dp_shard_tp_placement |
| else: |
| assert self.mesh_info.replicate_mesh_dim == 0 |
| self._spmd_placements = (Replicate(),) + dp_shard_tp_placement |
| self._sharding_spec = DTensorSpec( |
| self._spmd_mesh, |
| self._spmd_placements, |
| tensor_meta=self._tp_spec.tensor_meta, |
| ) |
| |
| if split_factor > 1: |
| num_shards = self._sharding_spec.num_shards_map[0] |
| tensor_size_dim_0 = self._sharding_spec.shape[0] |
| if tensor_size_dim_0 % num_shards != 0: |
| raise NotImplementedError( |
| "FSDP+TP sharding does not support uneven sharding for now: " |
| f"tensor dim 0 has size {tensor_size_dim_0} which cannot be " |
| f"evenly sharded into {num_shards} shards." |
| ) |
| param_data = cast(DTensor, param)._local_tensor |
| else: |
| self._spmd_mesh = self.mesh_info.mesh |
| if isinstance(self.mesh_info, HSDPMeshInfo): |
| self._spmd_placements = (Replicate(), fsdp_placement) |
| else: |
| self._spmd_placements = (fsdp_placement,) |
| self._sharding_spec = DTensorSpec( |
| self._spmd_mesh, |
| self._spmd_placements, |
| tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), |
| ) |
| param_data = param |
| assert param_data.is_contiguous(), f"{param_data.shape=} {param_data.stride()=}" |
| shard_dim = fsdp_placement.dim |
| if shard_dim >= param_data.ndim: |
| raise AssertionError( |
| f"Shard dim {shard_dim} is invalid for {param_data.ndim}D tensor: {param.shape}" |
| ) |
| self._orig_size = param_data.size() |
| self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) |
| shard_rank = self.mesh_info.shard_mesh_rank |
| shard_world_size = self.mesh_info.shard_mesh_size |
| if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0: |
| |
| |
| |
| |
| raise NotImplementedError( |
| f"FSDP does not support uneven sharding on dim {shard_dim}: " |
| f"{param_data.size()} (world size: {shard_world_size})" |
| ) |
| chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim) |
| sharded_param = chunks[shard_rank] |
| self.sharded_size = _get_dim_chunked_size( |
| sharded_param, param_data.size(), dim=shard_dim |
| ) |
| self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) |
| padded_sharded_size = chunks[0].size() |
| self.padded_sharded_param_size = padded_sharded_size |
| |
| padded_sharded_param = param_data.new_zeros(padded_sharded_size) |
| if sharded_param.numel() > 0: |
| padded_sharded_param.narrow( |
| dim=shard_dim, start=0, length=sharded_param.size(shard_dim) |
| ).copy_(sharded_param) |
| if self.offload_to_cpu and not padded_sharded_param.is_meta: |
| padded_sharded_param = padded_sharded_param.cpu() |
| if self.pin_memory: |
| padded_sharded_param = padded_sharded_param.pin_memory( |
| device=self.device |
| ) |
| self._sharded_param_data = padded_sharded_param.view(-1) |
| length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 |
| sharded_param = padded_sharded_param.narrow( |
| dim=shard_dim, start=0, length=length |
| ) |
| assert sharded_param.is_contiguous(), f"{self.fsdp_placement=}" |
| self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) |
| self.sharded_param.requires_grad_(param.requires_grad) |
| |
| |
| self._setattr_on_modules(self.sharded_param) |
| self.sharded_state = ShardedState.SHARDED |
|
|
| def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: |
| mesh_info = self.post_forward_mesh_info |
| assert mesh_info is not None |
| param_data = param._local_tensor if isinstance(param, DTensor) else param |
| chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) |
| self.sharded_post_forward_size = _get_dim_chunked_size( |
| chunks[mesh_info.shard_mesh_rank], |
| param_data.size(), |
| dim=self.fsdp_placement.dim, |
| ) |
| self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( |
| self.sharded_post_forward_size |
| ) |
|
|
| def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): |
| param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) |
| self.orig_dtype = self.sharded_param.dtype |
| |
| |
| |
| if reduce_dtype == param_dtype: |
| reduce_dtype = None |
| |
| if param_dtype == self.orig_dtype: |
| param_dtype = None |
| self.param_dtype = param_dtype |
| self.reduce_dtype = reduce_dtype |
| |
|
|
| def _init_extensions(self) -> None: |
| inner_tensor = self._sharded_local_tensor |
| has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather") |
| has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather") |
| if has_fsdp_pre_all_gather != has_fsdp_post_all_gather: |
| raise AssertionError( |
| "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined " |
| f"if using all-gather extensions: {inner_tensor}" |
| ) |
| if has_fsdp_pre_all_gather: |
| self._extensions_data = ExtensionsData() |
| self._unsharded_inner_tensors: list[torch.Tensor] = [] |
|
|
| def init_all_gather_outputs( |
| self, |
| all_gather_input_numels: list[int], |
| all_gather_input_dtypes: list[torch.dtype], |
| world_size: int, |
| device: torch.device, |
| force_recreate: bool = False, |
| ): |
| if not force_recreate and len(self.all_gather_outputs) > 0: |
| return |
| self.all_gather_outputs = [ |
| torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) |
| for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) |
| ] |
|
|
| def init_unsharded_param(self): |
| """ |
| [Note: Invariants for torch.compile Traceable FSDP2] |
| 1. Under compile, we always re-populate the content of `self._unsharded_param` |
| per AllGather using the slow path. |
| 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. |
| This is to ensure the buffer creation is internal to the graph and |
| avoid `self.all_gather_outputs` being captured as a graph input. |
| 3. Under compile, at the end of `free_unsharded_param()`, we always clean up |
| `self.all_gather_outputs` and `self._unsharded_inner_tensors`, |
| to avoid them being captured as graph output. |
| |
| With these invariants, only these tensors will be inputs to the graph: |
| - Sharded parameters |
| - Placeholders for the `self._unsharded_param` nn.Parameter |
| """ |
| if not compiled_autograd_enabled() and hasattr( |
| self, "_unsharded_param" |
| ): |
| inner_tensor = self._sharded_local_tensor |
| if not hasattr(inner_tensor, "fsdp_post_all_gather"): |
| return |
| for tensor in self._unsharded_inner_tensors: |
| alloc_storage(tensor) |
| all_gather_outputs = self._unflatten_all_gather_outputs() |
| inner_tensor.fsdp_post_all_gather( |
| all_gather_outputs, |
| self._extensions_data.all_gather_metadata, |
| self.param_dtype or self.orig_dtype, |
| out=self._unsharded_param, |
| ) |
| self._extensions_data.clear() |
| return |
| inner_tensor = self._sharded_local_tensor |
| if not compiled_autograd_enabled() and hasattr( |
| inner_tensor, "fsdp_post_all_gather" |
| ): |
| all_gather_outputs = self._unflatten_all_gather_outputs() |
| ( |
| unsharded_tensor, |
| self._unsharded_inner_tensors, |
| ) = inner_tensor.fsdp_post_all_gather( |
| all_gather_outputs, |
| self._extensions_data.all_gather_metadata, |
| self.param_dtype or self.orig_dtype, |
| ) |
| self._extensions_data.clear() |
| else: |
| |
| |
| assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}" |
| unsharded_tensor = self.all_gather_outputs[0] |
| unsharded_param = torch.as_strided( |
| unsharded_tensor, |
| self._orig_size, |
| self._contiguous_orig_stride, |
| storage_offset=0, |
| ) |
| if self.is_dtensor: |
| unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) |
| if hasattr(self, "_unsharded_param"): |
| assert compiled_autograd_enabled() |
| with ( |
| torch.no_grad(), |
| torch.autograd._unsafe_preserve_version_counter(self._unsharded_param), |
| ): |
| |
| |
| |
| |
| self._unsharded_param.untyped_storage().resize_( |
| self._unsharded_param.numel() * self._unsharded_param.itemsize |
| ) |
| torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param) |
| else: |
| self._unsharded_param = nn.Parameter( |
| unsharded_param, requires_grad=self.sharded_param.requires_grad |
| ) |
|
|
| def _unflatten_all_gather_outputs(self) -> tuple[torch.Tensor, ...]: |
| return tuple( |
| t.view(-1, *s[1:]) |
| for t, s in zip( |
| self.all_gather_outputs, self._extensions_data.all_gather_input_sizes |
| ) |
| ) |
|
|
| def to_sharded(self) -> None: |
| self._setattr_on_modules(self.sharded_param) |
| self.free_unsharded_param() |
| self.sharded_state = ShardedState.SHARDED |
|
|
| def to_sharded_post_forward(self) -> None: |
| if self.is_dtensor: |
| raise NotImplementedError( |
| "Resharding to smaller mesh with TP is not supported yet" |
| ) |
| self._assert_in_states(ShardedState.UNSHARDED) |
| assert self.post_forward_mesh_info is not None |
| assert len(self.all_gather_outputs) == 1 |
| shard_world_size = self.post_forward_mesh_info.shard_mesh_size |
| if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0: |
| _raise_assert_with_print( |
| f"All-gather output size ({numel}) must be divisible by the shard " |
| f"world size ({shard_world_size})" |
| ) |
| shard_rank = self.post_forward_mesh_info.shard_mesh_rank |
| sharded_numel = numel // shard_world_size |
| self._sharded_post_forward_param_data = ( |
| self.all_gather_outputs[0].narrow( |
| 0, sharded_numel * shard_rank, sharded_numel |
| ) |
| ).clone() |
| sharded_post_forward_tensor = torch.as_strided( |
| self._sharded_post_forward_param_data, |
| size=self.sharded_post_forward_size, |
| stride=self.contiguous_sharded_post_forward_stride, |
| storage_offset=0, |
| ) |
| self._sharded_post_forward_param = nn.Parameter( |
| self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) |
| ) |
| self._setattr_on_modules(self._sharded_post_forward_param) |
| self.free_unsharded_param() |
| self.sharded_state = ShardedState.SHARDED_POST_FORWARD |
|
|
| def to_unsharded(self) -> None: |
| |
| set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) |
| self._setattr_on_modules(self._unsharded_param) |
| if self.sharded_state == ShardedState.SHARDED_POST_FORWARD: |
| |
| |
| |
| |
| self._sharded_post_forward_param = None |
| self._sharded_post_forward_param_data = None |
| self.sharded_state = ShardedState.UNSHARDED |
|
|
| def _setattr_on_modules(self, param: nn.Parameter) -> None: |
| unsafe_setattr_param( |
| self._module_info.module, self._module_info.param_name, param |
| ) |
| for shared_module, shared_param_name in zip( |
| self._module_info.shared_modules, self._module_info.shared_param_names |
| ): |
| unsafe_setattr_param(shared_module, shared_param_name, param) |
|
|
| def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: |
| """ |
| Converts a local tensor representing either the sharded parameter or |
| sharded gradient to DTensor. |
| """ |
| if tensor.shape != self.sharded_size: |
| _raise_assert_with_print( |
| f"Expects size {self.sharded_size} but got {tensor.shape}" |
| ) |
| return _from_local_no_grad( |
| tensor, |
| self._sharding_spec, |
| ) |
|
|
| def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: |
| if tensor.shape != self.sharded_post_forward_size: |
| _raise_assert_with_print( |
| f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}" |
| ) |
| assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo) |
| |
| |
| post_forward_sharding_spec = DTensorSpec( |
| self.post_forward_mesh_info.mesh, |
| (Replicate(), Shard(0)), |
| tensor_meta=self._sharding_spec.tensor_meta, |
| ) |
| return _from_local_no_grad(tensor, post_forward_sharding_spec) |
|
|
| def to_accumulated_grad_if_needed(self) -> None: |
| |
| |
| if ( |
| self.reduce_dtype is None |
| or self._unsharded_param.grad is None |
| or self._unsharded_param.grad.dtype == self.reduce_dtype |
| ): |
| return |
| unsharded_grad = self._unsharded_param.grad |
| self._unsharded_param.grad = None |
| self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype) |
|
|
| def accumulate_unsharded_grad_if_needed(self) -> None: |
| if ( |
| self.unsharded_accumulated_grad is not None |
| and self.unsharded_param.grad is not None |
| ): |
| self.unsharded_accumulated_grad += self.unsharded_param.grad |
| self.unsharded_param.grad = None |
|
|
| def alloc_all_gather_outputs(self) -> None: |
| for tensor in self.all_gather_outputs: |
| alloc_storage(tensor) |
|
|
| def free_unsharded_param(self) -> None: |
| if compiled_autograd_enabled(): |
| """ |
| Assumptions under compile: |
| - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. |
| Instead, we resize `self._unsharded_param` storage size to full and then |
| explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param` |
| in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove |
| the resize_ and copy_ ops in a compiler graph pass to recover performance.) |
| - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT |
| graph inputs. They are created within the graph and is guaranteed to be freed |
| by the end of the graph. They don't leak outside of the graph. |
| """ |
| self._unsharded_param.untyped_storage().resize_(0) |
| self.all_gather_outputs = [] |
| self._unsharded_inner_tensors = [] |
| else: |
| for tensor in itertools.chain( |
| self.all_gather_outputs, self._unsharded_inner_tensors |
| ): |
| free_storage(tensor) |
|
|
| @property |
| def all_gather_inputs(self) -> list[torch.Tensor]: |
| self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) |
| if self.sharded_state == ShardedState.SHARDED: |
| if not compiled_autograd_enabled() and hasattr( |
| self._sharded_local_tensor, "fsdp_pre_all_gather" |
| ): |
| sharded_local_tensor = self._sharded_local_tensor |
| if self.offload_to_cpu: |
| sharded_local_tensor = sharded_local_tensor.to( |
| self.device, non_blocking=True |
| ) |
| pre_all_gather_signature = inspect.signature( |
| sharded_local_tensor.fsdp_pre_all_gather |
| ) |
| num_fn_params = len(pre_all_gather_signature.parameters) |
| |
| assert num_fn_params in ( |
| 1, |
| 5, |
| ), ( |
| f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n" |
| "Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, " |
| "module: nn.Module, mp_policy: MixedPrecisionPolicy)" |
| ) |
| if num_fn_params == 1: |
| ( |
| all_gather_inputs, |
| self._extensions_data.all_gather_metadata, |
| ) = sharded_local_tensor.fsdp_pre_all_gather( |
| self.shard_mesh_from_root |
| ) |
| else: |
| ( |
| all_gather_inputs, |
| self._extensions_data.all_gather_metadata, |
| ) = sharded_local_tensor.fsdp_pre_all_gather( |
| self.shard_mesh_from_root, |
| self._orig_size, |
| self._contiguous_orig_stride, |
| self._module_info.module, |
| self.mp_policy, |
| ) |
| if ( |
| sharded_local_tensor.size() != self.padded_sharded_param_size |
| and any( |
| all_gather_input.size() != self.padded_sharded_param_size |
| for all_gather_input in all_gather_inputs |
| ) |
| ): |
| |
| |
| |
| raise AssertionError( |
| "When a parameter is unevenly sharded by FSDP " |
| f"(orig size={self._orig_size}, FSDP world size={self.mesh_info.mesh.size()}), " |
| "fsdp_pre_all_gather must return all-gather inputs with the padded sharded size " |
| f"{self.padded_sharded_param_size} but got {[t.size() for t in all_gather_inputs]}" |
| ) |
| self._extensions_data.all_gather_input_sizes = [ |
| t.size() for t in all_gather_inputs |
| ] |
| return [t.view(-1) for t in all_gather_inputs] |
| sharded_param_data = self._sharded_param_data |
| if self.offload_to_cpu: |
| sharded_param_data = sharded_param_data.to( |
| self.device, non_blocking=True |
| ) |
| return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] |
| elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: |
| if not compiled_autograd_enabled() and hasattr( |
| self._sharded_local_tensor, "fsdp_pre_all_gather" |
| ): |
| raise NotImplementedError |
| all_gather_input = _to_dtype_if_needed( |
| cast(torch.Tensor, self._sharded_post_forward_param_data), |
| self.param_dtype, |
| ) |
| return [all_gather_input] |
| return [torch.empty(0)] |
|
|
| @property |
| def unsharded_param(self) -> nn.Parameter: |
| return self._unsharded_param |
|
|
| @property |
| def unsharded_grad_data(self) -> torch.Tensor: |
| grad = self.unsharded_param.grad |
| assert grad is not None, "Expects unsharded_param.grad to not be None" |
| return self._get_grad_inner_tensor(grad) |
|
|
| @property |
| def unsharded_accumulated_grad_data(self) -> torch.Tensor: |
| grad = self.unsharded_accumulated_grad |
| assert grad is not None, "Expects unsharded_accumulated_grad to not be None" |
| return self._get_grad_inner_tensor(grad) |
|
|
| def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: |
| if self.is_dtensor: |
| if isinstance(grad, AsyncCollectiveTensor): |
| grad = grad.wait() |
| assert isinstance(grad, DTensor), f"{type(grad)}" |
| placements = self._tp_spec.placements |
| if placements != grad.placements: |
| assert len(self._tp_spec.placements) == len(grad.placements), ( |
| f"{self._tp_spec=} {grad.placements=}" |
| ) |
| grad = grad.redistribute(placements=placements) |
| grad = grad._local_tensor |
| return grad |
|
|
| @property |
| def _sharded_local_tensor(self) -> torch.Tensor: |
| return cast(DTensor, self.sharded_param)._local_tensor |
|
|
| @property |
| def shard_mesh(self): |
| mesh = self.mesh_info.mesh |
| if mesh.ndim == 1: |
| return mesh |
| elif mesh.ndim == 2: |
| assert mesh.mesh_dim_names is not None |
| return mesh[mesh.mesh_dim_names[-1]] |
| raise ValueError(f"Invalid mesh: {mesh}") |
|
|
| @property |
| def shard_mesh_from_root(self): |
| mesh = self.mesh_info.mesh |
|
|
| if mesh.ndim == 1: |
| return mesh |
| else: |
| assert mesh.mesh_dim_names is not None |
| shard_dim_name = mesh.mesh_dim_names[-1] |
|
|
| root_mesh = _mesh_resources.get_root_mesh(mesh) |
| return root_mesh[shard_dim_name] |
|
|
| def _assert_in_states(self, *states: ShardedState) -> None: |
| if self.sharded_state not in states: |
| _raise_assert_with_print( |
| f"Expects to be in one of {states}, not {self.sharded_state}" |
| ) |
|
|
| def reset_sharded_param(self): |
| |
| |
| |
| module_info = self._module_info |
| new_param = getattr(module_info.module, module_info.param_name) |
| if new_param is not self.sharded_param: |
| if torch.__future__.get_swap_module_params_on_conversion(): |
| raise AssertionError( |
| f"Expects swap_tensors to preserve object but got {new_param} " |
| f"instead of {self.sharded_param}" |
| ) |
| self.sharded_param = new_param |
| local_tensor = new_param._local_tensor |
| if local_tensor.is_meta: |
| return |
| updated_local_tensor = False |
| padded_sharded_size = self.padded_sharded_param_size |
| shard_dim = self.fsdp_placement.dim |
| length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 |
| if local_tensor.size() != padded_sharded_size: |
| assert shard_dim == 0, ( |
| f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" |
| ) |
| padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) |
| padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( |
| local_tensor |
| ) |
| local_tensor = padded_local_tensor |
| updated_local_tensor = True |
| if self.pin_memory and not local_tensor.is_pinned(): |
| local_tensor = local_tensor.cpu().pin_memory(device=self.device) |
| updated_local_tensor = True |
| self._sharded_param_data = local_tensor.view(-1) |
| assert isinstance(self.sharded_param, DTensor) |
| if updated_local_tensor: |
| |
| self.sharded_param._local_tensor = local_tensor.narrow( |
| dim=shard_dim, start=0, length=length |
| ) |
| assert self.sharded_param._local_tensor.is_contiguous() |
| self._sharding_spec = self.sharded_param._spec |
|
|
| def __repr__(self): |
| return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})" |
|
|
|
|
| def alloc_storage(tensor: torch.Tensor) -> None: |
| size = tensor.numel() * tensor.itemsize |
| if (storage := tensor.untyped_storage()).size() != size: |
| storage.resize_(size) |
|
|
|
|
| def free_storage(tensor: torch.Tensor) -> None: |
| if (storage := tensor.untyped_storage()).size() != 0: |
| storage.resize_(0) |
|
|
|
|
| |
| |
| |
| def unsafe_setattr_param( |
| module: nn.Module, param_name: str, param: nn.Parameter |
| ) -> None: |
| if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__: |
| module._parameters[param_name] = param |
| else: |
| setattr(module, param_name, param) |
|
|
|
|
| def set_requires_grad_if_needed( |
| src_tensor: torch.Tensor, dst_tensor: torch.Tensor |
| ) -> None: |
| |
| |
| if src_tensor.requires_grad != dst_tensor.requires_grad: |
| dst_tensor.requires_grad_(src_tensor.requires_grad) |
|
|