| from typing import Optional
|
|
|
| import torch
|
| from torch import Tensor
|
| from torch.distributed import ProcessGroup
|
|
|
|
|
|
|
|
|
|
|
| if "all_gather_into_tensor" not in dir(torch.distributed):
|
| torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
| if "reduce_scatter_tensor" not in dir(torch.distributed):
|
| torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
|
|
|
|
|
|
| def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
| world_size = torch.distributed.get_world_size(process_group)
|
| output = torch.empty(
|
| world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
| )
|
| handle = torch.distributed.all_gather_into_tensor(
|
| output, input_.contiguous(), group=process_group, async_op=async_op
|
| )
|
| return output, handle
|
|
|
|
|
|
|
| def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
| world_size = torch.distributed.get_world_size(process_group)
|
| assert input_.shape[0] % world_size == 0
|
| output = torch.empty(
|
| input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
| )
|
| handle = torch.distributed.reduce_scatter_tensor(
|
| output, input_.contiguous(), group=process_group, async_op=async_op
|
| )
|
| return output, handle
|
|
|
|
|
|
|
| def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
| input_ = input_.contiguous()
|
| handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
| return input_, handle
|
|
|
|
|
| class AllGatherFunc(torch.autograd.Function):
|
| """Gather the input from sequence parallel region and concatenate."""
|
|
|
| @staticmethod
|
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
| ctx.process_group = process_group
|
| output, _ = all_gather_raw(input_, process_group)
|
| return output
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output: Tensor):
|
| grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
| return grad_input, None
|
|
|
|
|
|
|
| all_gather = AllGatherFunc.apply
|
|
|
|
|
| class ReduceScatterFunc(torch.autograd.Function):
|
| """Reduce scatter the input from the sequence parallel region and concatenate."""
|
|
|
| @staticmethod
|
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
| ctx.process_group = process_group
|
| output, _ = reduce_scatter_raw(input_, process_group)
|
| return output
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output: Tensor):
|
| grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
| return grad_input, None
|
|
|
|
|
|
|
| reduce_scatter = ReduceScatterFunc.apply
|
|
|
|
|
| class AllReduceFunc(torch.autograd.Function):
|
| """Gather the input from sequence parallel region and concatenate."""
|
|
|
| @staticmethod
|
| def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
| ctx.process_group = process_group
|
| output, _ = all_reduce_raw(input_, process_group)
|
| return output
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output: Tensor):
|
| return grad_output, None
|
|
|
|
|
|
|
| all_reduce = AllReduceFunc.apply
|
|
|
|
|
| def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
|
|
|
|
| pamams_shared = {
|
| name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
| }
|
| for _, p in sorted(pamams_shared.items()):
|
| with torch.no_grad():
|
|
|
| torch.distributed.broadcast(
|
| p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
| )
|
|
|
|
|
|
|
| def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
|
|
|
|
| params_seqparallel = {
|
| name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
| }
|
| grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
| if grads:
|
| with torch.no_grad():
|
| coalesced = torch._utils._flatten_dense_tensors(grads)
|
| torch.distributed.all_reduce(coalesced, group=process_group)
|
| for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
| buf.copy_(synced)
|
|
|
|
|
| def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
| """Get the dim for the local rank derived from splitting dim on world_size processes.
|
|
|
| The split may not be even across the world_size processes.
|
| """
|
| multiple = dim // multiple_of
|
| div = multiple // world_size
|
| mod = multiple % world_size
|
| local_multiple = div + int(local_rank < mod)
|
| return local_multiple * multiple_of
|
|
|