| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """A unified interface for model parallelism and data parallelism. |
| | |
| | Supports model parallelism types: |
| | - mp_replicate: Replicate model across multiple devices. |
| | - mp_shard: Shard model across multiple devices. |
| | |
| | And data parallelism types: |
| | - dp: Data parallelism. |
| | - cp: Context parallelism. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| | from datetime import timedelta |
| | from enum import Enum |
| | from typing import Any, Optional |
| |
|
| | from torch.distributed import barrier, destroy_process_group, init_process_group |
| | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
| |
|
| | from ..utils import logging |
| | from ..utils.types import DistributedConfig, ProcessGroup, TensorLike |
| | from . import helper |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class Dim(str, Enum): |
| | """Dimension names.""" |
| |
|
| | MP_REPLICATE = "mp_replicate" |
| | MP_SHARD = "mp_shard" |
| | DP = "dp" |
| | CP = "cp" |
| |
|
| |
|
| | @dataclass |
| | class DistributedStrategy: |
| | """Distributed strategy.""" |
| |
|
| | mp_replicate_size: int = 1 |
| | """Model parallel replicate size, default to 1.""" |
| | mp_shard_size: int | None = None |
| | """Model parallel shard size, default to world_size // mp_replicate_size.""" |
| | dp_size: int | None = None |
| | """Data parallel size, default to world_size // cp_size.""" |
| | cp_size: int = 1 |
| | """Context parallel size, default to 1.""" |
| |
|
| | def __post_init__(self) -> None: |
| | if not helper.is_distributed(): |
| | self.mp_shard_size = 1 |
| | elif self.mp_shard_size is None: |
| | self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size |
| | elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size(): |
| | raise ValueError( |
| | f"mp_replicate_size * mp_shard_size must equal to world_size, " |
| | f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}." |
| | ) |
| |
|
| | if not helper.is_distributed(): |
| | self.dp_size = 1 |
| | elif self.dp_size is None: |
| | self.dp_size = helper.get_world_size() // self.cp_size |
| | elif self.dp_size * self.cp_size != helper.get_world_size(): |
| | raise ValueError( |
| | f"dp_size * cp_size must equal to world_size, " |
| | f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}." |
| | ) |
| |
|
| | @property |
| | def model_mesh_shape(self) -> tuple[int, int]: |
| | """Model parallel mesh shape.""" |
| | return (self.mp_replicate_size, self.mp_shard_size) |
| |
|
| | @property |
| | def model_mesh_dim_names(self) -> tuple[str, str]: |
| | """Model parallel mesh dimension names.""" |
| | return (Dim.MP_REPLICATE.value, Dim.MP_SHARD.value) |
| |
|
| | @property |
| | def data_mesh_shape(self) -> tuple[int, int]: |
| | """Data parallel mesh shape.""" |
| | return (self.dp_size, self.cp_size) |
| |
|
| | @property |
| | def data_mesh_dim_names(self) -> tuple[str, str]: |
| | """Data parallel mesh dimension names.""" |
| | return (Dim.DP.value, Dim.CP.value) |
| |
|
| |
|
| | class DistributedInterface: |
| | """Distributed interface.""" |
| |
|
| | _instance: Optional["DistributedInterface"] = None |
| | _initialized: bool = False |
| |
|
| | def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface": |
| | """Singleton pattern.""" |
| | if cls._instance is None: |
| | cls._instance = super().__new__(cls) |
| |
|
| | return cls._instance |
| |
|
| | def __init__(self, config: DistributedConfig | None = None) -> None: |
| | if self._initialized: |
| | return |
| |
|
| | helper.set_device_index() |
| | self._is_distributed = helper.is_distributed() |
| | self._rank = helper.get_rank() |
| | self._world_size = helper.get_world_size() |
| | self._local_rank = helper.get_local_rank() |
| | self._local_world_size = helper.get_local_world_size() |
| | self.current_device = helper.get_current_device() |
| | self.device_count = helper.get_device_count() |
| |
|
| | if config is None: |
| | self.strategy = DistributedStrategy() |
| | timeout = 18000 |
| | else: |
| | self.strategy = DistributedStrategy( |
| | mp_replicate_size=config.get("mp_replicate_size", 1), |
| | mp_shard_size=config.get("mp_shard_size", None), |
| | dp_size=config.get("dp_size", None), |
| | cp_size=config.get("cp_size", 1), |
| | ) |
| | timeout = config.get("timeout", 18000) |
| |
|
| | if self._is_distributed: |
| | init_process_group(timeout=timedelta(seconds=timeout)) |
| | self.model_device_mesh = init_device_mesh( |
| | device_type=self.current_device.type, |
| | mesh_shape=self.strategy.model_mesh_shape, |
| | mesh_dim_names=self.strategy.model_mesh_dim_names, |
| | ) |
| | self.data_device_mesh = init_device_mesh( |
| | device_type=self.current_device.type, |
| | mesh_shape=self.strategy.data_mesh_shape, |
| | mesh_dim_names=self.strategy.data_mesh_dim_names, |
| | ) |
| | else: |
| | self.model_device_mesh = None |
| | self.data_device_mesh = None |
| |
|
| | self._initialized = True |
| | logger.info_rank0(f"DistributedInterface initialized: {self}.") |
| |
|
| | def __str__(self) -> str: |
| | return ( |
| | f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, " |
| | f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, " |
| | f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" |
| | ) |
| |
|
| | def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None: |
| | """Get device mesh for specified dimension.""" |
| | if dim is None: |
| | raise ValueError("dim must be specified.") |
| | elif not self._is_distributed: |
| | return None |
| | elif dim in self.strategy.data_mesh_dim_names: |
| | return self.data_device_mesh[dim.value] |
| | else: |
| | return self.model_device_mesh[dim.value] |
| |
|
| | def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]: |
| | """Get process group for specified dimension.""" |
| | if not self._is_distributed or dim is None: |
| | return None |
| | else: |
| | return self.get_device_mesh(dim).get_group() |
| |
|
| | def get_rank(self, dim: Dim | None = None) -> int: |
| | """Get parallel rank for specified dimension.""" |
| | if not self._is_distributed: |
| | return 0 |
| | elif dim is None: |
| | return self._rank |
| | else: |
| | return self.get_device_mesh(dim).get_local_rank() |
| |
|
| | def get_world_size(self, dim: Dim | None = None) -> int: |
| | """Get parallel size for specified dimension.""" |
| | if not self._is_distributed: |
| | return 1 |
| | elif dim is None: |
| | return self._world_size |
| | else: |
| | return self.get_device_mesh(dim).size() |
| |
|
| | def get_local_rank(self) -> int: |
| | """Get parallel local rank.""" |
| | return self._local_rank |
| |
|
| | def get_local_world_size(self) -> int: |
| | """Get parallel local world size.""" |
| | return self._local_world_size |
| |
|
| | def all_gather(self, data: TensorLike, dim: Dim | None = Dim.DP) -> TensorLike: |
| | """Gather tensor across specified parallel group.""" |
| | if self._is_distributed: |
| | return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim)) |
| | else: |
| | return data |
| |
|
| | def all_reduce( |
| | self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP |
| | ) -> TensorLike: |
| | """Reduce tensor across specified parallel group.""" |
| | if self._is_distributed: |
| | return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim)) |
| | else: |
| | return data |
| |
|
| | def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike: |
| | """Broadcast tensor across specified parallel group.""" |
| | if self._is_distributed: |
| | return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim)) |
| | else: |
| | return data |
| |
|
| | def sync(self) -> None: |
| | """Synchronize all processes.""" |
| | if self._is_distributed: |
| | helper.synchronize() |
| |
|
| | def barrier(self) -> None: |
| | """Barrier all processes.""" |
| | if self._is_distributed: |
| | barrier() |
| |
|
| | def destroy(self) -> None: |
| | """Destroy all processes.""" |
| | if self._is_distributed: |
| | destroy_process_group() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | python -m llamafactory.v1.accelerator.interface |
| | """ |
| | print(DistributedInterface()) |
| |
|