|
|
import os |
|
|
import math |
|
|
import random |
|
|
import argparse |
|
|
import datetime |
|
|
import logging |
|
|
import inspect |
|
|
import subprocess |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.distributed.device_mesh import init_device_mesh |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
|
|
|
dp_size = None |
|
|
cp_size = None |
|
|
dp_group = None |
|
|
cp_group = None |
|
|
cp_stream = None |
|
|
dp_ranks = None |
|
|
cp_ranks = None |
|
|
dp_rank = None |
|
|
cp_rank = None |
|
|
|
|
|
|
|
|
def init_context_parallel(context_parallel_size: int = 1, |
|
|
global_rank: int = 1, |
|
|
world_size: int = 1,): |
|
|
|
|
|
global dp_size |
|
|
global cp_size |
|
|
global dp_group |
|
|
global cp_group |
|
|
global dp_ranks |
|
|
global cp_ranks |
|
|
global dp_rank |
|
|
global cp_rank |
|
|
|
|
|
|
|
|
if world_size%context_parallel_size != 0: |
|
|
raise RuntimeError(f'world_size {world_size} must be multiple of context_parallel_size {context_parallel_size}!!!') |
|
|
|
|
|
|
|
|
cp_size = context_parallel_size |
|
|
dp_size = world_size//context_parallel_size |
|
|
|
|
|
|
|
|
print(f'[rank {global_rank}] init_device_mesh [dp_size x cp_size]: [{dp_size} x {cp_size}]') |
|
|
|
|
|
mesh_2d = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp")) |
|
|
|
|
|
print(f'[rank {global_rank}] mesh_2d: {mesh_2d}') |
|
|
|
|
|
dp_group = mesh_2d.get_group(mesh_dim="dp") |
|
|
cp_group = mesh_2d.get_group(mesh_dim="cp") |
|
|
|
|
|
dp_ranks = torch.distributed.get_process_group_ranks(dp_group) |
|
|
cp_ranks = torch.distributed.get_process_group_ranks(cp_group) |
|
|
|
|
|
dp_rank = dist.get_rank(group=dp_group) |
|
|
cp_rank = dist.get_rank(group=cp_group) |
|
|
|
|
|
global_rank_1 = torch.distributed.get_rank() |
|
|
print(f'[rank {global_rank_1}] [dp_rank, cp_rank]: [{dp_rank}, {cp_rank}], dp_ranks: {dp_ranks}, cp_ranks: {cp_ranks}') |
|
|
|
|
|
|
|
|
def get_cp_size(): |
|
|
|
|
|
global cp_size |
|
|
return cp_size |
|
|
|
|
|
def get_dp_size(): |
|
|
|
|
|
global dp_size |
|
|
return dp_size |
|
|
|
|
|
def get_cp_stream(): |
|
|
|
|
|
global cp_stream |
|
|
if cp_stream == None: |
|
|
cp_stream = torch.cuda.Stream() |
|
|
|
|
|
return cp_stream |
|
|
|
|
|
def get_dp_group(): |
|
|
|
|
|
global dp_group |
|
|
return dp_group |
|
|
|
|
|
def get_cp_group(): |
|
|
|
|
|
global cp_group |
|
|
return cp_group |
|
|
|
|
|
|
|
|
def get_dp_rank(): |
|
|
|
|
|
global dp_rank |
|
|
global cp_rank |
|
|
|
|
|
return dp_rank |
|
|
|
|
|
|
|
|
def get_cp_rank(): |
|
|
|
|
|
global dp_rank |
|
|
global cp_rank |
|
|
|
|
|
return cp_rank |
|
|
|
|
|
|
|
|
|
|
|
def get_cp_rank_list(): |
|
|
|
|
|
global cp_ranks |
|
|
if cp_ranks == None: |
|
|
cp_ranks = torch.distributed.get_process_group_ranks(cp_group) |
|
|
return cp_ranks |
|
|
|
|
|
|
|
|
def cp_broadcast(tensor, cp_index=0): |
|
|
|
|
|
global dp_group |
|
|
global cp_group |
|
|
|
|
|
cp_ranks = get_cp_rank_list() |
|
|
|
|
|
torch.distributed.broadcast(tensor, cp_ranks[cp_index], group=cp_group) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cp_broadcast_objects(tensor): |
|
|
|
|
|
global dp_group |
|
|
global cp_group |
|
|
|
|
|
raise NotImplementedError("cp_broadcast_objects method is not yet implemented!!!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_tensor_in_cp(input, seq_dim): |
|
|
|
|
|
global cp_size |
|
|
|
|
|
seq_size = input.shape[seq_dim] |
|
|
|
|
|
if seq_size%cp_size != 0: |
|
|
raise RuntimeError(f'seq_length {seq_size} in dim {seq_dim} must be multiple of cp_size {cp_size}!!!') |
|
|
|
|
|
split_seq_size = seq_size//cp_size |
|
|
|
|
|
tensor_splits = input.split(split_seq_size, dim=seq_dim) |
|
|
|
|
|
cp_rank = get_cp_rank() |
|
|
|
|
|
split_tensor = tensor_splits[cp_rank] |
|
|
|
|
|
return split_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GatherFunction(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input, process_group, seq_dim, frames): |
|
|
ctx.cp_group = process_group |
|
|
ctx.seq_dim = seq_dim |
|
|
ctx.frames = frames |
|
|
ctx.cp_size = get_cp_size() |
|
|
|
|
|
input = rearrange(input, "B (T S) C -> B T S C", T=frames) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
input = input.contiguous() |
|
|
|
|
|
output_tensors = [torch.zeros_like(input) for _ in range(ctx.cp_size)] |
|
|
|
|
|
dist.all_gather(output_tensors, input, group=ctx.cp_group) |
|
|
|
|
|
output_tensor = torch.cat(output_tensors, dim=seq_dim) |
|
|
|
|
|
|
|
|
|
|
|
output_tensor = rearrange(output_tensor, "B T S C -> B (T S) C", T=frames) |
|
|
|
|
|
|
|
|
return output_tensor |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
grad_output = grad_output * ctx.cp_size |
|
|
|
|
|
grad_output = rearrange(grad_output, "B (T S) C -> B T S C", T=ctx.frames) |
|
|
|
|
|
grad_input = split_tensor_in_cp(grad_output, ctx.seq_dim) |
|
|
|
|
|
grad_input = rearrange(grad_input, "B T S C -> B (T S) C", T=ctx.frames) |
|
|
|
|
|
|
|
|
return grad_input, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SplitFunction(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input, process_group, seq_dim): |
|
|
ctx.cp_group = process_group |
|
|
ctx.seq_dim = seq_dim |
|
|
ctx.cp_size = get_cp_size() |
|
|
|
|
|
output_tensor = split_tensor_in_cp(input, ctx.seq_dim) |
|
|
|
|
|
return output_tensor |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
grad_output = grad_output / ctx.cp_size |
|
|
|
|
|
output_tensors = [torch.zeros_like(grad_output) for _ in range(ctx.cp_size)] |
|
|
|
|
|
dist.all_gather(output_tensors, grad_output, group=ctx.cp_group) |
|
|
|
|
|
grad_input = torch.cat(output_tensors, dim=ctx.seq_dim) |
|
|
|
|
|
|
|
|
return grad_input, None, None |
|
|
|
|
|
|
|
|
|
|
|
def gather_cp(input, frames): |
|
|
|
|
|
cp_process_group = get_cp_group() |
|
|
|
|
|
output_tensor = GatherFunction.apply(input, cp_process_group, 2, frames) |
|
|
|
|
|
return output_tensor |
|
|
|
|
|
|
|
|
def split_cp(input, seq_dim): |
|
|
|
|
|
cp_process_group = get_cp_group() |
|
|
|
|
|
output_tensor = SplitFunction.apply(input, cp_process_group, seq_dim) |
|
|
|
|
|
return output_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceFunction(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input, process_group): |
|
|
ctx.cp_group = process_group |
|
|
|
|
|
output = input.detach().clone() |
|
|
|
|
|
dist.all_reduce(output, group=ctx.cp_group) |
|
|
|
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
grad_input = grad_output.detach().clone() |
|
|
|
|
|
return grad_input, None |
|
|
|
|
|
|
|
|
|
|
|
class ReplicateFunction(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input, process_group): |
|
|
ctx.cp_group = process_group |
|
|
|
|
|
output = input.detach().clone() |
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
|
|
|
grad_input = grad_output.detach().clone() |
|
|
|
|
|
dist.all_reduce(grad_input, group=ctx.cp_group) |
|
|
|
|
|
|
|
|
return grad_input, None |
|
|
|
|
|
|
|
|
def reduce_cp(partial_sum, partial_square_sum): |
|
|
|
|
|
cp_process_group = get_cp_group() |
|
|
|
|
|
all_sum = ReduceFunction.apply(partial_sum, cp_process_group) |
|
|
all_square_sum = ReduceFunction.apply(partial_square_sum, cp_process_group) |
|
|
|
|
|
return all_sum, all_square_sum |
|
|
|
|
|
|
|
|
def replicate_cp(all_mean, all_var): |
|
|
|
|
|
cp_process_group = get_cp_group() |
|
|
|
|
|
all_mean = ReplicateFunction.apply(all_mean, cp_process_group) |
|
|
all_var = ReplicateFunction.apply(all_var, cp_process_group) |
|
|
|
|
|
return all_mean, all_var |
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim): |
|
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
class _AllToAll(torch.autograd.Function): |
|
|
"""All-to-all communication. |
|
|
|
|
|
Args: |
|
|
input_: input matrix |
|
|
process_group: communication group |
|
|
scatter_dim: scatter dimension |
|
|
gather_dim: gather dimension |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
|
|
ctx.process_group = process_group |
|
|
ctx.scatter_dim = scatter_dim |
|
|
ctx.gather_dim = gather_dim |
|
|
world_size = dist.get_world_size(process_group) |
|
|
|
|
|
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, *grad_output): |
|
|
process_group = ctx.process_group |
|
|
scatter_dim = ctx.gather_dim |
|
|
gather_dim = ctx.scatter_dim |
|
|
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) |
|
|
return (return_grad, None, None, None) |
|
|
|
|
|
|
|
|
def all_to_all_with_pad( |
|
|
input_: torch.Tensor, |
|
|
process_group: dist.ProcessGroup, |
|
|
scatter_dim: int = 2, |
|
|
gather_dim: int = 1, |
|
|
scatter_pad: int = 0, |
|
|
gather_pad: int = 0, |
|
|
): |
|
|
if scatter_pad > 0: |
|
|
pad_shape = list(input_.shape) |
|
|
pad_shape[scatter_dim] = scatter_pad |
|
|
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) |
|
|
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) |
|
|
|
|
|
assert ( |
|
|
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0 |
|
|
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})" |
|
|
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
|
|
|
|
|
if gather_pad > 0: |
|
|
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) |
|
|
|
|
|
return input_ |
|
|
|
|
|
|
|
|
def dynamic_switch(x, scatter_dim, gather_dim): |
|
|
|
|
|
scatter_pad = 0 |
|
|
gather_pad = 0 |
|
|
cp_process_group = get_cp_group() |
|
|
|
|
|
x = all_to_all_with_pad( |
|
|
x, |
|
|
cp_process_group, |
|
|
scatter_dim=scatter_dim, |
|
|
gather_dim=gather_dim, |
|
|
scatter_pad=scatter_pad, |
|
|
gather_pad=gather_pad, |
|
|
) |
|
|
return x |
|
|
|