| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import warnings |
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Optional, Union |
|
|
| from torch.distributed.device_mesh import init_device_mesh |
|
|
| from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig |
|
|
|
|
| if TYPE_CHECKING: |
| from accelerate import Accelerator |
|
|
|
|
| @dataclass |
| class ParallelismConfig: |
| """ |
| A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims` |
| https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py |
| |
| Args: |
| dp_replicate_size (`int`, defaults to `1`): |
| The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication |
| group will not be used. |
| dp_shard_size (`int`, defaults to `1`): |
| The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also |
| be greater than 1, as composing DDP + TP is currently not supported. |
| tp_size (`int`, defaults to `1`): |
| The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be |
| used. |
| cp_size (`int`, defaults to `1`): |
| The size of the context parallel group. Currently not supported, but reserved for future use and enabled |
| for downstream libraries. |
| tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`): |
| The handler for the tensor parallel group. |
| |
| You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size` |
| together: |
| - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP). |
| - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP). |
| - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use |
| `DistributedDataParallelKwargs` instead. |
| |
| """ |
|
|
| dp_replicate_size: int = None |
| dp_shard_size: int = None |
| tp_size: int = None |
| cp_size: int = None |
|
|
| |
| tp_handler: Union[None, TorchTensorParallelConfig] = None |
| cp_handler: Union[None, TorchContextParallelConfig] = None |
|
|
| device_mesh = None |
|
|
| def __repr__(self): |
| return ( |
| "ParallelismConfig(\n " |
| f"\tdp_replicate_size={self.dp_replicate_size},\n" |
| f"\tdp_shard_size={self.dp_shard_size},\n" |
| f"\ttp_size={self.tp_size},\n" |
| f"\tcp_size={self.cp_size},\n" |
| f"\ttotal_size={self.total_size}\n" |
| f"\ttp_handler={self.tp_handler},\n" |
| f"\tcp_handler={self.cp_handler})\n" |
| ) |
|
|
| @property |
| def dp_dim_names(self): |
| """Names of enabled dimensions across which data parallelism is applied.""" |
| dims = [] |
| if self.dp_replicate_enabled: |
| dims += ["dp_replicate"] |
| if self.dp_shard_enabled: |
| dims += ["dp_shard"] |
| return dims |
|
|
| @property |
| def non_dp_dim_names(self): |
| """Names of enabled dimensions which will receive the same batch (non-data parallel dimensions).""" |
| dims = [] |
| if self.tp_enabled: |
| dims += ["tp"] |
| if self.cp_enabled: |
| dims += ["cp"] |
| return dims |
|
|
| @property |
| def dp_shard_cp_dim_names(self): |
| """Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP.""" |
| dims = [] |
| if self.dp_shard_enabled: |
| dims += ["dp_shard"] |
| if self.cp_enabled: |
| dims += ["cp"] |
| return dims |
|
|
| @property |
| def dp_cp_dim_names(self): |
| """Names of enabled dimensions across which loss should be averaged""" |
| dims = [] |
| if self.dp_replicate_enabled: |
| dims += ["dp_replicate"] |
| if self.dp_shard_enabled: |
| dims += ["dp_shard"] |
| if self.cp_enabled: |
| dims += ["cp"] |
| return dims |
|
|
| @property |
| def fsdp_dim_names(self): |
| """Names of enabled dimensions across which FSDP is applied, including data parallel replication.""" |
| dims = [] |
| if self.dp_replicate_enabled: |
| dims += ["dp_replicate"] |
| dims += ["dp_shard_cp"] |
| return dims |
|
|
| @property |
| def total_size(self): |
| """The total size of the parallelism configuration, which is the product of all sizes.""" |
| return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size |
|
|
| @property |
| def non_data_parallel_size(self): |
| """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes.""" |
| return self.tp_size * self.cp_size |
|
|
| @property |
| def data_parallel_size(self): |
| """The size of the data parallel dimensions, which is the product of data parallel replication and""" |
| return self.dp_replicate_size * self.dp_shard_size |
|
|
| @property |
| def dp_replicate_enabled(self): |
| """True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`.""" |
| return self.dp_replicate_size > 1 |
|
|
| @property |
| def dp_shard_enabled(self): |
| """True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`.""" |
| return self.dp_shard_size > 1 |
|
|
| @property |
| def tp_enabled(self): |
| """True if tensor parallelism is enabled, i.e. `tp_size > 1`.""" |
| return self.tp_size > 1 |
|
|
| @property |
| def cp_enabled(self): |
| """True if context parallelism is enabled, i.e. `cp_size > 1`.""" |
| return self.cp_size > 1 |
|
|
| @property |
| def active_mesh_dims(self): |
| """Names of all active mesh dimensions.""" |
| return self.dp_dim_names + self.non_dp_dim_names |
|
|
| def build_device_mesh(self, device_type: str): |
| """Builds a device mesh for the given device type based on the parallelism configuration. |
| This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`). |
| |
| Args: |
| device_type (`str`): The type of device for which to build the mesh, e |
| """ |
| mesh = self._get_mesh() |
| if len(mesh) == 0: |
| return None |
| mesh_dim_names, mesh_shape = mesh |
| device_mesh = init_device_mesh( |
| device_type, |
| mesh_shape, |
| mesh_dim_names=mesh_dim_names, |
| ) |
| if self.dp_dim_names: |
| device_mesh[self.dp_dim_names]._flatten("dp") |
| if self.dp_shard_cp_dim_names: |
| device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp") |
| if self.dp_cp_dim_names: |
| device_mesh[self.dp_cp_dim_names]._flatten("dp_cp") |
|
|
| return device_mesh |
|
|
| def get_device_mesh(self, device_type: Optional[str] = None): |
| if self.device_mesh is None: |
| if device_type is not None: |
| self.device_mesh = self.build_device_mesh(device_type) |
| else: |
| raise ("You need to pass a device_type e.g cuda to build the device mesh") |
| else: |
| if device_type is not None: |
| if self.device_mesh.device_type != device_type: |
| raise ValueError( |
| f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh" |
| ) |
| return self.device_mesh |
|
|
| def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]: |
| """Generate mesh shape and dimension names for torch.distributed.init_device_mesh().""" |
|
|
| |
| mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims} |
|
|
| |
| mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"] |
| sorted_items = sorted( |
| mesh_dims.items(), |
| key=lambda x: (mesh_order.index(x[0])), |
| ) |
| return tuple(zip(*sorted_items)) |
|
|
| def __post_init__(self): |
| |
| if self.dp_replicate_size is None: |
| self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1")) |
| if self.dp_shard_size is None: |
| self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1")) |
| if self.tp_size is None: |
| self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1")) |
| if self.cp_size is None: |
| self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1")) |
|
|
| if self.tp_size > 1: |
| if self.tp_handler is None: |
| self.tp_handler = TorchTensorParallelConfig() |
|
|
| if self.cp_size > 1: |
| if self.cp_handler is None: |
| self.cp_handler = TorchContextParallelConfig() |
|
|
| if self.dp_replicate_size < 1: |
| raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}") |
| if self.dp_shard_size < 1: |
| raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}") |
| if self.tp_size < 1: |
| raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}") |
| if self.cp_size < 1: |
| raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}") |
|
|
| if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1: |
| raise ValueError( |
| "Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). " |
| "Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, " |
| "or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel." |
| ) |
| self._sizes = { |
| "dp_replicate": self.dp_replicate_size, |
| "dp_shard": self.dp_shard_size, |
| "tp": self.tp_size, |
| "cp": self.cp_size, |
| } |
|
|
| def _set_size(self, parallelism: str, size: int): |
| assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}" |
| self._sizes[parallelism] = size |
| setattr(self, f"{parallelism}_size", size) |
|
|
| def _validate_accelerator(self, accelerator: "Accelerator"): |
| _warnings = set() |
| if not accelerator.multi_device and self.total_size == 1: |
| |
| return |
|
|
| |
| if self.total_size == 1: |
| self._set_size("dp_replicate", accelerator.num_processes) |
|
|
| if self.total_size != accelerator.num_processes: |
| raise ValueError( |
| f"ParallelismConfig total_size ({self.total_size}) does not match " |
| f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ " |
| f"dp_shard_size/tp_size/cp_size." |
| ) |
|
|
| if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device): |
| raise ValueError( |
| f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}." |
| ) |
|
|
| for parallelism, size in self._sizes.items(): |
| if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None: |
| _warnings.add( |
| f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored." |
| ) |
|
|
| if _warnings and accelerator.is_main_process: |
| warnings.warn( |
| "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings), |
| UserWarning, |
| ) |
|
|