Spaces:
Running on Zero
Running on Zero
| # 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨 | |
| # Experimental changes are subject to change and APIs may break without warning. | |
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Literal | |
| import torch | |
| import torch.distributed as dist | |
| from ..utils import get_logger | |
| if TYPE_CHECKING: | |
| pass | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | |
| # TODO(aryan): add support for the following: | |
| # - Unified Attention | |
| # - More dispatcher attention backends | |
| # - CFG/Data Parallel | |
| # - Tensor Parallel | |
| class ContextParallelConfig: | |
| """ | |
| Configuration for context parallelism. | |
| Args: | |
| ring_degree (`int`, *optional*, defaults to `1`): | |
| Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes | |
| attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N | |
| of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best | |
| for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a | |
| context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. | |
| ulysses_degree (`int`, *optional*, defaults to `1`): | |
| Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes | |
| local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all | |
| KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with | |
| good interconnect bandwidth. | |
| convert_to_fp32 (`bool`, *optional*, defaults to `True`): | |
| Whether to convert output and LSE to float32 for ring attention numerical stability. | |
| rotate_method (`str`, *optional*, defaults to `"allgather"`): | |
| Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` | |
| is supported. | |
| ulysses_anything (`bool`, *optional*, defaults to `False`): | |
| Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that | |
| are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and | |
| `ring_degree` must be 1. | |
| ring_anything (`bool`, *optional*, defaults to `False`): | |
| Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, | |
| `ring_degree` must be greater than 1 and `ulysses_degree` must be 1. | |
| mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): | |
| A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of | |
| creating a new one. This is useful when combining context parallelism with other parallelism strategies | |
| (e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and | |
| "ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with | |
| `mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP). | |
| """ | |
| ring_degree: int | None = None | |
| ulysses_degree: int | None = None | |
| convert_to_fp32: bool = True | |
| # TODO: support alltoall | |
| rotate_method: Literal["allgather", "alltoall"] = "allgather" | |
| mesh: torch.distributed.device_mesh.DeviceMesh | None = None | |
| # Whether to enable ulysses anything attention to support | |
| # any sequence lengths and any head numbers. | |
| ulysses_anything: bool = False | |
| # Whether to enable ring anything attention to support any sequence lengths. | |
| ring_anything: bool = False | |
| _rank: int = None | |
| _world_size: int = None | |
| _device: torch.device = None | |
| _mesh: torch.distributed.device_mesh.DeviceMesh = None | |
| _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None | |
| _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None | |
| _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None | |
| _ring_local_rank: int = None | |
| _ulysses_local_rank: int = None | |
| def __post_init__(self): | |
| if self.ring_degree is None: | |
| self.ring_degree = 1 | |
| if self.ulysses_degree is None: | |
| self.ulysses_degree = 1 | |
| if self.ring_degree == 1 and self.ulysses_degree == 1: | |
| raise ValueError( | |
| "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" | |
| ) | |
| if self.ring_degree < 1 or self.ulysses_degree < 1: | |
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | |
| if self.rotate_method != "allgather": | |
| raise NotImplementedError( | |
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | |
| ) | |
| if self.ulysses_anything: | |
| if self.ulysses_degree == 1: | |
| raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") | |
| if self.ring_degree > 1: | |
| raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") | |
| if self.ring_anything: | |
| if self.ring_degree == 1: | |
| raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.") | |
| if self.ulysses_degree > 1: | |
| raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.") | |
| if self.ulysses_anything and self.ring_anything: | |
| raise ValueError("ulysses_anything and ring_anything cannot both be enabled.") | |
| def mesh_shape(self) -> tuple[int, int]: | |
| return (self.ring_degree, self.ulysses_degree) | |
| def mesh_dim_names(self) -> tuple[str, str]: | |
| """Dimension names for the device mesh.""" | |
| return ("ring", "ulysses") | |
| def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): | |
| self._rank = rank | |
| self._world_size = world_size | |
| self._device = device | |
| self._mesh = mesh | |
| if self.ulysses_degree * self.ring_degree > world_size: | |
| raise ValueError( | |
| f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." | |
| ) | |
| self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten() | |
| self._ring_mesh = self._mesh["ring"] | |
| self._ulysses_mesh = self._mesh["ulysses"] | |
| self._ring_local_rank = self._ring_mesh.get_local_rank() | |
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | |
| class ParallelConfig: | |
| """ | |
| Configuration for applying different parallelisms. | |
| Args: | |
| context_parallel_config (`ContextParallelConfig`, *optional*): | |
| Configuration for context parallelism. | |
| """ | |
| context_parallel_config: ContextParallelConfig | None = None | |
| _rank: int = None | |
| _world_size: int = None | |
| _device: torch.device = None | |
| _mesh: torch.distributed.device_mesh.DeviceMesh = None | |
| def setup( | |
| self, | |
| rank: int, | |
| world_size: int, | |
| device: torch.device, | |
| *, | |
| mesh: torch.distributed.device_mesh.DeviceMesh | None = None, | |
| ): | |
| self._rank = rank | |
| self._world_size = world_size | |
| self._device = device | |
| self._mesh = mesh | |
| if self.context_parallel_config is not None: | |
| self.context_parallel_config.setup(rank, world_size, device, mesh) | |
| class ContextParallelInput: | |
| """ | |
| Configuration for splitting an input tensor across context parallel region. | |
| Args: | |
| split_dim (`int`): | |
| The dimension along which to split the tensor. | |
| expected_dims (`int`, *optional*): | |
| The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the | |
| tensor has the expected number of dimensions before splitting. | |
| split_output (`bool`, *optional*, defaults to `False`): | |
| Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor. | |
| This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex: | |
| RoPE). | |
| """ | |
| split_dim: int | |
| expected_dims: int | None = None | |
| split_output: bool = False | |
| def __repr__(self): | |
| return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})" | |
| class ContextParallelOutput: | |
| """ | |
| Configuration for gathering an output tensor across context parallel region. | |
| Args: | |
| gather_dim (`int`): | |
| The dimension along which to gather the tensor. | |
| expected_dims (`int`, *optional*): | |
| The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the | |
| tensor has the expected number of dimensions before gathering. | |
| """ | |
| gather_dim: int | |
| expected_dims: int | None = None | |
| def __repr__(self): | |
| return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})" | |
| # A dictionary where keys denote the input to be split across context parallel region, and the | |
| # value denotes the sharding configuration. | |
| # If the key is a string, it denotes the name of the parameter in the forward function. | |
| # If the key is an integer, split_output must be set to True, and it denotes the index of the output | |
| # to be split across context parallel region. | |
| ContextParallelInputType = dict[ | |
| str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...] | |
| ] | |
| # A dictionary where keys denote the output to be gathered across context parallel region, and the | |
| # value denotes the gathering configuration. | |
| ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...] | |
| # A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of | |
| # the module should be split/gathered across context parallel region. | |
| ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType] | |
| # Example of a ContextParallelModelPlan (QwenImageTransformer2DModel): | |
| # | |
| # Each model should define a _cp_plan attribute that contains information on how to shard/gather | |
| # tensors at different stages of the forward: | |
| # | |
| # ```python | |
| # _cp_plan = { | |
| # "": { | |
| # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), | |
| # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), | |
| # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), | |
| # }, | |
| # "pos_embed": { | |
| # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), | |
| # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), | |
| # }, | |
| # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), | |
| # } | |
| # ``` | |
| # | |
| # The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be | |
| # split/gathered according to this at the respective module level. Here, the following happens: | |
| # - "": | |
| # we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before | |
| # the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs) | |
| # - "pos_embed": | |
| # we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs), | |
| # we can individually specify how they should be split | |
| # - "proj_out": | |
| # before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear | |
| # layer forward has run). | |
| # | |
| # ContextParallelInput: | |
| # specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to | |
| # | |
| # ContextParallelOutput: | |
| # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to | |
| # Below are utility functions for distributed communication in context parallelism. | |
| def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> list[int]: | |
| r"""Gather the local size from all ranks. | |
| size: int, local size return: list[int], list of size from all ranks | |
| """ | |
| # NOTE(Serving/CP Safety): | |
| # Do NOT cache this collective result. | |
| # | |
| # In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL) | |
| # may legitimately differ across ranks. If we cache based on the *local* `size`, | |
| # different ranks can have different cache hit/miss patterns across time. | |
| # | |
| # That can lead to a catastrophic distributed hang: | |
| # - some ranks hit cache and *skip* dist.all_gather() | |
| # - other ranks miss cache and *enter* dist.all_gather() | |
| # This mismatched collective participation will stall the process group and | |
| # eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL | |
| # timeouts in Ulysses attention). | |
| world_size = dist.get_world_size(group=group) | |
| # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead | |
| comm_backends = str(dist.get_backend(group=group)) | |
| # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") | |
| gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() | |
| gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] | |
| dist.all_gather( | |
| gathered_sizes, | |
| torch.tensor([size], device=gather_device, dtype=torch.int64), | |
| group=group, | |
| ) | |
| gathered_sizes = [s[0].item() for s in gathered_sizes] | |
| # NOTE: DON'T use tolist here due to graph break - Explanation: | |
| # Backend compiler `inductor` failed with aten._local_scalar_dense.default | |
| return gathered_sizes | |