| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import warnings |
| | from dataclasses import dataclass |
| | from typing import TYPE_CHECKING, Literal, Optional, Union |
| |
|
| | from accelerate.utils.dataclasses import ( |
| | DeepSpeedSequenceParallelConfig, |
| | DistributedType, |
| | TorchContextParallelConfig, |
| | TorchTensorParallelConfig, |
| | ) |
| | from accelerate.utils.versions import is_torch_version |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from accelerate import Accelerator |
| |
|
| |
|
| | @dataclass |
| | class ParallelismConfig: |
| | """ |
| | A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims` |
| | https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py |
| | |
| | Args: |
| | dp_replicate_size (`int`, defaults to `1`): |
| | The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication |
| | group will not be used. |
| | dp_shard_size (`int`, defaults to `1`): |
| | The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also |
| | be greater than 1, as composing DDP + TP is currently not supported. |
| | tp_size (`int`, defaults to `1`): |
| | The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be |
| | used. |
| | tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`): |
| | The handler for the tensor parallel group. |
| | cp_size (`int`, defaults to `1`): |
| | The size of the context parallel group. Currently not supported, but reserved for future use and enabled |
| | for downstream libraries. |
| | cp_backend (`str`, defaults to `torch`): |
| | Which CP backend to use: `torch` (FSDP2) |
| | sp_size (`int`, defaults to `1`): |
| | The size of the sequence parallel group. |
| | sp_backend (`str`, defaults to `deepspeed`): |
| | Which SP backend to use:`deepspeed` (ALST/Ulysses) |
| | |
| | You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size` |
| | together: |
| | - `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP). |
| | - `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP). |
| | - `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use |
| | `DistributedDataParallelKwargs` instead. |
| | |
| | """ |
| |
|
| | dp_replicate_size: Optional[int] = None |
| | dp_shard_size: Optional[int] = None |
| | tp_size: Optional[int] = None |
| | cp_size: Optional[int] = None |
| | cp_backend: Literal["torch"] = None |
| | sp_size: Optional[int] = None |
| | sp_backend: Literal["deepspeed"] = None |
| |
|
| | |
| | tp_handler: Union[None, TorchTensorParallelConfig] = None |
| | cp_handler: Union[None, TorchContextParallelConfig] = None |
| | sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None |
| |
|
| | device_mesh = None |
| |
|
| | def __repr__(self): |
| | return ( |
| | "ParallelismConfig(\n " |
| | f"\tdp_replicate_size={self.dp_replicate_size},\n" |
| | f"\tdp_shard_size={self.dp_shard_size},\n" |
| | f"\ttp_size={self.tp_size},\n" |
| | f"\tcp_size={self.cp_size},\n" |
| | f"\tcp_backend={self.cp_backend},\n" |
| | f"\tsp_size={self.sp_size},\n" |
| | f"\tsp_backend={self.sp_backend},\n" |
| | f"\ttotal_size={self.total_size}\n" |
| | f"\ttp_handler={self.tp_handler},\n" |
| | f"\tcp_handler={self.cp_handler})\n" |
| | ) |
| |
|
| | def to_json(self): |
| | import copy |
| |
|
| | _non_serializable_fields = ["device_mesh"] |
| |
|
| | copy.deepcopy( |
| | { |
| | k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v |
| | for k, v in self.__dict__.items() |
| | if k not in _non_serializable_fields |
| | } |
| | ) |
| |
|
| | @property |
| | def dp_dim_names(self): |
| | """Names of enabled dimensions across which data parallelism is applied.""" |
| | dims = [] |
| | if self.dp_replicate_enabled: |
| | dims += ["dp_replicate"] |
| | if self.dp_shard_enabled: |
| | dims += ["dp_shard"] |
| | return dims |
| |
|
| | @property |
| | def non_dp_dim_names(self): |
| | """Names of enabled dimensions which will receive the same batch (non-data parallel dimensions).""" |
| | dims = [] |
| | if self.tp_enabled: |
| | dims += ["tp"] |
| | if self.cp_enabled: |
| | dims += ["cp"] |
| | if self.sp_enabled: |
| | dims += ["sp"] |
| | return dims |
| |
|
| | @property |
| | def dp_shard_cp_dim_names(self): |
| | """Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP.""" |
| | dims = [] |
| | if self.dp_shard_enabled: |
| | dims += ["dp_shard"] |
| | if self.cp_enabled: |
| | dims += ["cp"] |
| | return dims |
| |
|
| | @property |
| | def dp_cp_dim_names(self): |
| | """Names of enabled dimensions across which loss should be averaged""" |
| | dims = [] |
| | if self.dp_replicate_enabled: |
| | dims += ["dp_replicate"] |
| | if self.dp_shard_enabled: |
| | dims += ["dp_shard"] |
| | if self.cp_enabled: |
| | dims += ["cp"] |
| | return dims |
| |
|
| | @property |
| | def fsdp_dim_names(self): |
| | """Names of enabled dimensions across which FSDP is applied, including data parallel replication.""" |
| | dims = [] |
| | if self.dp_replicate_enabled: |
| | dims += ["dp_replicate"] |
| | dims += ["dp_shard_cp"] |
| | return dims |
| |
|
| | @property |
| | def total_size(self): |
| | """The total size of the parallelism configuration, which is the product of all sizes.""" |
| | return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size * self.sp_size |
| |
|
| | @property |
| | def non_data_parallel_size(self): |
| | """The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes.""" |
| | return self.tp_size * self.cp_size * self.sp_size |
| |
|
| | @property |
| | def data_parallel_size(self): |
| | """The size of the data parallel dimensions, which is the product of data parallel replication and""" |
| | return self.dp_replicate_size * self.dp_shard_size |
| |
|
| | @property |
| | def dp_replicate_enabled(self): |
| | """True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`.""" |
| | return self.dp_replicate_size > 1 |
| |
|
| | @property |
| | def dp_shard_enabled(self): |
| | """True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`.""" |
| | return self.dp_shard_size > 1 |
| |
|
| | @property |
| | def tp_enabled(self): |
| | """True if tensor parallelism is enabled, i.e. `tp_size > 1`.""" |
| | return self.tp_size > 1 |
| |
|
| | @property |
| | def cp_enabled(self): |
| | """True if context parallelism is enabled, i.e. `cp_size > 1`.""" |
| | return self.cp_size > 1 |
| |
|
| | @property |
| | def sp_enabled(self): |
| | """True if context parallelism is enabled, i.e. `sp_size > 1`.""" |
| | return self.sp_size > 1 |
| |
|
| | @property |
| | def active_mesh_dims(self): |
| | """Names of all active mesh dimensions.""" |
| | return self.dp_dim_names + self.non_dp_dim_names |
| |
|
| | def build_device_mesh(self, device_type: str): |
| | """Builds a device mesh for the given device type based on the parallelism configuration. |
| | This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`). |
| | |
| | Args: |
| | device_type (`str`): The type of device for which to build the mesh, e |
| | """ |
| | if is_torch_version(">=", "2.2.0"): |
| | from torch.distributed.device_mesh import init_device_mesh |
| | else: |
| | raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0") |
| |
|
| | mesh = self._get_mesh() |
| | if len(mesh) == 0: |
| | return None |
| | mesh_dim_names, mesh_shape = mesh |
| | device_mesh = init_device_mesh( |
| | device_type, |
| | mesh_shape, |
| | mesh_dim_names=mesh_dim_names, |
| | ) |
| | if self.dp_dim_names: |
| | device_mesh[self.dp_dim_names]._flatten("dp") |
| | if self.dp_shard_cp_dim_names: |
| | device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp") |
| | if self.dp_cp_dim_names: |
| | device_mesh[self.dp_cp_dim_names]._flatten("dp_cp") |
| |
|
| | return device_mesh |
| |
|
| | def get_device_mesh(self, device_type: Optional[str] = None): |
| | if self.device_mesh is None: |
| | if device_type is not None: |
| | self.device_mesh = self.build_device_mesh(device_type) |
| | else: |
| | raise ("You need to pass a device_type e.g cuda to build the device mesh") |
| | else: |
| | if device_type is not None: |
| | if self.device_mesh.device_type != device_type: |
| | raise ValueError( |
| | f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh" |
| | ) |
| | return self.device_mesh |
| |
|
| | def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]: |
| | """Generate mesh shape and dimension names for torch.distributed.init_device_mesh().""" |
| |
|
| | |
| | mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims} |
| |
|
| | |
| | mesh_order = ["dp_replicate", "dp_shard", "cp", "sp", "tp"] |
| | sorted_items = sorted( |
| | mesh_dims.items(), |
| | key=lambda x: (mesh_order.index(x[0])), |
| | ) |
| | return tuple(zip(*sorted_items)) |
| |
|
| | def __post_init__(self): |
| | |
| | if self.dp_replicate_size is None: |
| | self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1")) |
| | if self.dp_shard_size is None: |
| | self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1")) |
| | if self.tp_size is None: |
| | self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1")) |
| | if self.cp_size is None: |
| | self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1")) |
| | if self.cp_backend is None: |
| | self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch") |
| | if self.sp_size is None: |
| | self.sp_size = int(os.environ.get("PARALLELISM_CONFIG_SP_SIZE", "1")) |
| | if self.sp_backend is None: |
| | self.sp_backend = os.environ.get("PARALLELISM_CONFIG_SP_BACKEND", "deepspeed") |
| |
|
| | if self.tp_size > 1: |
| | if self.tp_handler is None: |
| | self.tp_handler = TorchTensorParallelConfig() |
| |
|
| | if self.cp_size > 1: |
| | if self.cp_handler is None: |
| | self.cp_handler = TorchContextParallelConfig() |
| | else: |
| | cp_backends_config_map = dict( |
| | torch=TorchContextParallelConfig, |
| | ) |
| | if not isinstance(self.cp_handler, cp_backends_config_map[self.cp_backend]): |
| | raise ValueError( |
| | f"ParallelismConfig's cp_backend={self.cp_backend} requires {cp_backends_config_map[self.cp_backend]}, but cp_handler was set to {type(self.cp_handler)}" |
| | ) |
| |
|
| | if self.sp_size > 1: |
| | if self.sp_handler is None: |
| | self.sp_handler = DeepSpeedSequenceParallelConfig() |
| | if self.dp_replicate_size < 1: |
| | raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}") |
| | if self.dp_shard_size < 1: |
| | raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}") |
| | if self.tp_size < 1: |
| | raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}") |
| | if self.cp_size < 1: |
| | raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}") |
| | valid_cp_backends = ["torch"] |
| | if self.cp_backend not in valid_cp_backends: |
| | raise ValueError(f"cp_backend must be one of {valid_cp_backends}, but got {self.cp_backend}") |
| |
|
| | if self.sp_size < 1: |
| | raise ValueError(f"sp_size must be at least 1, but got {self.sp_size}") |
| | valid_sp_backends = ["deepspeed"] |
| | if self.sp_backend not in valid_sp_backends: |
| | raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}") |
| |
|
| | if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1: |
| | raise ValueError( |
| | "Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). " |
| | "Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, " |
| | "or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel." |
| | ) |
| | self._sizes = { |
| | "dp_replicate": self.dp_replicate_size, |
| | "dp_shard": self.dp_shard_size, |
| | "tp": self.tp_size, |
| | "cp": self.cp_size, |
| | "sp": self.sp_size, |
| | } |
| |
|
| | def _set_size(self, parallelism: str, size: int): |
| | assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}" |
| | self._sizes[parallelism] = size |
| | setattr(self, f"{parallelism}_size", size) |
| |
|
| | def _validate_accelerator(self, accelerator: "Accelerator"): |
| | _warnings = set() |
| | if not accelerator.multi_device and self.total_size == 1: |
| | |
| | return |
| |
|
| | |
| | if self.total_size == 1: |
| | self._set_size("dp_replicate", accelerator.num_processes) |
| |
|
| | if self.total_size != accelerator.num_processes: |
| | raise ValueError( |
| | f"ParallelismConfig total_size ({self.total_size}) does not match " |
| | f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ " |
| | f"dp_shard_size/tp_size/cp_size/sp_size." |
| | ) |
| |
|
| | if self.total_size > 1 and not ( |
| | accelerator.is_fsdp2 |
| | or accelerator.multi_device |
| | or accelerator.distributed_type == DistributedType.DEEPSPEED |
| | ): |
| | raise ValueError( |
| | f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}} or DistributedType.DEEPSPEED, but got {accelerator.distributed_type}." |
| | ) |
| |
|
| | for parallelism, size in self._sizes.items(): |
| | if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None: |
| | _warnings.add( |
| | f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored." |
| | ) |
| |
|
| | if _warnings and accelerator.is_main_process: |
| | warnings.warn( |
| | "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings), |
| | UserWarning, |
| | ) |
| |
|