# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Optional, Tuple import torch import torch.distributed as dist from torch import Tensor from torch.distributed import ProcessGroup from .comm import ( get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_world_size, ) from .utils import ( pad_tensor, unpad_tensor, ) def _all_gather( x: Tensor, group: dist.ProcessGroup, ): device = x.device dtype = x.dtype group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group) x_size = torch.tensor(x.size()).to(device) size_list = [torch.zeros(x_size.size(), dtype=torch.int64, device=device) for i in range(sp_world_size)] dist.all_gather(size_list, x_size, group=group) tensor_list = [torch.zeros(torch.Size(size_list[i]), dtype=dtype, device=device) for i in range(sp_world_size)] dist.all_gather(tensor_list, x, group=group) return tensor_list, size_list def _all_gather_into_tensor( x: Tensor, group: dist.ProcessGroup, ): dim_size = list(x.size()) group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group) dim_size[0] = dim_size[0] * sp_world_size output = torch.empty(dim_size, dtype=x.dtype, device=torch.cuda.current_device()) dist.all_gather_into_tensor(output, x, group=group) return output def _all_to_all( local_input: Tensor, scatter_dim: int, gather_dim: int, group: Optional[dist.ProcessGroup] = None, async_op: bool = False, ): group = get_ulysses_sequence_parallel_group() if group is None else group seq_world_size = dist.get_world_size(group) input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) if async_op: def wait(): comm.wait() return torch.cat(output_list, dim=gather_dim).contiguous() return wait return torch.cat(output_list, dim=gather_dim).contiguous() def _all_to_all_single( x: Tensor, scatter_dim: int, gather_dim: int, group: Optional[dist.ProcessGroup] = None, async_op: bool = False ): """ A function to do all-to-all on the first two dim """ group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group) assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!" assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!" if scatter_dim != 0: gather_dim_bef = x.shape[gather_dim] scatter_dim_bef = x.shape[scatter_dim] x = ( x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:])) .transpose(0, 1) .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:])) .contiguous() ) output = torch.empty_like(x) comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op) if async_op: def wait(): comm.wait() if scatter_dim == 0: return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim) else: return output return wait if scatter_dim == 0: output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim) return output def all_to_all_tensor( x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False, ): if scatter_dim <= 1 and gather_dim <= 1: return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op) else: return _all_to_all(x, scatter_dim, gather_dim, group, async_op) class _SeqAllToAll(torch.autograd.Function): @staticmethod def forward( ctx: Any, group: dist.ProcessGroup, local_input: Tensor, scatter_dim: int, gather_dim: int, async_op: bool, ) -> Tensor: ctx.group = group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim ctx.async_op = async_op return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: if ctx.async_op: input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() else: input_t = grad_output[0] return ( None, all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), None, None, None, None, ) class _Slice(torch.autograd.Function): @staticmethod def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor: ctx.group = group ctx.rank = dist.get_rank(group) seq_world_size = dist.get_world_size(group) ctx.seq_world_size = seq_world_size ctx.dim = dim ctx.scale_grad = scale_grad dim_size = local_input.shape[dim] return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() @staticmethod def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: dim_size = list(grad_output.size()) split_size = dim_size[0] output = _all_gather_into_tensor(grad_output, group=ctx.group) if ctx.scale_grad: output = output / ctx.seq_world_size return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None) class _Gather(torch.autograd.Function): @staticmethod def forward( ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, grad_scale: Optional[bool] = False, ) -> Tensor: ctx.group = group ctx.rank = dist.get_rank(group) ctx.dim = dim ctx.grad_scale = grad_scale seq_world_size = dist.get_world_size(group) ctx.seq_world_size = seq_world_size output, size_list = _all_gather(local_input.contiguous(), group=ctx.group) dim_size_list = [size_list[i][dim].item() for i in range(seq_world_size)] ctx.dim_size_list = dim_size_list return torch.cat(output, dim=dim) @staticmethod def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: if ctx.grad_scale: grad_output = grad_output * ctx.seq_world_size return ( None, grad_output.split(ctx.dim_size_list, dim=ctx.dim)[ctx.rank].contiguous(), None, None, ) def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor: """ A func to sync attention result with alltoall in sequence parallel """ group = get_ulysses_sequence_parallel_group() if group is None else group if not group: return x dim_size = x.size(seq_dim) sp_world = get_ulysses_sequence_parallel_world_size(group) if dim_size % sp_world != 0: padding_size = sp_world - (dim_size % sp_world) x = pad_tensor(x, seq_dim, padding_size) return _SeqAllToAll.apply(group, x, seq_dim, head_dim, False) def gather_seq_scatter_heads( x: Tensor, seq_dim: int, head_dim: int, unpadded_dim_size: int = 0, async_op: bool = False, group: ProcessGroup = None, ) -> Tensor: """ A func to sync embedding input with alltoall in sequence parallel """ group = get_ulysses_sequence_parallel_group() if group is None else group if not group: return x sp_world = get_ulysses_sequence_parallel_world_size(group) if async_op: return _SeqAllToAll.apply(group, x, head_dim, seq_dim, async_op) else: x = _SeqAllToAll.apply(group, x, head_dim, seq_dim, async_op) if unpadded_dim_size and unpadded_dim_size % sp_world != 0: padding_size = x.size(seq_dim) - unpadded_dim_size x = unpad_tensor(x, seq_dim, padding_size) return x def gather_seq_scatter_heads_qkv( qkv_tensor: Tensor, seq_dim: int, unpadded_dim_size: Optional[int] = None, restore_shape: bool = True, async_op: bool = False, group: ProcessGroup = None, ) -> Tensor: """ A func to sync splited qkv tensor qkv_tensor: the tensor we want to do alltoall with. The last dim must be the projection_idx, which we will split into 3 part. After spliting, the gather idx will be projecttion_idx + 1 seq_dim: gather_dim for all2all comm restore_shape: if True, output will has the same shape length as input """ group = get_ulysses_sequence_parallel_group() if group is None else group if not group: return qkv_tensor sp_world = get_ulysses_sequence_parallel_world_size(group) orig_shape = qkv_tensor.shape scatter_dim = qkv_tensor.dim() bef_all2all_shape = list(orig_shape) qkv_proj_dim = bef_all2all_shape[-1] bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] qkv_tensor = qkv_tensor.view(bef_all2all_shape) if async_op: return _SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) else: qkv_tensor = _SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op) if restore_shape: out_shape = list(orig_shape) out_shape[seq_dim] *= sp_world out_shape[-1] = qkv_proj_dim // sp_world qkv_tensor = qkv_tensor.view(out_shape) # remove padding if unpadded_dim_size and unpadded_dim_size % sp_world != 0: padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size) return qkv_tensor class _AlltoAllRegion(torch.autograd.Function): """balance the intermediate tensors in the sequence parallel region""" @staticmethod def forward(ctx, group, x, input_splits, output_splits): ctx.group = group ctx.input_splits = input_splits ctx.output_splits = output_splits input_tensor_list = list(x.split(input_splits, dim=0)) input_tensor_list = [t.contiguous() for t in input_tensor_list] output_tensor_list = [torch.empty([o, *x.shape[1:]], dtype=x.dtype, device=x.device) for o in output_splits] dist.all_to_all(output_tensor_list, input_tensor_list, group=group) return torch.cat(output_tensor_list, dim=0) def backward(ctx, dy): dx_list = [torch.empty([i, *dy.shape[1:]], dtype=dy.dtype, device=dy.device) for i in ctx.input_splits] dy_list = list(dy.split(ctx.output_splits, dim=0)) dist.all_to_all(dx_list, dy_list, group=ctx.group) return None, torch.cat(dx_list, dim=0), None, None def all_to_all_images(image_embeds, in_splits, out_splits): if not in_splits: return image_embeds image_embeds = image_embeds[: sum(in_splits)] group = get_ulysses_sequence_parallel_group() return _AlltoAllRegion.apply(group, image_embeds, in_splits, out_splits)