|
|
|
|
|
|
|
|
from typing import Any, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from torch import Tensor |
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
|
|
|
class AllToAll(torch.autograd.Function): |
|
|
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single |
|
|
operation in torch.distributed. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
inputs: Tensor, |
|
|
group: ProcessGroup, |
|
|
overlap: bool = False, |
|
|
) -> Tuple[Tensor, Any]: |
|
|
""" |
|
|
Returns: |
|
|
outputs: Tensor |
|
|
handle: Optional[Work], if overlap is True |
|
|
""" |
|
|
assert ctx is not None or not overlap |
|
|
|
|
|
if ctx is not None: |
|
|
ctx.comm_grp = group |
|
|
if not inputs.is_contiguous(): |
|
|
inputs = inputs.contiguous() |
|
|
if dist.get_world_size(group) == 1: |
|
|
return inputs, None |
|
|
output = torch.empty_like(inputs) |
|
|
if not overlap: |
|
|
dist.all_to_all_single(output, inputs, group=group) |
|
|
return output, None |
|
|
else: |
|
|
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) |
|
|
return output, handle |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
|
return ( |
|
|
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
class AsyncAllGatherForTwo(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
inputs: Tensor, |
|
|
weight: Tensor, |
|
|
bias: Tensor, |
|
|
sp_rank: int, |
|
|
sp_size: int, |
|
|
group: Optional[ProcessGroup] = None, |
|
|
) -> Tuple[Tensor, Any]: |
|
|
""" |
|
|
Returns: |
|
|
outputs: Tensor |
|
|
handle: Optional[Work], if overlap is True |
|
|
""" |
|
|
from torch.distributed._functional_collectives import all_gather_tensor |
|
|
|
|
|
ctx.group = group |
|
|
ctx.sp_rank = sp_rank |
|
|
ctx.sp_size = sp_size |
|
|
|
|
|
|
|
|
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group) |
|
|
|
|
|
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0) |
|
|
|
|
|
|
|
|
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1]) |
|
|
|
|
|
remote_qkv = F.linear(remote_inputs, weight, bias) |
|
|
|
|
|
|
|
|
if sp_rank == 0: |
|
|
qkv = torch.cat([local_qkv, remote_qkv], dim=0) |
|
|
else: |
|
|
qkv = torch.cat([remote_qkv, local_qkv], dim=0) |
|
|
qkv = rearrange(qkv, "sp b n c -> b (sp n) c") |
|
|
|
|
|
ctx.save_for_backward(inputs, weight, remote_inputs) |
|
|
return qkv |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
|
from torch.distributed._functional_collectives import reduce_scatter_tensor |
|
|
|
|
|
group = ctx.group |
|
|
sp_rank = ctx.sp_rank |
|
|
sp_size = ctx.sp_size |
|
|
inputs, weight, remote_inputs = ctx.saved_tensors |
|
|
|
|
|
|
|
|
qkv_grad = grad_outputs[0] |
|
|
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size) |
|
|
qkv_grad = torch.chunk(qkv_grad, 2, dim=0) |
|
|
if sp_rank == 0: |
|
|
local_qkv_grad, remote_qkv_grad = qkv_grad |
|
|
else: |
|
|
remote_qkv_grad, local_qkv_grad = qkv_grad |
|
|
|
|
|
|
|
|
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0) |
|
|
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0) |
|
|
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0) |
|
|
|
|
|
|
|
|
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad) |
|
|
if sp_rank == 0: |
|
|
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0) |
|
|
else: |
|
|
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0) |
|
|
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group) |
|
|
|
|
|
|
|
|
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0) |
|
|
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0) |
|
|
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0) |
|
|
|
|
|
|
|
|
inputs_grad = remote_inputs_grad + local_input_grad |
|
|
return inputs_grad, weight_grad, bias_grad, None, None, None |
|
|
|
|
|
|
|
|
class AllGather(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
inputs: Tensor, |
|
|
group: Optional[ProcessGroup] = None, |
|
|
overlap: bool = False, |
|
|
) -> Tuple[Tensor, Any]: |
|
|
""" |
|
|
Returns: |
|
|
outputs: Tensor |
|
|
handle: Optional[Work], if overlap is True |
|
|
""" |
|
|
assert ctx is not None or not overlap |
|
|
|
|
|
if ctx is not None: |
|
|
ctx.comm_grp = group |
|
|
|
|
|
comm_size = dist.get_world_size(group) |
|
|
|
|
|
if comm_size == 1: |
|
|
return inputs.unsqueeze(0), None |
|
|
|
|
|
buffer_shape = (comm_size,) + inputs.shape |
|
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) |
|
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not overlap: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dist.all_gather(buffer_list, inputs, group=group) |
|
|
return outputs, None |
|
|
else: |
|
|
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) |
|
|
return outputs, handle |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
|
return ( |
|
|
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
class ReduceScatter(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
inputs: Tensor, |
|
|
group: ProcessGroup, |
|
|
overlap: bool = False, |
|
|
) -> Tuple[Tensor, Any]: |
|
|
""" |
|
|
Returns: |
|
|
outputs: Tensor |
|
|
handle: Optional[Work], if overlap is True |
|
|
""" |
|
|
assert ctx is not None or not overlap |
|
|
|
|
|
if ctx is not None: |
|
|
ctx.comm_grp = group |
|
|
|
|
|
comm_size = dist.get_world_size(group) |
|
|
if comm_size == 1: |
|
|
return inputs.squeeze(0), None |
|
|
|
|
|
if not inputs.is_contiguous(): |
|
|
inputs = inputs.contiguous() |
|
|
|
|
|
output_shape = inputs.shape[1:] |
|
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) |
|
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) |
|
|
if not overlap: |
|
|
dist.reduce_scatter(outputs, buffer_list, group=group) |
|
|
return outputs, None |
|
|
else: |
|
|
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) |
|
|
return outputs, handle |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
|
|
|
|
return ( |
|
|
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): |
|
|
inp_shape = list(input_.shape) |
|
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size |
|
|
if scatter_dim < 2: |
|
|
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous() |
|
|
else: |
|
|
input_t = ( |
|
|
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]) |
|
|
.transpose(0, 1) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
output = torch.empty_like(input_t) |
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
|
|
|
if scatter_dim < 2: |
|
|
output = output.transpose(0, 1).contiguous() |
|
|
|
|
|
return output.reshape( |
|
|
inp_shape[:gather_dim] |
|
|
+ [ |
|
|
inp_shape[gather_dim] * seq_world_size, |
|
|
] |
|
|
+ inp_shape[gather_dim + 1 :] |
|
|
).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): |
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
class _AllToAll(torch.autograd.Function): |
|
|
"""All-to-all communication. |
|
|
|
|
|
Args: |
|
|
input_: input matrix |
|
|
process_group: communication group |
|
|
scatter_dim: scatter dimension |
|
|
gather_dim: gather dimension |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
|
|
ctx.process_group = process_group |
|
|
ctx.scatter_dim = scatter_dim |
|
|
ctx.gather_dim = gather_dim |
|
|
world_size = dist.get_world_size(process_group) |
|
|
bsz, _, _ = input_.shape |
|
|
|
|
|
|
|
|
if bsz == 1: |
|
|
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
else: |
|
|
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, *grad_output): |
|
|
process_group = ctx.process_group |
|
|
scatter_dim = ctx.gather_dim |
|
|
gather_dim = ctx.scatter_dim |
|
|
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) |
|
|
return (return_grad, None, None, None) |
|
|
|
|
|
|
|
|
def model_sharding(model: torch.nn.Module): |
|
|
global_rank = dist.get_rank() |
|
|
world_size = dist.get_world_size() |
|
|
for _, param in model.named_parameters(): |
|
|
padding_size = (world_size - param.numel() % world_size) % world_size |
|
|
if padding_size > 0: |
|
|
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) |
|
|
else: |
|
|
padding_param = param.data.view(-1) |
|
|
splited_params = padding_param.split(padding_param.numel() // world_size) |
|
|
splited_params = splited_params[global_rank] |
|
|
param.data = splited_params |
|
|
|
|
|
|
|
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): |
|
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
|
|
|
|
|
|
|
|
def _gather(input_, dim=-1, process_group=None): |
|
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
if world_size == 1: |
|
|
return input_ |
|
|
|
|
|
|
|
|
input_ = input_.contiguous() |
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
torch.distributed.all_gather(tensor_list, input_, group=process_group) |
|
|
|
|
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def _split(input_, dim=-1, process_group=None): |
|
|
|
|
|
world_size = dist.get_world_size(process_group) |
|
|
if world_size == 1: |
|
|
return input_ |
|
|
|
|
|
|
|
|
dim_size = input_.size(dim) |
|
|
assert dim_size % world_size == 0, ( |
|
|
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " |
|
|
f"cannot split tensor evenly" |
|
|
) |
|
|
|
|
|
tensor_list = torch.split(input_, dim_size // world_size, dim=dim) |
|
|
rank = dist.get_rank(process_group) |
|
|
output = tensor_list[rank].clone().contiguous() |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class _GatherForwardSplitBackward(torch.autograd.Function): |
|
|
"""Gather the input from model parallel region and concatenate. |
|
|
|
|
|
Args: |
|
|
input_: input matrix. |
|
|
parallel_mode: parallel mode. |
|
|
dim: dimension |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_, dim, process_group): |
|
|
ctx.process_group = process_group |
|
|
ctx.dim = dim |
|
|
return _gather(input_, dim, process_group) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None |
|
|
|
|
|
|
|
|
def gather_forward_split_backward(input_, dim, process_group): |
|
|
return _GatherForwardSplitBackward.apply(input_, dim, process_group) |
|
|
|
|
|
|
|
|
|