# Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT from typing import Any, Optional, Tuple import torch import torch.distributed as dist import torch.nn.functional as F from einops import rearrange from torch import Tensor from torch.distributed import ProcessGroup class AllToAll(torch.autograd.Function): """Dispatches input tensor [e, c, h] to all experts by all_to_all_single operation in torch.distributed. """ @staticmethod def forward( ctx: Any, inputs: Tensor, group: ProcessGroup, overlap: bool = False, ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor handle: Optional[Work], if overlap is True """ assert ctx is not None or not overlap if ctx is not None: ctx.comm_grp = group if not inputs.is_contiguous(): inputs = inputs.contiguous() if dist.get_world_size(group) == 1: return inputs, None output = torch.empty_like(inputs) if not overlap: dist.all_to_all_single(output, inputs, group=group) return output, None else: handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) return output, handle @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) class AsyncAllGatherForTwo(torch.autograd.Function): @staticmethod def forward( ctx: Any, inputs: Tensor, weight: Tensor, bias: Tensor, sp_rank: int, sp_size: int, group: Optional[ProcessGroup] = None, ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor handle: Optional[Work], if overlap is True """ from torch.distributed._functional_collectives import all_gather_tensor ctx.group = group ctx.sp_rank = sp_rank ctx.sp_size = sp_size # all gather inputs all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group) # compute local qkv local_qkv = F.linear(inputs, weight, bias).unsqueeze(0) # remote compute remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1]) # compute remote qkv remote_qkv = F.linear(remote_inputs, weight, bias) # concat local and remote qkv if sp_rank == 0: qkv = torch.cat([local_qkv, remote_qkv], dim=0) else: qkv = torch.cat([remote_qkv, local_qkv], dim=0) qkv = rearrange(qkv, "sp b n c -> b (sp n) c") ctx.save_for_backward(inputs, weight, remote_inputs) return qkv @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: from torch.distributed._functional_collectives import reduce_scatter_tensor group = ctx.group sp_rank = ctx.sp_rank sp_size = ctx.sp_size inputs, weight, remote_inputs = ctx.saved_tensors # split qkv_grad qkv_grad = grad_outputs[0] qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size) qkv_grad = torch.chunk(qkv_grad, 2, dim=0) if sp_rank == 0: local_qkv_grad, remote_qkv_grad = qkv_grad else: remote_qkv_grad, local_qkv_grad = qkv_grad # compute remote grad remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0) weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0) bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0) # launch async reduce scatter remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad) if sp_rank == 0: remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0) else: remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0) remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group) # compute local grad and wait for reduce scatter local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0) weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0) bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0) # sum remote and local grad inputs_grad = remote_inputs_grad + local_input_grad return inputs_grad, weight_grad, bias_grad, None, None, None class AllGather(torch.autograd.Function): @staticmethod def forward( ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor handle: Optional[Work], if overlap is True """ assert ctx is not None or not overlap if ctx is not None: ctx.comm_grp = group comm_size = dist.get_world_size(group) # print(f"XW debug, All Gather Dist world size {comm_size}") if comm_size == 1: return inputs.unsqueeze(0), None buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) # buffer_list = list([ # t.squeeze(0) for t in torch.chunk(outputs, comm_size, dim=0) # ]) if not overlap: # print("buffer list", len(buffer_list), [t.shape for t in buffer_list]) # print("inputs", inputs.shape, inputs.is_contiguous()) # print(group) dist.all_gather(buffer_list, inputs, group=group) return outputs, None else: handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) return outputs, handle @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) class ReduceScatter(torch.autograd.Function): @staticmethod def forward( ctx: Any, inputs: Tensor, group: ProcessGroup, overlap: bool = False, ) -> Tuple[Tensor, Any]: """ Returns: outputs: Tensor handle: Optional[Work], if overlap is True """ assert ctx is not None or not overlap if ctx is not None: ctx.comm_grp = group comm_size = dist.get_world_size(group) if comm_size == 1: return inputs.squeeze(0), None if not inputs.is_contiguous(): inputs = inputs.contiguous() output_shape = inputs.shape[1:] outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) if not overlap: dist.reduce_scatter(outputs, buffer_list, group=group) return outputs, None else: handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) return outputs, handle @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: # TODO: support async backward return ( AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) # using all_to_all_single api to perform all to all communication def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() else: input_t = ( input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) .transpose(0, 1) .contiguous() ) output = torch.empty_like(input_t) dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: output = output.transpose(0, 1).contiguous() return output.reshape( inp_shape[:gather_dim] + [ inp_shape[gather_dim] * seq_world_size, ] + inp_shape[gather_dim + 1 :] ).contiguous() # using all_to_all api to perform all to all communication def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() class _AllToAll(torch.autograd.Function): """All-to-all communication. Args: input_: input matrix process_group: communication group scatter_dim: scatter dimension gather_dim: gather dimension """ @staticmethod def forward(ctx, input_, process_group, scatter_dim, gather_dim): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim world_size = dist.get_world_size(process_group) bsz, _, _ = input_.shape # Todo: Try to make all_to_all_single compatible with a large batch size if bsz == 1: return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) else: return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) @staticmethod def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) return (return_grad, None, None, None) def model_sharding(model: torch.nn.Module): global_rank = dist.get_rank() world_size = dist.get_world_size() for _, param in model.named_parameters(): padding_size = (world_size - param.numel() % world_size) % world_size if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // world_size) splited_params = splited_params[global_rank] param.data = splited_params def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) def _gather(input_, dim=-1, process_group=None): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ # all gather input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] torch.distributed.all_gather(tensor_list, input_, group=process_group) # concat output = torch.cat(tensor_list, dim=dim).contiguous() return output def _split(input_, dim=-1, process_group=None): # skip if only one rank involved world_size = dist.get_world_size(process_group) if world_size == 1: return input_ # Split along last dimension. dim_size = input_.size(dim) assert dim_size % world_size == 0, ( f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " f"cannot split tensor evenly" ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = dist.get_rank(process_group) output = tensor_list[rank].clone().contiguous() return output class _GatherForwardSplitBackward(torch.autograd.Function): """Gather the input from model parallel region and concatenate. Args: input_: input matrix. parallel_mode: parallel mode. dim: dimension """ @staticmethod def forward(ctx, input_, dim, process_group): ctx.process_group = process_group ctx.dim = dim return _gather(input_, dim, process_group) @staticmethod def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group)