| | |
| | |
| | import logging |
| | import math |
| | import os |
| | import threading |
| | import warnings |
| | from collections.abc import Iterator |
| | from functools import reduce |
| | from itertools import chain, zip_longest |
| | from typing import Optional, TYPE_CHECKING, Union |
| |
|
| | import torch |
| | from torch.distributed import is_available |
| | from torch.utils._typing_utils import not_none |
| |
|
| |
|
| | __all__ = ["init_device_mesh", "DeviceMesh"] |
| |
|
| |
|
| | if not is_available(): |
| | import sys |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class _DeviceMeshStub: |
| | pass |
| |
|
| | def _init_device_mesh_stub(): |
| | pass |
| |
|
| | sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub |
| | sys.modules[ |
| | "torch.distributed.device_mesh" |
| | ].init_device_mesh = _init_device_mesh_stub |
| |
|
| |
|
| | else: |
| | from torch._C._distributed_c10d import Backend as C10dBackend |
| | from torch.distributed.distributed_c10d import ( |
| | _get_default_group, |
| | _resolve_process_group, |
| | get_backend, |
| | get_process_group_ranks, |
| | get_rank, |
| | get_world_size, |
| | init_process_group, |
| | is_initialized, |
| | new_group, |
| | ProcessGroup, |
| | split_group, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | if TYPE_CHECKING: |
| | try: |
| | from numpy.typing import ArrayLike |
| | except ImportError: |
| | logger.warning( |
| | "DeviceMesh requires numpy >= 1.21 to be installed for type checking" |
| | ) |
| |
|
| | class _MeshEnv(threading.local): |
| | def __init__(self) -> None: |
| | self.mesh_stack: list[DeviceMesh] = [] |
| | self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} |
| | self.mesh_dim_group_options: dict[ |
| | int, tuple[Optional[str], Optional[C10dBackend.Options]] |
| | ] = {} |
| | self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} |
| | |
| | self.flatten_name_to_root_dims: dict[ |
| | DeviceMesh, dict[str, tuple[int, ...]] |
| | ] = {} |
| |
|
| | def get_current_mesh(self) -> "DeviceMesh": |
| | if len(self.mesh_stack) == 0: |
| | raise RuntimeError("No device mesh is currently active!") |
| | return self.mesh_stack[-1] |
| |
|
| | def create_sub_mesh( |
| | self, |
| | device_mesh: "DeviceMesh", |
| | submesh_dim_names: tuple[str, ...], |
| | submesh_dims: list[tuple[int, ...]], |
| | ) -> "DeviceMesh": |
| | |
| | |
| | |
| | |
| | slice_dim_size = [ |
| | reduce( |
| | lambda x, y: x * device_mesh.mesh.size(y), |
| | mesh_dim, |
| | 1, |
| | ) |
| | for mesh_dim in submesh_dims |
| | ] |
| |
|
| | mesh_tensor = device_mesh.mesh |
| | |
| | slice_dim_idx = [] |
| | slice_dim_group_name = [] |
| | |
| | |
| | num_dims_flatten = 0 |
| | for mesh_dim_indices, mesh_dim_name in zip(submesh_dims, submesh_dim_names): |
| | |
| | |
| | if len(mesh_dim_indices) > 1: |
| | |
| | mesh_tensor = mesh_tensor.flatten( |
| | start_dim=mesh_dim_indices[0] - num_dims_flatten, |
| | end_dim=mesh_dim_indices[-1] - num_dims_flatten, |
| | ) |
| | |
| | |
| | |
| | slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) |
| | num_dims_flatten += len(mesh_dim_indices) - 1 |
| | slice_dim_group_name.append( |
| | self.root_to_flatten_mapping[device_mesh][ |
| | mesh_dim_name |
| | ]._dim_group_names[0] |
| | ) |
| | else: |
| | slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) |
| | slice_dim_group_name.append( |
| | device_mesh._dim_group_names[mesh_dim_indices[0]] |
| | ) |
| |
|
| | |
| | mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) |
| | for idx in slice_dim_idx: |
| | if idx not in mesh_dims_remained_idx: |
| | raise NotImplementedError( |
| | "Currently, this only allows slicing out a contiguous flattened dim." |
| | ) |
| | mesh_dims_remained_idx.remove(idx) |
| |
|
| | |
| | |
| | |
| | pg_ranks_by_dim = mesh_tensor.permute( |
| | *mesh_dims_remained_idx, *slice_dim_idx |
| | ).reshape(-1, *slice_dim_size) |
| |
|
| | cur_rank = device_mesh.get_rank() |
| | for mesh_nd in pg_ranks_by_dim: |
| | submesh = DeviceMesh( |
| | device_mesh.device_type, |
| | mesh_nd, |
| | mesh_dim_names=submesh_dim_names, |
| | _init_backend=False, |
| | ) |
| | if cur_rank in mesh_nd: |
| | res_submesh = submesh |
| |
|
| | res_submesh._dim_group_names = slice_dim_group_name |
| | self.child_to_root_mapping[res_submesh] = device_mesh |
| |
|
| | return res_submesh |
| |
|
| | def create_flatten_mesh( |
| | self, |
| | device_mesh: "DeviceMesh", |
| | mesh_dim_name: Optional[str] = None, |
| | backend_override: tuple[Optional[str], Optional[C10dBackend.Options]] = ( |
| | None, |
| | None, |
| | ), |
| | ) -> "DeviceMesh": |
| | root_mesh = _mesh_resources.get_root_mesh(device_mesh) |
| |
|
| | flatten_dims_in_root = [ |
| | not_none(root_mesh.mesh_dim_names).index(flatten_mesh_dim_name) |
| | for flatten_mesh_dim_name in not_none(device_mesh.mesh_dim_names) |
| | ] |
| |
|
| | if not mesh_dim_name: |
| | mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names)) |
| |
|
| | |
| | self.flatten_name_to_root_dims.setdefault(root_mesh, {}) |
| | invalid_dim_names = chain( |
| | list(not_none(root_mesh.mesh_dim_names)), |
| | *self.flatten_name_to_root_dims[root_mesh].keys(), |
| | ) |
| | if mesh_dim_name in invalid_dim_names: |
| | raise RuntimeError( |
| | f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", |
| | f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " |
| | f"Please specify another valid mesh_dim_name.", |
| | ) |
| |
|
| | |
| | |
| | |
| | if ( |
| | root_mesh in self.root_to_flatten_mapping |
| | and mesh_dim_name in self.root_to_flatten_mapping[root_mesh] |
| | ): |
| | return self.root_to_flatten_mapping[root_mesh][mesh_dim_name] |
| |
|
| | flattened_mesh_dim_size = math.prod(device_mesh.mesh.size()) |
| |
|
| | remained_dims_in_root = list(range(root_mesh.mesh.ndim)) |
| | for flatten_dim_in_root in flatten_dims_in_root: |
| | remained_dims_in_root.remove(flatten_dim_in_root) |
| |
|
| | pg_ranks_by_dim = root_mesh.mesh.permute( |
| | *remained_dims_in_root, *flatten_dims_in_root |
| | ).reshape(-1, flattened_mesh_dim_size) |
| |
|
| | cur_rank = root_mesh.get_rank() |
| | for mesh_nd in pg_ranks_by_dim: |
| | |
| | flattened_mesh = DeviceMesh( |
| | root_mesh.device_type, |
| | mesh_nd, |
| | mesh_dim_names=(mesh_dim_name,), |
| | backend_override=(backend_override,), |
| | ) |
| | if cur_rank in mesh_nd: |
| | res_flattened_mesh = flattened_mesh |
| | self.child_to_root_mapping[res_flattened_mesh] = root_mesh |
| | self.root_to_flatten_mapping.setdefault(root_mesh, {})[mesh_dim_name] = ( |
| | res_flattened_mesh |
| | ) |
| | self.flatten_name_to_root_dims[root_mesh][mesh_dim_name] = tuple( |
| | flatten_dims_in_root |
| | ) |
| |
|
| | return res_flattened_mesh |
| |
|
| | def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": |
| | |
| | |
| | |
| | root_mesh = self.child_to_root_mapping.get(device_mesh, None) |
| | return device_mesh if not root_mesh else root_mesh |
| |
|
| | def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: |
| | """ |
| | Returns the index of the mesh dim in the root mesh. |
| | The device_mesh passed in needs to be sliced out from the root mesh |
| | or submesh of the root mesh. |
| | """ |
| | root_mesh = self.get_root_mesh(device_mesh) |
| | child_mesh_dim_names = device_mesh.mesh_dim_names |
| | if root_mesh and child_mesh_dim_names: |
| | assert len(child_mesh_dim_names) == 1, ( |
| | "The submesh can only be a 1D mesh." |
| | ) |
| | child_mesh_dim_name = child_mesh_dim_names[0] |
| | return self.get_mesh_dim_by_name(root_mesh, child_mesh_dim_name) |
| | return None |
| |
|
| | @staticmethod |
| | def num_devices_per_host(device_type: str) -> int: |
| | return _get_device_handle(device_type).device_count() |
| |
|
| | @staticmethod |
| | def num_hosts(device_type: str) -> int: |
| | |
| | |
| | return get_world_size() // _MeshEnv.num_devices_per_host(device_type) |
| |
|
| | def get_mesh_dim_by_name( |
| | self, device_mesh: "DeviceMesh", mesh_dim_name: str |
| | ) -> int: |
| | if ( |
| | device_mesh.mesh_dim_names is None |
| | or len(device_mesh.mesh_dim_names) == 0 |
| | ): |
| | raise KeyError( |
| | "No `mesh_dim_names` found.", |
| | ) |
| | if mesh_dim_name not in device_mesh.mesh_dim_names: |
| | raise KeyError( |
| | f"Mesh dimension '{mesh_dim_name}' does not exist.", |
| | f"Available mesh dimensions are: mesh_dim_names={device_mesh.mesh_dim_names}", |
| | ) |
| | return not_none(device_mesh.mesh_dim_names.index(mesh_dim_name)) |
| |
|
| | def _set_mesh_dim_group_options( |
| | self, |
| | dim: int, |
| | backend: Optional[str], |
| | pg_options: Optional[C10dBackend.Options] = None, |
| | ) -> None: |
| | self.mesh_dim_group_options[dim] = (backend, pg_options) |
| |
|
| | def _get_slice_mesh_dims( |
| | self, device_mesh, mesh_dim_names |
| | ) -> list[tuple[int, ...]]: |
| | """ |
| | Validate whether the mesh_dim_names is valid for slicing the given device_mesh. |
| | If valid, return dim indexes of the slice mesh in the device mesh. |
| | """ |
| | if device_mesh != self.get_root_mesh(device_mesh): |
| | warnings.warn( |
| | "You are attempting to slice a submesh from another submesh. While we support this operation, " |
| | "it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. " |
| | "If not, this may result in some ranks receiving the submesh while others encounter errors." |
| | ) |
| |
|
| | |
| | |
| | self.flatten_name_to_root_dims.setdefault(device_mesh, {}) |
| | flatten_name_to_root_dims = self.flatten_name_to_root_dims[device_mesh] |
| | valid_mesh_dim_names = [ |
| | *device_mesh.mesh_dim_names, |
| | *flatten_name_to_root_dims, |
| | ] |
| |
|
| | if not all( |
| | mesh_dim_name in valid_mesh_dim_names |
| | for mesh_dim_name in mesh_dim_names |
| | ): |
| | raise KeyError( |
| | f"Invalid mesh_dim_names {mesh_dim_names} specified. " |
| | f"Valid mesh_dim_names are {valid_mesh_dim_names}." |
| | ) |
| |
|
| | |
| | |
| | curr_idx = -1 |
| | slice_mesh_dims = [] |
| | for mesh_dim_name in mesh_dim_names: |
| | if mesh_dim_name in flatten_name_to_root_dims: |
| | mesh_indices = flatten_name_to_root_dims[mesh_dim_name] |
| | |
| | |
| | next_idx = mesh_indices[-1] |
| | slice_mesh_dims.append(mesh_indices) |
| | else: |
| | next_idx = device_mesh.mesh_dim_names.index(mesh_dim_name) |
| | slice_mesh_dims.append((next_idx,)) |
| | if next_idx <= curr_idx: |
| | raise KeyError( |
| | f"Invalid mesh_dim_names {mesh_dim_names} specified. " |
| | f"Found mesh dim indices to slice: {slice_mesh_dims}. " |
| | "Mesh dim indices should be in ascending order." |
| | ) |
| | curr_idx = next_idx |
| |
|
| | return slice_mesh_dims |
| |
|
| | def _get_all_submeshes( |
| | self, device_mesh: "DeviceMesh", mesh_dim_name: str |
| | ) -> list["DeviceMesh"]: |
| | """ |
| | Return all the submeshes of a given mesh dimension of the device mesh. |
| | """ |
| | mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) |
| | pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( |
| | -1, device_mesh.mesh.size(mesh_dim) |
| | ) |
| |
|
| | cur_rank = device_mesh.get_rank() |
| | res_submeshes = [] |
| | for mesh_1d in pg_ranks_by_dim: |
| | submesh = DeviceMesh( |
| | device_mesh.device_type, |
| | mesh_1d, |
| | mesh_dim_names=(mesh_dim_name,), |
| | _init_backend=False, |
| | ) |
| | submesh._dim_group_names = ( |
| | [device_mesh._dim_group_names[mesh_dim]] |
| | if cur_rank in mesh_1d |
| | else [] |
| | ) |
| | res_submeshes.append(submesh) |
| |
|
| | return res_submeshes |
| |
|
| | _mesh_resources: _MeshEnv = _MeshEnv() |
| |
|
| | def _get_device_handle(device_type: str = "cuda"): |
| | """ |
| | Get the module corresponding to the device_type which is cuda or cuda-like device. |
| | For example, when the device_type is cuda, the module `torch.cuda` is returned. |
| | Return None when there is no corresponding module for device_type, otherwise |
| | return the corresponding module. |
| | """ |
| | return getattr(torch, device_type, None) |
| |
|
| | class DeviceMesh: |
| | """ |
| | DeviceMesh represents a mesh of devices, where layout of devices could be |
| | represented as a n-d dimension array, and each value of the n-d dimensional |
| | array is the global id of the default process group ranks. |
| | |
| | DeviceMesh could be used to setup the N dimensional device connections across the cluster, |
| | and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on |
| | each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects |
| | already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization), |
| | and will select/set the device for the current process if user does not set the device |
| | beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization. |
| | |
| | DeviceMesh can also be used as a context manager when using together with DTensor APIs. |
| | |
| | .. note:: |
| | DeviceMesh follows SPMD programming model, which means the same PyTorch Python program |
| | is running on all processes/ranks in the cluster. Therefore, users need to make sure the |
| | `mesh` array (which describes the layout of devices) should be identical across all ranks. |
| | Inconsistent `mesh` will lead to silent hang. |
| | |
| | Args: |
| | device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". |
| | mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout |
| | of devices, where the IDs are global IDs of the default process group. |
| | |
| | Returns: |
| | DeviceMesh: A :class:`DeviceMesh` object representing the device layout. |
| | |
| | The following program runs on each process/rank in an SPMD manner. In this example, we have 2 |
| | hosts with 4 GPUs each. |
| | A reduction over the first dimension of mesh will reduce across |
| | columns (0, 4), .. and (3, 7), a reduction over the second dimension |
| | of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +SKIP("no rank") |
| | >>> from torch.distributed.device_mesh import DeviceMesh |
| | >>> |
| | >>> # Initialize device mesh as (2, 4) to represent the topology |
| | >>> # of cross-host(dim 0), and within-host (dim 1). |
| | >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) |
| | """ |
| |
|
| | device_type: str |
| | mesh: torch.Tensor |
| | mesh_dim_names: Optional[tuple[str, ...]] |
| |
|
| | def __init__( |
| | self, |
| | device_type: str, |
| | mesh: Union[torch.Tensor, "ArrayLike"], |
| | *, |
| | mesh_dim_names: Optional[tuple[str, ...]] = None, |
| | backend_override: Optional[ |
| | tuple[tuple[Optional[str], Optional[C10dBackend.Options]], ...] |
| | ] = None, |
| | _init_backend: bool = True, |
| | ) -> None: |
| | self.device_type = device_type |
| | if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": |
| | raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") |
| | self.mesh = ( |
| | mesh.detach().to(dtype=torch.int) |
| | if isinstance(mesh, torch.Tensor) |
| | else torch.tensor(mesh, device="cpu", dtype=torch.int) |
| | ) |
| | self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None |
| | if backend_override is None: |
| | backend_override = ((None, None),) * self.mesh.ndim |
| |
|
| | |
| | self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) |
| | self._thread_id = None |
| |
|
| | |
| | |
| | if device_type != "xla": |
| | |
| | |
| | |
| | if _init_backend: |
| | self._setup_world_group_and_device() |
| | self._init_process_groups(backend_override) |
| |
|
| | if is_initialized() and get_backend() == "threaded": |
| | self._thread_id = threading.get_ident() |
| |
|
| | |
| | rank_coords = (self.mesh == get_rank()).nonzero() |
| | assert rank_coords.size(0) in (0, 1) |
| | self._coordinate_on_dim: Optional[list[int]] = ( |
| | rank_coords[0].tolist() if rank_coords.size(0) > 0 else None |
| | ) |
| |
|
| | def _setup_world_group_and_device(self): |
| | default_initialized = is_initialized() |
| | |
| | |
| | if not default_initialized: |
| | init_process_group() |
| |
|
| | world_size = get_world_size() |
| | if self.mesh.numel() > world_size: |
| | raise RuntimeError( |
| | f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" |
| | ) |
| |
|
| | |
| | |
| | device_handle = _get_device_handle(self.device_type) |
| | if device_handle and not device_handle.is_initialized(): |
| | |
| | |
| | if "LOCAL_RANK" in os.environ: |
| | local_rank = int(os.environ["LOCAL_RANK"]) |
| | logger.info( |
| | "Setting default device for the current process based on LOCAL_RANK=%s", |
| | local_rank, |
| | ) |
| | device_handle.set_device(local_rank) |
| | else: |
| | warnings.warn( |
| | "It seems like you did not set/select the default device for the current process before the DeviceMesh " |
| | "initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. " |
| | "It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that " |
| | "the underlying communicator (i.e. NCCL) can be initialized properly. " |
| | "Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the " |
| | "device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. " |
| | ) |
| | |
| | |
| | num_devices_per_host = device_handle.device_count() |
| | if ( |
| | world_size > num_devices_per_host |
| | and world_size % num_devices_per_host != 0 |
| | ): |
| | raise RuntimeError( |
| | f"DeviceMesh only support homogeneous hardware, but found " |
| | f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" |
| | ) |
| | device_handle.set_device(get_rank() % num_devices_per_host) |
| |
|
| | return _get_default_group() |
| |
|
| | def _init_process_groups( |
| | self, |
| | backend_override: tuple[ |
| | tuple[Optional[str], Optional[C10dBackend.Options]], ... |
| | ], |
| | ): |
| | |
| | |
| | |
| | dim_group_names: list[str] = [] |
| | default_group = _get_default_group() |
| |
|
| | if ( |
| | self.mesh.ndim == 1 |
| | and self.mesh.numel() == get_world_size() |
| | and _mesh_resources.mesh_dim_group_options.get(0, (None, None)) |
| | == (None, None) |
| | and backend_override[0] == (None, None) |
| | ): |
| | |
| | |
| | ranks = list(range(get_world_size())) |
| | dim_group = ( |
| | new_group( |
| | backend="cpu:gloo,cuda:nccl", |
| | ranks=ranks, |
| | group_desc="mesh_default", |
| | ) |
| | if torch.cuda.is_available() |
| | and get_backend(default_group) == "gloo" |
| | else default_group |
| | ) |
| | dim_group_names.append(dim_group.group_name) |
| | else: |
| | |
| | for dim in range(self.mesh.ndim): |
| | |
| | |
| | pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( |
| | -1, self.mesh.size(dim) |
| | ) |
| |
|
| | |
| | |
| | if dim in _mesh_resources.mesh_dim_group_options: |
| | if backend_override[dim] != (None, None): |
| | raise RuntimeError( |
| | f"Dimension {dim} present both in the backend_override argument " |
| | "and via _mesh_resources._set_mesh_dim_group_options" |
| | ) |
| | ( |
| | backend, |
| | pg_options, |
| | ) = _mesh_resources.mesh_dim_group_options[dim] |
| | else: |
| | backend, pg_options = backend_override[dim] |
| |
|
| | |
| | |
| | |
| | |
| | group_desc = ( |
| | f"mesh_{self.mesh_dim_names[dim]}" |
| | if self.mesh_dim_names |
| | else f"mesh_dim_{dim}" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | dim_group = None |
| | has_split_group = False |
| | if ( |
| | ( |
| | bound_device_id := getattr( |
| | default_group, "bound_device_id", None |
| | ) |
| | ) |
| | is not None |
| | and torch.cuda.is_available() |
| | and ( |
| | backend is None |
| | or default_group._get_backend(torch.device("cuda")).name() |
| | == backend |
| | ) |
| | ): |
| | dim_group = split_group( |
| | parent_pg=default_group, |
| | pg_options=pg_options, |
| | split_ranks=pg_ranks_by_dim.tolist(), |
| | group_desc=group_desc, |
| | ) |
| | has_split_group = True |
| |
|
| | |
| | |
| | |
| | |
| | for dim_mesh in pg_ranks_by_dim: |
| | subgroup_ranks = dim_mesh.tolist() |
| |
|
| | |
| | |
| | |
| | if bound_device_id is None or not has_split_group: |
| | dim_group = new_group( |
| | ranks=subgroup_ranks, |
| | backend=backend, |
| | pg_options=pg_options, |
| | group_desc=group_desc, |
| | ) |
| |
|
| | |
| | if self.get_rank() in subgroup_ranks: |
| | if len(dim_group_names) > dim: |
| | raise RuntimeError( |
| | f"Each device mesh dimension should get only one process group, but got {self.get_rank()} " |
| | f"in {subgroup_ranks}!" |
| | ) |
| | dim_group_names.append(dim_group.group_name) |
| | self._dim_group_names = dim_group_names |
| |
|
| | def __enter__(self) -> "DeviceMesh": |
| | |
| | _mesh_resources.mesh_stack.append(self) |
| | return self |
| |
|
| | |
| | def __exit__(self, exc_type, exc_value, exc_traceback) -> None: |
| | |
| | _mesh_resources.mesh_stack.pop() |
| |
|
| | def __repr__(self) -> str: |
| | device_mesh_repr = ( |
| | f"({', '.join(f'{k}={v}' for k, v in zip(self.mesh_dim_names, self.mesh.shape))})" |
| | if self.mesh_dim_names |
| | else f"{tuple(self.mesh.shape)}" |
| | ) |
| | device_mesh_repr = f"DeviceMesh({device_mesh_repr}, device: '{self.device_type}', stride: {self.mesh.stride()}" |
| | |
| | if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": |
| | device_mesh_repr += f", Mesh: {self.mesh.tolist()}" |
| | return f"{device_mesh_repr})" |
| |
|
| | def __hash__(self): |
| | |
| | self._hash = getattr(self, "_hash", None) |
| | if not self._hash: |
| | self._hash = hash( |
| | ( |
| | self._flatten_mesh_list, |
| | self.mesh.shape, |
| | self.device_type, |
| | self.mesh_dim_names, |
| | self._thread_id, |
| | ) |
| | ) |
| | return self._hash |
| |
|
| | def __eq__(self, other: object) -> bool: |
| | if self is other: |
| | return True |
| | if not isinstance(other, DeviceMesh): |
| | return False |
| | return ( |
| | self._flatten_mesh_list == other._flatten_mesh_list |
| | and self.mesh.shape == other.mesh.shape |
| | and self.device_type == other.device_type |
| | and self.mesh_dim_names == other.mesh_dim_names |
| | and self._thread_id == other._thread_id |
| | ) |
| |
|
| | def __getitem__( |
| | self, mesh_dim_names: Union[str, tuple[str, ...]] |
| | ) -> "DeviceMesh": |
| | """ |
| | Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. |
| | The submesh created consists of the dimensions and the communicators indicated by |
| | ``mesh_dim_names`` |
| | |
| | Args: |
| | mesh_dim_names (Union[str, Tuple[str]]): the name or the tuple of names of the |
| | mesh dimension of the DeviceMesh to create the submesh for. |
| | Returns: |
| | A :class:`DeviceMesh` object |
| | |
| | The following program runs on each process/rank in an SPMD manner in a world size of 8. |
| | In the first example: |
| | Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). |
| | Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). |
| | Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). |
| | Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). |
| | Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). |
| | Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). |
| | |
| | In the second example: |
| | Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). |
| | Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). |
| | Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). |
| | Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +SKIP("no rank") |
| | >>> from torch.distributed.device_mesh import DeviceMesh |
| | >>> |
| | >>> # Initialize a 2D device mesh as (2, 4) to represent the topology |
| | >>> # of cross-host(dim 0), and within-host (dim 1). |
| | >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) |
| | >>> tp_mesh = mesh_2d["tp"] |
| | >>> dp_mesh = mesh_2d["dp"] |
| | >>> |
| | >>> # Initialize a 3D mesh. |
| | >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) |
| | >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. |
| | >>> dp_cp_mesh = mesh_3d["dp", "cp"] |
| | >>> cp_dp_mesh = mesh_3d["cp", "dp"] |
| | """ |
| | if not self.mesh_dim_names: |
| | raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") |
| |
|
| | mesh_dim_names = ( |
| | (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names |
| | ) |
| |
|
| | if mesh_dim_names == self.mesh_dim_names: |
| | return self |
| | else: |
| | slice_mesh_dims = _mesh_resources._get_slice_mesh_dims( |
| | self, mesh_dim_names |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | with torch._subclasses.fake_tensor.unset_fake_temporarily(): |
| | submesh = _mesh_resources.create_sub_mesh( |
| | self, mesh_dim_names, slice_mesh_dims |
| | ) |
| | return submesh |
| |
|
| | def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: |
| | """ |
| | Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the |
| | DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. |
| | |
| | Args: |
| | mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index |
| | of the mesh dimension. Default is None. |
| | |
| | Returns: |
| | A :class:`ProcessGroup` object. |
| | """ |
| | if not hasattr(self, "_dim_group_names"): |
| | raise RuntimeError("DeviceMesh process groups not initialized!") |
| |
|
| | if self.mesh.ndim > 1 and mesh_dim is None: |
| | raise RuntimeError( |
| | f"Found the DeviceMesh have {self.mesh.ndim} dimensions", |
| | "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", |
| | "If you want to get the list of all the ProcessGroups in the DeviceMesh," |
| | "please use `get_all_groups()` instead.", |
| | ) |
| |
|
| | |
| | if self.mesh.ndim == 1 and mesh_dim is None: |
| | return not_none(_resolve_process_group(self._dim_group_names[0])) |
| |
|
| | root_mesh = _mesh_resources.get_root_mesh(self) |
| | root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get( |
| | root_mesh, None |
| | ) |
| | if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): |
| | dim_group_name = root_to_flatten_mapping[ |
| | mesh_dim |
| | ]._dim_group_names[0] |
| | return not_none(_resolve_process_group(dim_group_name)) |
| | else: |
| | mesh_dim = ( |
| | _mesh_resources.get_mesh_dim_by_name(self, mesh_dim) |
| | if isinstance(mesh_dim, str) |
| | else mesh_dim |
| | ) |
| | assert isinstance(mesh_dim, int) |
| | return not_none(_resolve_process_group(self._dim_group_names[mesh_dim])) |
| |
|
| | def get_all_groups(self) -> list[ProcessGroup]: |
| | """ |
| | Returns a list of ProcessGroups for all mesh dimensions. |
| | |
| | Returns: |
| | A list of :class:`ProcessGroup` object. |
| | """ |
| | return [self.get_group(i) for i in range(self.mesh.ndim)] |
| |
|
| | @staticmethod |
| | def from_group( |
| | group: Union[ProcessGroup, list[ProcessGroup]], |
| | device_type: str, |
| | mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, |
| | *, |
| | mesh_dim_names: Optional[tuple[str, ...]] = None, |
| | ) -> "DeviceMesh": |
| | """ |
| | Constructs a :class:`DeviceMesh` with ``device_type`` from an |
| | existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`. |
| | |
| | The constructed device mesh has number of dimensions equal to the |
| | number of groups passed. For example, if a single process group is passed in, |
| | the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, |
| | the resulted DeviceMesh is a 2D mesh. |
| | |
| | If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments |
| | are required. The order of the process groups passed in determines the topology of |
| | the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. |
| | The `mesh` tensor passed in must have the same number of dimensions as the number of process |
| | groups passed in, and the order of the dimensions in the `mesh` tensor must match the order |
| | in the process groups passed in. |
| | |
| | Args: |
| | group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup |
| | or a list of existing ProcessGroups. |
| | device_type (str): The device type of the mesh. Currently supports: "cpu", |
| | "cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0", |
| | is not allowed. |
| | mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an |
| | integer tensor describing the layout of devices, where the IDs are global IDs |
| | of the default process group. Default is None. |
| | mesh_dim_names (tuple[str], optional): A tuple of mesh dimension names to assign |
| | to each dimension of the multi-dimensional array describing the layout of devices. |
| | Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` |
| | must be unique. Default is None. |
| | |
| | Returns: |
| | DeviceMesh: A :class:`DeviceMesh` object representing the device layout. |
| | """ |
| |
|
| | |
| | if isinstance(group, ProcessGroup): |
| | group_ranks = get_process_group_ranks(group) |
| | if ( |
| | isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks |
| | ) or ( |
| | mesh is not None |
| | and not isinstance(mesh, torch.Tensor) |
| | and mesh != group_ranks |
| | ): |
| | raise ValueError( |
| | f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" |
| | ) |
| | mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) |
| | device_mesh = DeviceMesh( |
| | device_type, |
| | mesh, |
| | mesh_dim_names=mesh_dim_names, |
| | _init_backend=False, |
| | ) |
| | device_mesh._dim_group_names = [group.group_name] |
| | return device_mesh |
| |
|
| | |
| | groups = list(group) |
| | if len(groups) == 0: |
| | raise ValueError("Expects at least one ProcessGroup to be passed") |
| | if mesh is None: |
| | raise ValueError("Must pass mesh if passing multiple ProcessGroups") |
| | if mesh_dim_names is None: |
| | raise ValueError( |
| | "Must pass mesh_dim_names if passing multiple ProcessGroups" |
| | ) |
| | mesh = ( |
| | mesh.detach().to(dtype=torch.int, device="cpu") |
| | if isinstance(mesh, torch.Tensor) |
| | else torch.tensor(mesh, device="cpu", dtype=torch.int) |
| | ) |
| | if mesh.ndim != len(groups): |
| | raise ValueError( |
| | "Expects mesh with ndim equal to number of ProcessGroups but got " |
| | f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" |
| | ) |
| | device_mesh = DeviceMesh( |
| | device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False |
| | ) |
| | device_mesh._dim_group_names = [group.group_name for group in groups] |
| | return device_mesh |
| |
|
| | def size(self, mesh_dim: Optional[int] = None) -> int: |
| | return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim) |
| |
|
| | @property |
| | def ndim(self) -> int: |
| | return self.mesh.ndim |
| |
|
| | @property |
| | def shape(self) -> tuple[int, ...]: |
| | return tuple(self.mesh.shape) |
| |
|
| | def get_rank(self) -> int: |
| | """ |
| | Returns the current global rank. |
| | """ |
| | return get_rank() |
| |
|
| | def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: |
| | """ |
| | Returns the local rank of the given mesh_dim of the DeviceMesh. |
| | |
| | Args: |
| | mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index |
| | of the mesh dimension. Default is None. |
| | |
| | Returns: |
| | An integer denotes the local rank. |
| | |
| | The following program runs on each process/rank in an SPMD manner. In this example, we have 2 |
| | hosts with 4 GPUs each. |
| | Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. |
| | Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. |
| | Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. |
| | Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. |
| | Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. |
| | Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +SKIP("no rank") |
| | >>> from torch.distributed.device_mesh import DeviceMesh |
| | >>> |
| | >>> # Initialize device mesh as (2, 4) to represent the topology |
| | >>> # of cross-host(dim 0), and within-host (dim 1). |
| | >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) |
| | """ |
| | if self.ndim > 1 and mesh_dim is None: |
| | raise RuntimeError( |
| | f"Found the DeviceMesh have {self.mesh.ndim} dimensions", |
| | "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", |
| | ) |
| | elif mesh_dim is None: |
| | mesh_dim = 0 |
| |
|
| | mesh_dim_group = not_none(self.get_group(mesh_dim)) |
| | assert isinstance(mesh_dim_group, ProcessGroup), ( |
| | "We expect ProcessGroup before calling `get_rank`!" |
| | ) |
| | return not_none(get_rank(mesh_dim_group)) |
| |
|
| | def get_coordinate(self) -> Optional[list[int]]: |
| | """ |
| | Return the relative indices of this rank relative to all |
| | dimensions of the mesh. If this rank is not part of the mesh, return None. |
| | """ |
| | return self._coordinate_on_dim if self._coordinate_on_dim else None |
| |
|
| | def _flatten( |
| | self, |
| | mesh_dim_name: Optional[str] = None, |
| | backend_override: Union[ |
| | None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] |
| | ] = None, |
| | ) -> "DeviceMesh": |
| | """ |
| | Returns a 1D DeviceMesh by flattening the current DeviceMesh. |
| | |
| | If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the |
| | given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh |
| | DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling |
| | mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 2, 4, 6], mesh_dim_names=("dp_cp",)) |
| | on rank 0, 2, 4, 6 and a 1D submesh DeviceMesh([1, 3, 5, 7], mesh_dim_names=("dp_cp",)) on rank 1, 3, 5, 7. |
| | |
| | After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the |
| | existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. |
| | """ |
| | if not self.mesh_dim_names: |
| | raise RuntimeError( |
| | "Cannot flatten a DeviceMesh without mesh_dim_names!" |
| | ) |
| |
|
| | if backend_override is not None: |
| | (backend_override_tuple,) = _normalize_backend_override( |
| | {0: backend_override}, 1 |
| | ) |
| | else: |
| | backend_override_tuple = (None, None) |
| |
|
| | return _mesh_resources.create_flatten_mesh( |
| | self, mesh_dim_name, backend_override_tuple |
| | ) |
| |
|
| | def _normalize_backend_override( |
| | backend_override: dict[ |
| | Union[int, str], |
| | Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], |
| | ], |
| | ndim: int, |
| | mesh_dim_names: Optional[tuple[str, ...]] = None, |
| | ) -> Iterator[tuple[Optional[str], Optional[C10dBackend.Options]]]: |
| | if mesh_dim_names is None: |
| | mesh_dim_names = () |
| | for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names): |
| | if dim_name is not None and dim_name in backend_override: |
| | if dim_idx in backend_override: |
| | raise RuntimeError( |
| | f"Found redundant dim index {dim_idx} and " |
| | f"name {dim_name} in backend_override" |
| | ) |
| | val = backend_override.pop(dim_name) |
| | elif dim_idx in backend_override: |
| | val = backend_override.pop(dim_idx) |
| | else: |
| | yield (None, None) |
| | continue |
| |
|
| | if isinstance(val, str): |
| | yield (val, None) |
| | elif isinstance(val, C10dBackend.Options): |
| | yield (None, val) |
| | else: |
| | yield val |
| |
|
| | if backend_override: |
| | raise RuntimeError( |
| | f"Found invalid keys in backend_override: got {list(backend_override.keys())}, " |
| | f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" |
| | ) |
| |
|
| | def init_device_mesh( |
| | device_type: str, |
| | mesh_shape: tuple[int, ...], |
| | *, |
| | mesh_dim_names: Optional[tuple[str, ...]] = None, |
| | backend_override: Optional[ |
| | dict[ |
| | Union[int, str], |
| | Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], |
| | ] |
| | ] = None, |
| | ) -> DeviceMesh: |
| | """ |
| | Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. |
| | |
| | This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. |
| | If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. |
| | |
| | .. note:: |
| | `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program |
| | runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array |
| | describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. |
| | |
| | .. note:: |
| | If no process group is found, init_device_mesh will initialize distributed process group/groups |
| | required for distributed communications behind the scene. |
| | |
| | Args: |
| | device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu". |
| | Passing in a device type with a GPU index, such as "cuda:0", is not allowed. |
| | mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array |
| | describing the layout of devices. |
| | mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension |
| | of the multi-dimensional array describing the layout of devices. Its length must match the length |
| | of `mesh_shape`. Each string in `mesh_dim_names` must be unique. |
| | backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of |
| | the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a |
| | dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name |
| | of the backend and its options, or just one of these two components (in which case the other will be |
| | set to its default value). |
| | |
| | Returns: |
| | DeviceMesh: A :class:`DeviceMesh` object representing the device layout. |
| | |
| | Example:: |
| | |
| | >>> # xdoctest: +SKIP("no rank") |
| | >>> from torch.distributed.device_mesh import init_device_mesh |
| | >>> |
| | >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) |
| | >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) |
| | |
| | """ |
| | if mesh_dim_names is not None: |
| | if len(set(mesh_dim_names)) != len(mesh_dim_names): |
| | raise RuntimeError( |
| | "Each mesh_dim_name must be unique.", |
| | f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", |
| | ) |
| |
|
| | if len(mesh_shape) != len(mesh_dim_names): |
| | raise RuntimeError( |
| | "mesh_shape and mesh_dim_names should have same length!", |
| | f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", |
| | ) |
| |
|
| | if backend_override is not None: |
| | backend_override_tuple = tuple( |
| | _normalize_backend_override( |
| | backend_override, len(mesh_shape), mesh_dim_names |
| | ) |
| | ) |
| | else: |
| | backend_override_tuple = None |
| |
|
| | |
| | if device_type and not device_type.isalpha(): |
| | raise RuntimeError( |
| | f"Device type with index is not supported but got {device_type}. ", |
| | "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.", |
| | ) |
| |
|
| | |
| | |
| | with torch.device("cpu"): |
| | mesh = torch.arange(math.prod(mesh_shape), dtype=torch.int).view(mesh_shape) |
| | device_mesh = DeviceMesh( |
| | device_type=device_type, |
| | mesh=mesh, |
| | mesh_dim_names=mesh_dim_names, |
| | backend_override=backend_override_tuple, |
| | ) |
| |
|
| | return device_mesh |
| |
|