import os import math import random import argparse import datetime import logging import inspect import subprocess import torch import torch.distributed as dist from torch.distributed.device_mesh import init_device_mesh from einops import rearrange, repeat dp_size = None cp_size = None dp_group = None cp_group = None cp_stream = None dp_ranks = None cp_ranks = None dp_rank = None cp_rank = None def init_context_parallel(context_parallel_size: int = 1, global_rank: int = 1, world_size: int = 1,): global dp_size global cp_size global dp_group global cp_group global dp_ranks global cp_ranks global dp_rank global cp_rank if world_size%context_parallel_size != 0: raise RuntimeError(f'world_size {world_size} must be multiple of context_parallel_size {context_parallel_size}!!!') cp_size = context_parallel_size dp_size = world_size//context_parallel_size print(f'[rank {global_rank}] init_device_mesh [dp_size x cp_size]: [{dp_size} x {cp_size}]') mesh_2d = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp")) print(f'[rank {global_rank}] mesh_2d: {mesh_2d}') dp_group = mesh_2d.get_group(mesh_dim="dp") cp_group = mesh_2d.get_group(mesh_dim="cp") dp_ranks = torch.distributed.get_process_group_ranks(dp_group) cp_ranks = torch.distributed.get_process_group_ranks(cp_group) dp_rank = dist.get_rank(group=dp_group) cp_rank = dist.get_rank(group=cp_group) global_rank_1 = torch.distributed.get_rank() print(f'[rank {global_rank_1}] [dp_rank, cp_rank]: [{dp_rank}, {cp_rank}], dp_ranks: {dp_ranks}, cp_ranks: {cp_ranks}') def get_cp_size(): global cp_size return cp_size def get_dp_size(): global dp_size return dp_size def get_cp_stream(): global cp_stream if cp_stream == None: cp_stream = torch.cuda.Stream() return cp_stream def get_dp_group(): global dp_group return dp_group def get_cp_group(): global cp_group return cp_group def get_dp_rank(): global dp_rank global cp_rank return dp_rank def get_cp_rank(): global dp_rank global cp_rank return cp_rank def get_cp_rank_list(): global cp_ranks if cp_ranks == None: cp_ranks = torch.distributed.get_process_group_ranks(cp_group) return cp_ranks def cp_broadcast(tensor, cp_index=0): global dp_group global cp_group cp_ranks = get_cp_rank_list() torch.distributed.broadcast(tensor, cp_ranks[cp_index], group=cp_group) def cp_broadcast_objects(tensor): global dp_group global cp_group raise NotImplementedError("cp_broadcast_objects method is not yet implemented!!!") def split_tensor_in_cp(input, seq_dim): global cp_size seq_size = input.shape[seq_dim] if seq_size%cp_size != 0: raise RuntimeError(f'seq_length {seq_size} in dim {seq_dim} must be multiple of cp_size {cp_size}!!!') split_seq_size = seq_size//cp_size tensor_splits = input.split(split_seq_size, dim=seq_dim) cp_rank = get_cp_rank() split_tensor = tensor_splits[cp_rank] return split_tensor class GatherFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, process_group, seq_dim, frames): ctx.cp_group = process_group ctx.seq_dim = seq_dim ctx.frames = frames ctx.cp_size = get_cp_size() input = rearrange(input, "B (T S) C -> B T S C", T=frames) with torch.no_grad(): input = input.contiguous() output_tensors = [torch.zeros_like(input) for _ in range(ctx.cp_size)] dist.all_gather(output_tensors, input, group=ctx.cp_group) output_tensor = torch.cat(output_tensors, dim=seq_dim) output_tensor = rearrange(output_tensor, "B T S C -> B (T S) C", T=frames) return output_tensor @staticmethod def backward(ctx, grad_output): with torch.no_grad(): grad_output = grad_output * ctx.cp_size grad_output = rearrange(grad_output, "B (T S) C -> B T S C", T=ctx.frames) grad_input = split_tensor_in_cp(grad_output, ctx.seq_dim) grad_input = rearrange(grad_input, "B T S C -> B (T S) C", T=ctx.frames) return grad_input, None, None, None class SplitFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, process_group, seq_dim): ctx.cp_group = process_group ctx.seq_dim = seq_dim ctx.cp_size = get_cp_size() output_tensor = split_tensor_in_cp(input, ctx.seq_dim) return output_tensor @staticmethod def backward(ctx, grad_output): with torch.no_grad(): grad_output = grad_output / ctx.cp_size output_tensors = [torch.zeros_like(grad_output) for _ in range(ctx.cp_size)] dist.all_gather(output_tensors, grad_output, group=ctx.cp_group) grad_input = torch.cat(output_tensors, dim=ctx.seq_dim) return grad_input, None, None def gather_cp(input, frames): cp_process_group = get_cp_group() output_tensor = GatherFunction.apply(input, cp_process_group, 2, frames) return output_tensor def split_cp(input, seq_dim): cp_process_group = get_cp_group() output_tensor = SplitFunction.apply(input, cp_process_group, seq_dim) return output_tensor class ReduceFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, process_group): ctx.cp_group = process_group output = input.detach().clone() dist.all_reduce(output, group=ctx.cp_group) return output @staticmethod def backward(ctx, grad_output): grad_input = grad_output.detach().clone() return grad_input, None class ReplicateFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, process_group): ctx.cp_group = process_group output = input.detach().clone() return output @staticmethod def backward(ctx, grad_output): grad_input = grad_output.detach().clone() dist.all_reduce(grad_input, group=ctx.cp_group) return grad_input, None def reduce_cp(partial_sum, partial_square_sum): cp_process_group = get_cp_group() all_sum = ReduceFunction.apply(partial_sum, cp_process_group) all_square_sum = ReduceFunction.apply(partial_square_sum, cp_process_group) return all_sum, all_square_sum def replicate_cp(all_mean, all_var): cp_process_group = get_cp_group() all_mean = ReplicateFunction.apply(all_mean, cp_process_group) all_var = ReplicateFunction.apply(all_var, cp_process_group) return all_mean, all_var def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim): input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() class _AllToAll(torch.autograd.Function): """All-to-all communication. Args: input_: input matrix process_group: communication group scatter_dim: scatter dimension gather_dim: gather dimension """ @staticmethod def forward(ctx, input_, process_group, scatter_dim, gather_dim): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim world_size = dist.get_world_size(process_group) return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) @staticmethod def backward(ctx, *grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) return (return_grad, None, None, None) def all_to_all_with_pad( input_: torch.Tensor, process_group: dist.ProcessGroup, scatter_dim: int = 2, gather_dim: int = 1, scatter_pad: int = 0, gather_pad: int = 0, ): if scatter_pad > 0: pad_shape = list(input_.shape) pad_shape[scatter_dim] = scatter_pad pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) assert ( input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0 ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})" input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) if gather_pad > 0: input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) return input_ def dynamic_switch(x, scatter_dim, gather_dim): scatter_pad = 0 gather_pad = 0 cp_process_group = get_cp_group() x = all_to_all_with_pad( x, cp_process_group, scatter_dim=scatter_dim, gather_dim=gather_dim, scatter_pad=scatter_pad, gather_pad=gather_pad, ) return x