Infinite-World / infworld /context_parallel /context_parallel_util.py
wuruiqi0722's picture
Upload folder using huggingface_hub
01c7703 verified
import os
import math
import random
import argparse
import datetime
import logging
import inspect
import subprocess
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from einops import rearrange, repeat
dp_size = None
cp_size = None
dp_group = None
cp_group = None
cp_stream = None
dp_ranks = None
cp_ranks = None
dp_rank = None
cp_rank = None
def init_context_parallel(context_parallel_size: int = 1,
global_rank: int = 1,
world_size: int = 1,):
global dp_size
global cp_size
global dp_group
global cp_group
global dp_ranks
global cp_ranks
global dp_rank
global cp_rank
if world_size%context_parallel_size != 0:
raise RuntimeError(f'world_size {world_size} must be multiple of context_parallel_size {context_parallel_size}!!!')
cp_size = context_parallel_size
dp_size = world_size//context_parallel_size
print(f'[rank {global_rank}] init_device_mesh [dp_size x cp_size]: [{dp_size} x {cp_size}]')
mesh_2d = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp"))
print(f'[rank {global_rank}] mesh_2d: {mesh_2d}')
dp_group = mesh_2d.get_group(mesh_dim="dp")
cp_group = mesh_2d.get_group(mesh_dim="cp")
dp_ranks = torch.distributed.get_process_group_ranks(dp_group)
cp_ranks = torch.distributed.get_process_group_ranks(cp_group)
dp_rank = dist.get_rank(group=dp_group)
cp_rank = dist.get_rank(group=cp_group)
global_rank_1 = torch.distributed.get_rank()
print(f'[rank {global_rank_1}] [dp_rank, cp_rank]: [{dp_rank}, {cp_rank}], dp_ranks: {dp_ranks}, cp_ranks: {cp_ranks}')
def get_cp_size():
global cp_size
return cp_size
def get_dp_size():
global dp_size
return dp_size
def get_cp_stream():
global cp_stream
if cp_stream == None:
cp_stream = torch.cuda.Stream()
return cp_stream
def get_dp_group():
global dp_group
return dp_group
def get_cp_group():
global cp_group
return cp_group
def get_dp_rank():
global dp_rank
global cp_rank
return dp_rank
def get_cp_rank():
global dp_rank
global cp_rank
return cp_rank
def get_cp_rank_list():
global cp_ranks
if cp_ranks == None:
cp_ranks = torch.distributed.get_process_group_ranks(cp_group)
return cp_ranks
def cp_broadcast(tensor, cp_index=0):
global dp_group
global cp_group
cp_ranks = get_cp_rank_list()
torch.distributed.broadcast(tensor, cp_ranks[cp_index], group=cp_group)
def cp_broadcast_objects(tensor):
global dp_group
global cp_group
raise NotImplementedError("cp_broadcast_objects method is not yet implemented!!!")
def split_tensor_in_cp(input, seq_dim):
global cp_size
seq_size = input.shape[seq_dim]
if seq_size%cp_size != 0:
raise RuntimeError(f'seq_length {seq_size} in dim {seq_dim} must be multiple of cp_size {cp_size}!!!')
split_seq_size = seq_size//cp_size
tensor_splits = input.split(split_seq_size, dim=seq_dim)
cp_rank = get_cp_rank()
split_tensor = tensor_splits[cp_rank]
return split_tensor
class GatherFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group, seq_dim, frames):
ctx.cp_group = process_group
ctx.seq_dim = seq_dim
ctx.frames = frames
ctx.cp_size = get_cp_size()
input = rearrange(input, "B (T S) C -> B T S C", T=frames)
with torch.no_grad():
input = input.contiguous()
output_tensors = [torch.zeros_like(input) for _ in range(ctx.cp_size)]
dist.all_gather(output_tensors, input, group=ctx.cp_group)
output_tensor = torch.cat(output_tensors, dim=seq_dim)
output_tensor = rearrange(output_tensor, "B T S C -> B (T S) C", T=frames)
return output_tensor
@staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
grad_output = grad_output * ctx.cp_size
grad_output = rearrange(grad_output, "B (T S) C -> B T S C", T=ctx.frames)
grad_input = split_tensor_in_cp(grad_output, ctx.seq_dim)
grad_input = rearrange(grad_input, "B T S C -> B (T S) C", T=ctx.frames)
return grad_input, None, None, None
class SplitFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group, seq_dim):
ctx.cp_group = process_group
ctx.seq_dim = seq_dim
ctx.cp_size = get_cp_size()
output_tensor = split_tensor_in_cp(input, ctx.seq_dim)
return output_tensor
@staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
grad_output = grad_output / ctx.cp_size
output_tensors = [torch.zeros_like(grad_output) for _ in range(ctx.cp_size)]
dist.all_gather(output_tensors, grad_output, group=ctx.cp_group)
grad_input = torch.cat(output_tensors, dim=ctx.seq_dim)
return grad_input, None, None
def gather_cp(input, frames):
cp_process_group = get_cp_group()
output_tensor = GatherFunction.apply(input, cp_process_group, 2, frames)
return output_tensor
def split_cp(input, seq_dim):
cp_process_group = get_cp_group()
output_tensor = SplitFunction.apply(input, cp_process_group, seq_dim)
return output_tensor
class ReduceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group):
ctx.cp_group = process_group
output = input.detach().clone()
dist.all_reduce(output, group=ctx.cp_group)
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.detach().clone()
return grad_input, None
class ReplicateFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, process_group):
ctx.cp_group = process_group
output = input.detach().clone()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.detach().clone()
dist.all_reduce(grad_input, group=ctx.cp_group)
return grad_input, None
def reduce_cp(partial_sum, partial_square_sum):
cp_process_group = get_cp_group()
all_sum = ReduceFunction.apply(partial_sum, cp_process_group)
all_square_sum = ReduceFunction.apply(partial_square_sum, cp_process_group)
return all_sum, all_square_sum
def replicate_cp(all_mean, all_var):
cp_process_group = get_cp_group()
all_mean = ReplicateFunction.apply(all_mean, cp_process_group)
all_var = ReplicateFunction.apply(all_var, cp_process_group)
return all_mean, all_var
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
def all_to_all_with_pad(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
scatter_dim: int = 2,
gather_dim: int = 1,
scatter_pad: int = 0,
gather_pad: int = 0,
):
if scatter_pad > 0:
pad_shape = list(input_.shape)
pad_shape[scatter_dim] = scatter_pad
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
assert (
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
if gather_pad > 0:
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
return input_
def dynamic_switch(x, scatter_dim, gather_dim):
scatter_pad = 0
gather_pad = 0
cp_process_group = get_cp_group()
x = all_to_all_with_pad(
x,
cp_process_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
gather_pad=gather_pad,
)
return x