| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Literal |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| from ..utils import get_logger |
|
|
|
|
| if TYPE_CHECKING: |
| pass |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| @dataclass |
| class ContextParallelConfig: |
| """ |
| Configuration for context parallelism. |
| |
| Args: |
| ring_degree (`int`, *optional*, defaults to `1`): |
| Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes |
| attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N |
| of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best |
| for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a |
| context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. |
| ulysses_degree (`int`, *optional*, defaults to `1`): |
| Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes |
| local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all |
| KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with |
| good interconnect bandwidth. |
| convert_to_fp32 (`bool`, *optional*, defaults to `True`): |
| Whether to convert output and LSE to float32 for ring attention numerical stability. |
| rotate_method (`str`, *optional*, defaults to `"allgather"`): |
| Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` |
| is supported. |
| |
| """ |
|
|
| ring_degree: int | None = None |
| ulysses_degree: int | None = None |
| convert_to_fp32: bool = True |
| |
| rotate_method: Literal["allgather", "alltoall"] = "allgather" |
| |
| |
| ulysses_anything: bool = False |
|
|
| _rank: int = None |
| _world_size: int = None |
| _device: torch.device = None |
| _mesh: torch.distributed.device_mesh.DeviceMesh = None |
| _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None |
| _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None |
| _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None |
| _ring_local_rank: int = None |
| _ulysses_local_rank: int = None |
|
|
| def __post_init__(self): |
| if self.ring_degree is None: |
| self.ring_degree = 1 |
| if self.ulysses_degree is None: |
| self.ulysses_degree = 1 |
|
|
| if self.ring_degree == 1 and self.ulysses_degree == 1: |
| raise ValueError( |
| "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" |
| ) |
| if self.ring_degree < 1 or self.ulysses_degree < 1: |
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") |
| if self.rotate_method != "allgather": |
| raise NotImplementedError( |
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." |
| ) |
| if self.ulysses_anything: |
| if self.ulysses_degree == 1: |
| raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") |
| if self.ring_degree > 1: |
| raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") |
|
|
| @property |
| def mesh_shape(self) -> tuple[int, int]: |
| return (self.ring_degree, self.ulysses_degree) |
|
|
| @property |
| def mesh_dim_names(self) -> tuple[str, str]: |
| """Dimension names for the device mesh.""" |
| return ("ring", "ulysses") |
|
|
| def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): |
| self._rank = rank |
| self._world_size = world_size |
| self._device = device |
| self._mesh = mesh |
|
|
| if self.ulysses_degree * self.ring_degree > world_size: |
| raise ValueError( |
| f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." |
| ) |
|
|
| self._flattened_mesh = self._mesh._flatten() |
| self._ring_mesh = self._mesh["ring"] |
| self._ulysses_mesh = self._mesh["ulysses"] |
| self._ring_local_rank = self._ring_mesh.get_local_rank() |
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() |
|
|
|
|
| @dataclass |
| class ParallelConfig: |
| """ |
| Configuration for applying different parallelisms. |
| |
| Args: |
| context_parallel_config (`ContextParallelConfig`, *optional*): |
| Configuration for context parallelism. |
| """ |
|
|
| context_parallel_config: ContextParallelConfig | None = None |
|
|
| _rank: int = None |
| _world_size: int = None |
| _device: torch.device = None |
| _mesh: torch.distributed.device_mesh.DeviceMesh = None |
|
|
| def setup( |
| self, |
| rank: int, |
| world_size: int, |
| device: torch.device, |
| *, |
| mesh: torch.distributed.device_mesh.DeviceMesh | None = None, |
| ): |
| self._rank = rank |
| self._world_size = world_size |
| self._device = device |
| self._mesh = mesh |
| if self.context_parallel_config is not None: |
| self.context_parallel_config.setup(rank, world_size, device, mesh) |
|
|
|
|
| @dataclass(frozen=True) |
| class ContextParallelInput: |
| """ |
| Configuration for splitting an input tensor across context parallel region. |
| |
| Args: |
| split_dim (`int`): |
| The dimension along which to split the tensor. |
| expected_dims (`int`, *optional*): |
| The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the |
| tensor has the expected number of dimensions before splitting. |
| split_output (`bool`, *optional*, defaults to `False`): |
| Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor. |
| This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex: |
| RoPE). |
| """ |
|
|
| split_dim: int |
| expected_dims: int | None = None |
| split_output: bool = False |
|
|
| def __repr__(self): |
| return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})" |
|
|
|
|
| @dataclass(frozen=True) |
| class ContextParallelOutput: |
| """ |
| Configuration for gathering an output tensor across context parallel region. |
| |
| Args: |
| gather_dim (`int`): |
| The dimension along which to gather the tensor. |
| expected_dims (`int`, *optional*): |
| The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the |
| tensor has the expected number of dimensions before gathering. |
| """ |
|
|
| gather_dim: int |
| expected_dims: int | None = None |
|
|
| def __repr__(self): |
| return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})" |
|
|
|
|
| |
| |
| |
| |
| |
| ContextParallelInputType = dict[ |
| str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...] |
| ] |
|
|
| |
| |
| ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...] |
|
|
| |
| |
| ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> list[int]: |
| r"""Gather the local size from all ranks. |
| size: int, local size return: list[int], list of size from all ranks |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| world_size = dist.get_world_size(group=group) |
| |
| comm_backends = str(dist.get_backend(group=group)) |
| |
| gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() |
| gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] |
| dist.all_gather( |
| gathered_sizes, |
| torch.tensor([size], device=gather_device, dtype=torch.int64), |
| group=group, |
| ) |
|
|
| gathered_sizes = [s[0].item() for s in gathered_sizes] |
| |
| |
| return gathered_sizes |
|
|