BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch.distributed import ProcessGroup
if torch.__version__ >= "2.4.0":
_torch_custom_op_wrapper = torch.library.custom_op
_torch_register_fake_wrapper = torch.library.register_fake
else:
def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
if fn is None:
return wrap
return fn
def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
if fn is None:
return wrap
return fn
_torch_custom_op_wrapper = noop_custom_op_wrapper
_torch_register_fake_wrapper = noop_register_fake_wrapper
__sp_comm_group__ = None
def set_sp_comm_group(group=None):
global __sp_comm_group__
assert __sp_comm_group__ is None and group is not None
__sp_comm_group__ = group
def get_sp_comm_group():
global __sp_comm_group__
assert __sp_comm_group__ is not None
return __sp_comm_group__
# ======================================================
# Model
# ======================================================
def model_sharding(model: torch.nn.Module):
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
param.data = splited_params
# ======================================================
# AllGather & ReduceScatter
# ======================================================
class AsyncAllGatherForTwo(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
weight: Tensor,
bias: Tensor,
sp_rank: int,
sp_size: int,
group: Optional[ProcessGroup] = None,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
from torch.distributed._functional_collectives import all_gather_tensor
ctx.group = group
ctx.sp_rank = sp_rank
ctx.sp_size = sp_size
# all gather inputs
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
# compute local qkv
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
# remote compute
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
# compute remote qkv
remote_qkv = F.linear(remote_inputs, weight, bias)
# concat local and remote qkv
if sp_rank == 0:
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
else:
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
ctx.save_for_backward(inputs, weight, remote_inputs)
return qkv
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
from torch.distributed._functional_collectives import reduce_scatter_tensor
group = ctx.group
sp_rank = ctx.sp_rank
sp_size = ctx.sp_size
inputs, weight, remote_inputs = ctx.saved_tensors
# split qkv_grad
qkv_grad = grad_outputs[0]
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
if sp_rank == 0:
local_qkv_grad, remote_qkv_grad = qkv_grad
else:
remote_qkv_grad, local_qkv_grad = qkv_grad
# compute remote grad
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
# launch async reduce scatter
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
if sp_rank == 0:
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
else:
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
# compute local grad and wait for reduce scatter
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
# sum remote and local grad
inputs_grad = remote_inputs_grad + local_input_grad
return inputs_grad, weight_grad, bias_grad, None, None, None
class AllGather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
# ======================================================
# AlltoAll
# ======================================================
@_torch_custom_op_wrapper("distributed::_all_to_all_func", mutates_args=(), device_types="cuda")
def _all_to_all_func(input_: torch.Tensor, world_size: int = 1, scatter_dim: int = 0, gather_dim: int = 0) -> torch.Tensor:
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
group = get_sp_comm_group()
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
@_torch_register_fake_wrapper("distributed::_all_to_all_func")
def _all_to_all_func_fake(input_: torch.Tensor, world_size: int = 1, scatter_dim: int = 0, gather_dim: int = 0) -> torch.Tensor:
inp_shape = list(input_.shape)
group = get_sp_comm_group()
world_size = dist.get_world_size(group)
if world_size == 1:
return input_
inp_shape[gather_dim] = inp_shape[gather_dim] * world_size
inp_shape[scatter_dim] = inp_shape[scatter_dim] // world_size
outputs = torch.empty(torch.Size(inp_shape), dtype=input_.dtype, device=input_.device, layout=input_.layout)
return outputs
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
return _wrapper_all_to_all_func(input_, world_size, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
# ======================================================
# Sequence Gather & Split
# ======================================================
def _split_sequence_func(inputs, pg: dist.ProcessGroup, dim=-1):
world_size = dist.get_world_size(pg)
if world_size == 1:
return inputs
# Split along last dimension.
rank = dist.get_rank(pg)
dim_size = inputs.size(dim)
assert dim_size % world_size == 0, (
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
f"cannot split tensor evenly"
)
outputs = torch.split(inputs, dim_size // world_size, dim=dim)[rank]
return outputs
@_torch_custom_op_wrapper("distributed::_gather_sequence_func", mutates_args=(), device_types="cuda")
def _gather_sequence_func(inputs: torch.Tensor, dim: int = -1) -> torch.Tensor:
pg = get_sp_comm_group()
world_size = dist.get_world_size(pg)
if world_size == 1:
return inputs
# all gather
inputs = inputs.contiguous()
outputs = [torch.empty_like(inputs) for _ in range(world_size)]
dist.all_gather(outputs, inputs, group=pg)
# concat
outputs = torch.cat(outputs, dim=dim)
return outputs
@_torch_register_fake_wrapper("distributed::_gather_sequence_func")
def _gather_sequence_func_fake(inputs: torch.Tensor, dim: int = -1) -> torch.Tensor:
inp_shape = list(inputs.shape)
pg = get_sp_comm_group()
world_size = dist.get_world_size(pg)
if world_size == 1:
return inputs
inp_shape[dim] = inp_shape[dim] * world_size
outputs = torch.empty(torch.Size(inp_shape), dtype=inputs.dtype, device=inputs.device, layout=inputs.layout)
return outputs
if torch.__version__ >= "2.4.0":
_wrapper_all_to_all_func = torch.ops.distributed._all_to_all_func
_wrapper_gather_sequence_func = torch.ops.distributed._gather_sequence_func
else:
_wrapper_all_to_all_func = _all_to_all_func
_wrapper_gather_sequence_func = _gather_sequence_func
class _GatherForwardSplitBackward(torch.autograd.Function):
"""
Gather the input sequence.
Args:
input_: input matrix.
process_group: process group.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _wrapper_gather_sequence_func(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _wrapper_gather_sequence_func(input_, dim)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim), None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split sequence.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split_sequence_func(input_)
@staticmethod
def forward(ctx, input_, process_group, dim, grad_scale):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _split_sequence_func(input_, process_group, dim)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale == "up":
grad_output = grad_output * dist.get_world_size(ctx.process_group)
elif ctx.grad_scale == "down":
grad_output = grad_output / dist.get_world_size(ctx.process_group)
return _wrapper_gather_sequence_func(grad_output, ctx.dim), None, None, None
def split_sequence(input_, process_group, dim, grad_scale=1.0):
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale)
def gather_sequence(input_, process_group, dim, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale)