| import itertools |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| from torch.distributed.device_mesh import _get_device_handle |
| from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
|
|
| from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo |
| from ._fsdp_state import _get_module_fsdp_state |
|
|
|
|
| def _get_post_forward_mesh_info( |
| reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo |
| ) -> Optional[FSDPMeshInfo]: |
| shard_mesh_size = mesh_info.shard_mesh_size |
| if not isinstance(reshard_after_forward, (bool, int)): |
| raise ValueError( |
| "reshard_after_forward should be a bool or an int representing the " |
| f"group size to reshard to, not {reshard_after_forward}" |
| ) |
| |
| if not isinstance(reshard_after_forward, bool) and isinstance( |
| reshard_after_forward, int |
| ): |
| if ( |
| reshard_after_forward < 1 |
| or reshard_after_forward > shard_mesh_size |
| or shard_mesh_size % reshard_after_forward != 0 |
| ): |
| raise ValueError( |
| "If passing reshard_after_forward as an int, it should be a " |
| f"factor of {shard_mesh_size}, not {reshard_after_forward}" |
| ) |
| elif reshard_after_forward == 1: |
| reshard_after_forward = False |
| elif reshard_after_forward == shard_mesh_size: |
| reshard_after_forward = True |
| post_forward_mesh_info = None |
| if reshard_after_forward is True: |
| post_forward_mesh_info = mesh_info |
| elif reshard_after_forward is not False: |
| |
| post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward) |
| post_forward_mesh = DeviceMesh( |
| mesh_info.mesh.device_type, post_forward_mesh_tensor |
| ) |
| post_forward_mesh_info = HSDPMeshInfo( |
| post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0 |
| ) |
| return post_forward_mesh_info |
|
|
|
|
| def _init_default_fully_shard_mesh() -> DeviceMesh: |
| """Default to global CUDA mesh if possible else global CPU mesh.""" |
| if not dist.distributed_c10d.is_initialized(): |
| dist.distributed_c10d.init_process_group() |
| default_pg = dist.distributed_c10d._get_default_group() |
| device = torch._C._get_accelerator() |
| mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),)) |
| return mesh |
|
|
|
|
| def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device: |
| if mesh.device_type == "cpu": |
| return torch.device("cpu") |
| device_handle = _get_device_handle(mesh.device_type) |
| return torch.device(mesh.device_type, device_handle.current_device()) |
|
|
|
|
| def _ignore_module( |
| module: nn.Module, |
| ignored_params: set[nn.Parameter], |
| ignore_decision: dict[nn.Module, bool], |
| ) -> bool: |
| """ |
| Decide if it is safe to ignore a module for applying fully_shard. |
| """ |
| if module in ignore_decision: |
| return ignore_decision[module] |
|
|
| if len(list(module.buffers(recurse=False))) > 0: |
| |
| ignore_decision[module] = False |
| return False |
|
|
| for _, param in module.named_parameters(recurse=False): |
| if param not in ignored_params: |
| |
| ignore_decision[module] = False |
| return False |
|
|
| |
| for child in list(module.children()): |
| ignore_child = _ignore_module(child, ignored_params, ignore_decision) |
| if not ignore_child: |
| |
| ignore_decision[module] = False |
| return False |
|
|
| |
| ignore_decision[module] = True |
| return True |
|
|
|
|
| def _adjust_managed_modules( |
| modules: list[nn.Module], ignored_params: set[nn.Parameter] |
| ) -> list[nn.Module]: |
| """ |
| Adjust the given list of managed modules by removing those with all parameters ignored. |
| """ |
| ignore_decision: dict[nn.Module, bool] = {} |
| new_modules = [] |
| for module in modules: |
| ignored = _ignore_module(module, ignored_params, ignore_decision) |
| if not ignored: |
| new_modules.append(module) |
| return new_modules |
|
|
|
|
| def _get_managed_modules( |
| root_modules: tuple[nn.Module, ...], |
| ignored_params: Optional[set[nn.Parameter]] = None, |
| ) -> list[nn.Module]: |
| modules: list[nn.Module] = [] |
| root_modules_set = set(root_modules) |
| |
| visited_modules: set[nn.Module] = set() |
|
|
| def dfs(module: nn.Module) -> None: |
| """ |
| Runs a DFS to collect managed modules, not recursing into modules with |
| a non-composable API or ``fully_shard`` already applied. |
| """ |
| if not _is_composable_with_fsdp(module): |
| return |
| elif ( |
| module not in root_modules_set |
| and _get_module_fsdp_state(module) is not None |
| ): |
| return |
| visited_modules.add(module) |
| for submodule in module.children(): |
| if submodule not in visited_modules: |
| dfs(submodule) |
| modules.append(module) |
|
|
| for root_module in root_modules: |
| dfs(root_module) |
|
|
| if ignored_params is None: |
| return modules |
|
|
| adjusted_modules = _adjust_managed_modules(modules, ignored_params) |
| return adjusted_modules |
|
|
|
|
| def _verify_managed_param(name: str, param: nn.Parameter) -> None: |
| """ |
| Verify if the parameter is accepted by fully_shard. The only restriction now |
| is that the parameter cannot be a scalar tensor (param.numel == 0) since we |
| need at least one dim to shard. |
| """ |
| if len(param.shape) == 0: |
| raise ValueError( |
| "fully_shard doesn't support scalar parameters. " |
| f"Change {name} to a 1D tensor with numel equal to 1." |
| ) |
|
|
|
|
| def _get_managed_states( |
| modules: list[nn.Module], ignored_params: Optional[set[nn.Parameter]] = None |
| ) -> tuple[list[nn.Parameter], list[torch.Tensor]]: |
| params: list[nn.Parameter] = [] |
| buffers: list[torch.Tensor] = [] |
| |
| |
| visited_params: set[nn.Parameter] = set() |
| visited_buffers: set[torch.Tensor] = set() |
| if ignored_params is None: |
| ignored_params = set() |
|
|
| for module in modules: |
| for name, param in module.named_parameters(recurse=False): |
| if param in ignored_params: |
| |
| continue |
| if param not in visited_params: |
| _verify_managed_param(name, param) |
| params.append(param) |
| visited_params.add(param) |
| for buffer in module.buffers(recurse=False): |
| if buffer not in visited_buffers: |
| buffers.append(buffer) |
| visited_buffers.add(buffer) |
| return params, buffers |
|
|
|
|
| def _move_states_to_device( |
| params: list[nn.Parameter], |
| buffers: list[torch.Tensor], |
| device: torch.device, |
| ) -> None: |
| """ |
| We have FSDP move states to device for simpler and faster initialization |
| since FSDP almost always uses CUDA for training. We move parameters/buffers |
| rather than modules since modules to support ignoring parameters/buffers in |
| the future. |
| """ |
| |
| for tensor in itertools.chain(params, buffers): |
| if tensor.device == device or tensor.device.type == "meta": |
| |
| continue |
| if isinstance(tensor, DTensor): |
| if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type: |
| raise ValueError( |
| "Requires DTensor to have mesh of the same type as the FSDP mesh " |
| f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP" |
| ) |
| raise AssertionError( |
| f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" |
| ) |
| tensor_ = tensor |
| if is_traceable_wrapper_subclass(tensor_): |
| with torch.no_grad(): |
| tensor_on_device = nn.Parameter(tensor.to(device)) |
| torch.utils.swap_tensors(tensor, tensor_on_device) |
| else: |
| tensor.data = tensor.to(device) |
|
|