| | |
| | |
| | |
| | |
| | |
| |
|
| | from collections.abc import Callable |
| | from dataclasses import dataclass |
| | from functools import cached_property |
| |
|
| | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
| |
|
| | from torchtitan.tools.logging import logger |
| |
|
| |
|
| | __all__ = ["ParallelDims"] |
| |
|
| |
|
| | @dataclass |
| | class ParallelDims: |
| | dp_replicate: int |
| | dp_shard: int |
| | cp: int |
| | tp: int |
| | pp: int |
| | world_size: int |
| | enable_loss_parallel: bool |
| |
|
| | def __post_init__(self): |
| | self._validate() |
| |
|
| | def _validate(self): |
| | dp_replicate, dp_shard, cp, tp, pp = ( |
| | self.dp_replicate, |
| | self.dp_shard, |
| | self.cp, |
| | self.tp, |
| | self.pp, |
| | ) |
| | for d in (dp_replicate, cp, tp, pp): |
| | assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" |
| |
|
| | assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." |
| | if dp_shard < 0: |
| | self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp) |
| | assert dp_shard >= 1 |
| |
|
| | assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, ( |
| | f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " |
| | f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" |
| | ) |
| |
|
| | def build_mesh(self, device_type: str) -> DeviceMesh: |
| | dims = [] |
| | names = [] |
| | for d, name in zip( |
| | [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], |
| | ["pp", "dp_replicate", "dp_shard", "cp", "tp"], |
| | ): |
| | if d > 1: |
| | dims.append(d) |
| | names.append(name) |
| |
|
| | return self._build_mesh(device_type, dims, names, init_device_mesh) |
| |
|
| | def _build_mesh( |
| | self, |
| | device_type: str, |
| | dims: list[int], |
| | names: list[str], |
| | init_device_mesh_fn: Callable, |
| | ) -> DeviceMesh: |
| | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") |
| | mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names) |
| |
|
| | |
| | |
| | |
| | dp_mesh_dim_names = [] |
| | |
| | dp_shard_cp_mesh_dim_names = [] |
| | |
| | dp_cp_mesh_dim_names = [] |
| |
|
| | if self.dp_replicate_enabled: |
| | dp_mesh_dim_names.append("dp_replicate") |
| | dp_cp_mesh_dim_names.append("dp_replicate") |
| | if self.dp_shard_enabled: |
| | dp_mesh_dim_names.append("dp_shard") |
| | dp_shard_cp_mesh_dim_names.append("dp_shard") |
| | dp_cp_mesh_dim_names.append("dp_shard") |
| | if self.cp_enabled: |
| | dp_shard_cp_mesh_dim_names.append("cp") |
| | dp_cp_mesh_dim_names.append("cp") |
| |
|
| | if dp_mesh_dim_names != []: |
| | mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") |
| | if dp_shard_cp_mesh_dim_names != []: |
| | mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( |
| | mesh_dim_name="dp_shard_cp" |
| | ) |
| | if dp_cp_mesh_dim_names != []: |
| | mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") |
| |
|
| | return mesh |
| |
|
| | @property |
| | def dp_enabled(self): |
| | return self.dp_replicate > 1 or self.dp_shard > 1 |
| |
|
| | @property |
| | def dp_replicate_enabled(self): |
| | return self.dp_replicate > 1 |
| |
|
| | @property |
| | def dp_shard_enabled(self): |
| | return self.dp_shard > 1 |
| |
|
| | @property |
| | def cp_enabled(self): |
| | return self.cp > 1 |
| |
|
| | @property |
| | def tp_enabled(self): |
| | return self.tp > 1 |
| |
|
| | @property |
| | def pp_enabled(self): |
| | return self.pp > 1 |
| |
|
| | @property |
| | def loss_parallel_enabled(self): |
| | return self.tp > 1 and self.enable_loss_parallel |
| |
|
| | @cached_property |
| | def non_data_parallel_size(self): |
| | return self.cp * self.tp * self.pp |
| |
|