File size: 13,072 Bytes
69e1a8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 | # 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
# Experimental changes are subject to change and APIs may break without warning.
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
import torch
import torch.distributed as dist
from ..utils import get_logger
if TYPE_CHECKING:
pass
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): add support for the following:
# - Unified Attention
# - More dispatcher attention backends
# - CFG/Data Parallel
# - Tensor Parallel
@dataclass
class ContextParallelConfig:
"""
Configuration for context parallelism.
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
good interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
"""
ring_degree: int | None = None
ulysses_degree: int | None = None
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
_rank: int = None
_world_size: int = None
_device: torch.device = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
def __post_init__(self):
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self.ulysses_anything:
if self.ulysses_degree == 1:
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
@property
def mesh_shape(self) -> tuple[int, int]:
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ulysses_degree * self.ring_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
class ParallelConfig:
"""
Configuration for applying different parallelisms.
Args:
context_parallel_config (`ContextParallelConfig`, *optional*):
Configuration for context parallelism.
"""
context_parallel_config: ContextParallelConfig | None = None
_rank: int = None
_world_size: int = None
_device: torch.device = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
rank: int,
world_size: int,
device: torch.device,
*,
mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True)
class ContextParallelInput:
"""
Configuration for splitting an input tensor across context parallel region.
Args:
split_dim (`int`):
The dimension along which to split the tensor.
expected_dims (`int`, *optional*):
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
tensor has the expected number of dimensions before splitting.
split_output (`bool`, *optional*, defaults to `False`):
Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
RoPE).
"""
split_dim: int
expected_dims: int | None = None
split_output: bool = False
def __repr__(self):
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
@dataclass(frozen=True)
class ContextParallelOutput:
"""
Configuration for gathering an output tensor across context parallel region.
Args:
gather_dim (`int`):
The dimension along which to gather the tensor.
expected_dims (`int`, *optional*):
The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
tensor has the expected number of dimensions before gathering.
"""
gather_dim: int
expected_dims: int | None = None
def __repr__(self):
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
# A dictionary where keys denote the input to be split across context parallel region, and the
# value denotes the sharding configuration.
# If the key is a string, it denotes the name of the parameter in the forward function.
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
# to be split across context parallel region.
ContextParallelInputType = dict[
str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...]
]
# A dictionary where keys denote the output to be gathered across context parallel region, and the
# value denotes the gathering configuration.
ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...]
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
# the module should be split/gathered across context parallel region.
ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType]
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
#
# Each model should define a _cp_plan attribute that contains information on how to shard/gather
# tensors at different stages of the forward:
#
# ```python
# _cp_plan = {
# "": {
# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
# },
# "pos_embed": {
# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
# },
# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
# }
# ```
#
# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
# split/gathered according to this at the respective module level. Here, the following happens:
# - "":
# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
# - "pos_embed":
# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
# we can individually specify how they should be split
# - "proj_out":
# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
# layer forward has run).
#
# ContextParallelInput:
# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
#
# ContextParallelOutput:
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
# Below are utility functions for distributed communication in context parallelism.
def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> list[int]:
r"""Gather the local size from all ranks.
size: int, local size return: list[int], list of size from all ranks
"""
# NOTE(Serving/CP Safety):
# Do NOT cache this collective result.
#
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
# may legitimately differ across ranks. If we cache based on the *local* `size`,
# different ranks can have different cache hit/miss patterns across time.
#
# That can lead to a catastrophic distributed hang:
# - some ranks hit cache and *skip* dist.all_gather()
# - other ranks miss cache and *enter* dist.all_gather()
# This mismatched collective participation will stall the process group and
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
# timeouts in Ulysses attention).
world_size = dist.get_world_size(group=group)
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(
gathered_sizes,
torch.tensor([size], device=gather_device, dtype=torch.int64),
group=group,
)
gathered_sizes = [s[0].item() for s in gathered_sizes]
# NOTE: DON'T use tolist here due to graph break - Explanation:
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
return gathered_sizes
|