| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
| import math |
| import os |
| from dataclasses import dataclass |
| from functools import wraps |
| from typing import TYPE_CHECKING, Callable, Literal, Optional |
|
|
| import torch |
| from torch import distributed as dist |
|
|
| from ..utils import logging |
| from ..utils.import_utils import is_torch_npu_available, is_torch_version_greater_than |
|
|
|
|
| if is_torch_version_greater_than("2.4"): |
| from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
|
|
|
|
| if TYPE_CHECKING: |
| from torch.distributed import ProcessGroup |
| from torch.distributed.device_mesh import DeviceMesh |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _PARALLEL_STATE: "ParallelState" = None |
|
|
|
|
| def requires_mesh(fn: Callable) -> Callable: |
| @wraps(fn) |
| def _inner(self: "ParallelState", *args, **kwargs): |
| if self.device_mesh is None: |
| raise ValueError("Device mesh is not initialized.") |
|
|
| return fn(self, *args, **kwargs) |
|
|
| return _inner |
|
|
|
|
| def init_ep_mesh_matrix(ep_size: int, ep_fsdp_size: int, ep_outside: bool = False) -> "DeviceMesh": |
| """ |
| Initialize the device mesh matrix for the EP. |
| Args: |
| ep_size (int): The size of the EP. |
| ep_fsdp_size (int): The size of the EP-FSDP. |
| ep_outside (bool): Whether the EP is outside in ep-fsdp group. |
| """ |
| if ep_outside: |
| with torch.device("cpu"): |
| mesh = torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int).view(ep_size, ep_fsdp_size) |
| else: |
| with torch.device("cpu"): |
| mesh = ( |
| torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int) |
| .view(ep_fsdp_size, ep_size) |
| .transpose(0, 1) |
| ) |
| return mesh |
|
|
|
|
| @dataclass(frozen=True) |
| class ParallelState: |
| dp_size: int = 1 |
| dp_replicate_size: int = 1 |
| dp_shard_size: int = 1 |
| tp_size: int = 1 |
| ep_size: int = 1 |
| pp_size: int = 1 |
| cp_size: int = 1 |
| ulysses_size: int = 1 |
| dp_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp1" |
| device_type: str = "npu" if is_torch_npu_available() else "cuda" |
| include_sp_in_fsdp: bool = True |
| device_mesh: Optional["DeviceMesh"] = None |
| ep_fsdp_device_mesh: Optional["DeviceMesh"] = None |
|
|
| def __post_init__(self): |
| if not self.include_sp_in_fsdp: |
| raise NotImplementedError("Decoupled sequence parallel has not been implemented.") |
|
|
| if self.cp_size > 1: |
| raise NotImplementedError("Ring attention is not supported yet.") |
|
|
| if self.pp_size * self.dp_size * self.cp_size * self.ulysses_size * self.tp_size != self.world_size: |
| raise ValueError("The product of parallel sizes should be equal to the world size.") |
|
|
| if self.dp_replicate_size * self.dp_shard_size != self.dp_size: |
| raise ValueError( |
| f"The product of dp_replicate_size: {self.dp_replicate_size} and dp_shard_size: {self.dp_shard_size} should be equal to dp_size: {self.dp_size}." |
| ) |
|
|
| if self.sp_enabled: |
| from ..distributed.sequence_parallel import ( |
| init_sequence_parallel, |
| set_context_parallel_group, |
| set_data_parallel_group, |
| set_ulysses_sequence_parallel_group, |
| set_unified_sequence_parallel_group, |
| ) |
|
|
| if self.device_mesh is not None: |
| set_data_parallel_group(self.device_mesh.get_group("dp")) |
| if self.ulysses_size > 1: |
| set_ulysses_sequence_parallel_group(self.device_mesh.get_group("ulysses")) |
| if self.cp_size > 1: |
| set_context_parallel_group(self.device_mesh.get_group("cp")) |
| |
| set_unified_sequence_parallel_group(self.device_mesh.get_group("sp")) |
| else: |
| init_sequence_parallel( |
| ulysses_size=self.ulysses_size, |
| sep_dp=True, |
| ulysses_group_key="default", |
| cp_size=self.cp_size, |
| ) |
|
|
| @property |
| def is_initialized(self) -> bool: |
| return dist.is_initialized() |
|
|
| @property |
| def local_rank(self) -> int: |
| return int(os.getenv("LOCAL_RANK", "-1")) |
|
|
| @property |
| def global_rank(self) -> int: |
| if self.is_initialized: |
| return dist.get_rank() |
| return -1 |
|
|
| @property |
| def world_size(self) -> int: |
| if self.is_initialized: |
| return dist.get_world_size() |
| return 1 |
|
|
| |
| @property |
| def dp_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("dp") |
|
|
| if self.sp_enabled: |
| from ..distributed.sequence_parallel import get_data_parallel_group |
|
|
| return get_data_parallel_group() |
|
|
| return self.fsdp_group |
|
|
| @property |
| def dp_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("dp") |
|
|
| if self.sp_enabled: |
| from ..distributed.sequence_parallel import get_data_parallel_rank |
|
|
| return get_data_parallel_rank() |
|
|
| return self.fsdp_rank |
|
|
| @property |
| @requires_mesh |
| def dp_mesh(self) -> "DeviceMesh": |
| if self.device_mesh is not None: |
| return self.device_mesh["dp"] |
|
|
| raise self.fsdp_mesh |
|
|
| @property |
| def dp_enabled(self) -> bool: |
| return self.dp_size > 1 |
|
|
| |
| @property |
| def dp_replicate_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("dp_replicate") |
|
|
| @property |
| def dp_replicate_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("dp_replicate") |
|
|
| @property |
| @requires_mesh |
| def dp_replicate_mesh(self) -> "DeviceMesh": |
| if self.device_mesh is not None: |
| return self.device_mesh["dp_replicate"] |
|
|
| @property |
| def dp_replicate_enabled(self) -> bool: |
| return self.dp_replicate_size > 1 |
|
|
| |
| @property |
| def dp_shard_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("dp_shard") |
|
|
| @property |
| def dp_shard_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("dp_shard") |
|
|
| @property |
| @requires_mesh |
| def dp_shard_mesh(self) -> "DeviceMesh": |
| if self.device_mesh is not None: |
| return self.device_mesh["dp_shard"] |
|
|
| @property |
| def dp_shard_enabled(self) -> bool: |
| return self.dp_shard_size >= 1 |
|
|
| |
| @property |
| def fsdp_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("dp_sp") |
|
|
| @property |
| def fsdp_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("dp_sp") |
|
|
| return self.global_rank |
|
|
| @property |
| def dp_shard_sp_enabled(self) -> bool: |
| return self.dp_shard_enabled and self.sp_enabled |
|
|
| @property |
| @requires_mesh |
| def fsdp_mesh(self) -> "DeviceMesh": |
| if self.dp_replicate_enabled: |
| |
| if self.dp_shard_sp_enabled: |
| return self.device_mesh["dp_replicate", "dp_shard_sp"] |
| elif self.dp_shard_enabled: |
| return self.device_mesh["dp_replicate", "dp_shard"] |
| else: |
| |
| return self.device_mesh["dp_replicate"] |
| |
| elif self.dp_shard_sp_enabled: |
| return self.device_mesh["dp_shard_sp"] |
| elif self.dp_shard_enabled: |
| return self.device_mesh["dp_shard"] |
| else: |
| return self.device_mesh["dp"] |
|
|
| @property |
| def fsdp_enabled(self) -> bool: |
| return self.fsdp_size > 1 |
|
|
| @property |
| def fsdp_size(self) -> int: |
| return self.world_size // (self.pp_size * self.tp_size) |
|
|
| |
| @property |
| @requires_mesh |
| def tp_rank(self) -> int: |
| return self.device_mesh.get_local_rank("tp") |
|
|
| @property |
| @requires_mesh |
| def tp_mesh(self) -> "DeviceMesh": |
| return self.device_mesh["tp"] |
|
|
| @property |
| def tp_enabled(self) -> bool: |
| return self.tp_size > 1 |
|
|
| |
| @property |
| @requires_mesh |
| def pp_rank(self) -> int: |
| return self.device_mesh.get_local_rank("pp") |
|
|
| @property |
| @requires_mesh |
| def pp_mesh(self) -> "DeviceMesh": |
| return self.device_mesh["pp"] |
|
|
| @property |
| def pp_enabled(self) -> bool: |
| return self.pp_size > 1 |
|
|
| @property |
| @requires_mesh |
| def is_first_pp_stage(self) -> bool: |
| return self.pp_rank == 0 |
|
|
| @property |
| @requires_mesh |
| def is_last_pp_stage(self) -> bool: |
| return self.pp_rank == (self.pp_size - 1) |
|
|
| |
| @property |
| @requires_mesh |
| def ep_mesh(self) -> "DeviceMesh": |
| return self.ep_fsdp_device_mesh["ep"] |
|
|
| @property |
| @requires_mesh |
| def ep_fsdp_mesh(self) -> "DeviceMesh": |
| return self.ep_fsdp_device_mesh["ep", "ep_fsdp"] |
|
|
| @property |
| @requires_mesh |
| def ep_group(self) -> "ProcessGroup": |
| return self.ep_mesh.get_group() |
|
|
| @property |
| def ep_enabled(self) -> bool: |
| return self.ep_size > 1 |
|
|
| @property |
| def ep_rank(self) -> int: |
| return self.ep_fsdp_device_mesh.get_local_rank("ep") |
|
|
| |
| @property |
| def sp_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("sp") |
|
|
| if self.sp_enabled: |
| from .sequence_parallel import get_unified_sequence_parallel_group |
|
|
| return get_unified_sequence_parallel_group() |
|
|
| return None |
|
|
| @property |
| def sp_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("sp") |
|
|
| if self.sp_enabled: |
| from .sequence_parallel import get_unified_sequence_parallel_rank |
|
|
| return get_unified_sequence_parallel_rank() |
|
|
| return -1 |
|
|
| @property |
| def sp_enabled(self) -> bool: |
| return self.cp_size > 1 or self.ulysses_size > 1 |
|
|
| @property |
| def sp_size(self) -> int: |
| return self.ulysses_size * self.cp_size |
|
|
| @property |
| def ulysses_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("ulysses") |
|
|
| if self.sp_enabled: |
| from .sequence_parallel import get_ulysses_sequence_parallel_group |
|
|
| return get_ulysses_sequence_parallel_group() |
|
|
| return None |
|
|
| @property |
| def ulysses_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("ulysses") |
|
|
| if self.sp_enabled: |
| from .sequence_parallel import get_ulysses_sequence_parallel_rank |
|
|
| return get_ulysses_sequence_parallel_rank() |
|
|
| return -1 |
|
|
| @property |
| def ulysses_enabled(self) -> bool: |
| return self.ulysses_size > 1 |
|
|
| @property |
| def cp_group(self) -> Optional["ProcessGroup"]: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_group("cp") |
|
|
| if self.sp_enabled: |
| from .sequence_parallel import get_context_parallel_group |
|
|
| return get_context_parallel_group() |
|
|
| return None |
|
|
| @property |
| def cp_rank(self) -> int: |
| if self.device_mesh is not None: |
| return self.device_mesh.get_local_rank("cp") |
|
|
| if self.sp_enabled: |
| from .sequence_parallel import get_context_parallel_rank |
|
|
| return get_context_parallel_rank() |
|
|
| return -1 |
|
|
| @property |
| def cp_enabled(self) -> bool: |
| return self.cp_size > 1 |
|
|
|
|
| def init_parallel_state( |
| dp_size: int = 1, |
| dp_replicate_size: int = 1, |
| dp_shard_size: int = 1, |
| tp_size: int = 1, |
| ep_size: int = 1, |
| pp_size: int = 1, |
| cp_size: int = 1, |
| ulysses_size: int = 1, |
| dp_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp1", |
| device_type: str = None, |
| include_sp_in_fsdp: bool = True, |
| ep_outside: bool = False, |
| ) -> None: |
| """ |
| Initializes global parallel state. |
| """ |
| global _PARALLEL_STATE |
| if _PARALLEL_STATE is not None: |
| logger.warning("Parallel state has already been initialized.") |
| return |
|
|
| if device_type is None: |
| device_type = "npu" if is_torch_npu_available() else "cuda" |
|
|
| |
| if dp_size > 1 and dp_shard_size == 1 and dp_replicate_size == 1: |
| dp_shard_size = dp_size |
|
|
| logger.info_rank0( |
| f"Initializing parallel state... dp_size {dp_size}, dp_replicate_size {dp_replicate_size}, dp_shard_size {dp_shard_size},tp_size {tp_size}, pp_size {pp_size}, cp_size {cp_size}, ulysses_size {ulysses_size}" |
| ) |
|
|
| device_mesh, ep_fsdp_device_mesh = None, None |
| if is_torch_version_greater_than("2.4"): |
| mesh_shape = [] |
| mesh_dim_names = [] |
| for d, name in zip( |
| [pp_size, dp_replicate_size, dp_shard_size, ulysses_size, cp_size, tp_size], |
| ["pp", "dp_replicate", "dp_shard", "ulysses", "cp", "tp"], |
| ): |
| if d > 1 or name in ["dp_shard"]: |
| mesh_shape.append(d) |
| mesh_dim_names.append(name) |
|
|
| device_mesh = init_device_mesh( |
| device_type=device_type, |
| mesh_shape=tuple(mesh_shape), |
| mesh_dim_names=tuple(mesh_dim_names), |
| ) |
|
|
| |
| dp_mesh_dim_names = [] |
| |
| dp_shard_sp_mesh_dim_names = [] |
| |
| dp_sp_mesh_dim_names = [] |
| |
| sp_mesh_dim_names = [] |
|
|
| if dp_replicate_size > 1: |
| dp_mesh_dim_names.append("dp_replicate") |
| dp_sp_mesh_dim_names.append("dp_replicate") |
| if dp_shard_size >= 1: |
| dp_mesh_dim_names.append("dp_shard") |
| dp_shard_sp_mesh_dim_names.append("dp_shard") |
| dp_sp_mesh_dim_names.append("dp_shard") |
| if ulysses_size > 1: |
| dp_shard_sp_mesh_dim_names.append("ulysses") |
| sp_mesh_dim_names.append("ulysses") |
| dp_sp_mesh_dim_names.append("ulysses") |
| if cp_size > 1: |
| dp_shard_sp_mesh_dim_names.append("cp") |
| sp_mesh_dim_names.append("cp") |
| dp_sp_mesh_dim_names.append("cp") |
|
|
| if dp_mesh_dim_names != []: |
| device_mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") |
|
|
| if dp_shard_sp_mesh_dim_names != []: |
| device_mesh[tuple(dp_shard_sp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_sp") |
|
|
| if dp_sp_mesh_dim_names != []: |
| device_mesh[tuple(dp_sp_mesh_dim_names)]._flatten(mesh_dim_name="dp_sp") |
|
|
| if sp_mesh_dim_names != []: |
| device_mesh[tuple(sp_mesh_dim_names)]._flatten(mesh_dim_name="sp") |
|
|
| if ep_size > 1: |
| world_size = dist.get_world_size() |
| assert world_size % ep_size == 0, "ep_size must be a factor of world_size" |
| ep_fsdp_size = world_size // ep_size |
|
|
| mesh = init_ep_mesh_matrix(ep_size=ep_size, ep_fsdp_size=ep_fsdp_size, ep_outside=ep_outside) |
| ep_fsdp_device_mesh = DeviceMesh( |
| device_type=device_type, |
| mesh=mesh, |
| mesh_dim_names=("ep", "ep_fsdp"), |
| ) |
|
|
| logger.info_rank0(f"Device mesh: {device_mesh}") |
| logger.info_rank0(f"EP FSDP device mesh: {ep_fsdp_device_mesh}") |
|
|
| _PARALLEL_STATE = ParallelState( |
| dp_size=dp_size, |
| dp_replicate_size=dp_replicate_size, |
| dp_shard_size=dp_shard_size, |
| tp_size=tp_size, |
| ep_size=ep_size, |
| pp_size=pp_size, |
| cp_size=cp_size, |
| ulysses_size=ulysses_size, |
| dp_mode=dp_mode, |
| device_type=device_type, |
| include_sp_in_fsdp=include_sp_in_fsdp, |
| device_mesh=device_mesh, |
| ep_fsdp_device_mesh=ep_fsdp_device_mesh, |
| ) |
|
|
|
|
| def get_parallel_state() -> "ParallelState": |
| """ |
| Returns global parallel state. |
| """ |
| if _PARALLEL_STATE is None: |
| logger.warning_once("Parallel state has not been initialized. returning default Single-process state.") |
| return ParallelState() |
|
|
| return _PARALLEL_STATE |
|
|