# Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT import math import torch import torch.nn as nn import torch.distributed as dist import infinity.models.videovae.utils.diffdist.functional as distops class ContextParallelUtils: _CONTEXT_PARALLEL_GROUP = None _CONTEXT_PARALLEL_SIZE = 0 _CONTEXT_PARALLEL_ON = False """ { "cp_size": 2, } """ CP_CONFIG = None @staticmethod def set_cp_on(on=True): ContextParallelUtils._CONTEXT_PARALLEL_ON = on @staticmethod def cp_on(): return ContextParallelUtils._CONTEXT_PARALLEL_ON @staticmethod def get_cp_cfg(): return ContextParallelUtils.CP_CONFIG @staticmethod def is_cp_initialized(): if ContextParallelUtils._CONTEXT_PARALLEL_GROUP is None: return False else: return True @staticmethod def initialize_context_parallel(cp_config:dict): assert ContextParallelUtils._CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized" context_parallel_size = cp_config["cp_size"] if context_parallel_size > 1: ContextParallelUtils.CP_CONFIG = cp_config else: print(f"WARN: context parallel size must > 1 but got {context_parallel_size}") return ContextParallelUtils._CONTEXT_PARALLEL_SIZE = context_parallel_size rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() for i in range(0, world_size, context_parallel_size): ranks = range(i, i + context_parallel_size) group = torch.distributed.new_group(ranks) if rank in ranks: ContextParallelUtils._CONTEXT_PARALLEL_GROUP = group break @staticmethod def get_cp_group(): return ContextParallelUtils._CONTEXT_PARALLEL_GROUP @staticmethod def get_cp_size(): return ContextParallelUtils._CONTEXT_PARALLEL_SIZE @staticmethod def get_cp_world_size(): if ContextParallelUtils.is_cp_initialized(): world_size = torch.distributed.get_world_size() return world_size // ContextParallelUtils._CONTEXT_PARALLEL_SIZE else: return 0 @staticmethod def get_cp_rank(): if ContextParallelUtils.is_cp_initialized(): global_rank = torch.distributed.get_rank() cp_rank = global_rank % ContextParallelUtils._CONTEXT_PARALLEL_SIZE return cp_rank else: return 0 def get_cp_group_rank(): if ContextParallelUtils.is_cp_initialized(): rank = torch.distributed.get_rank() cp_group_rank = rank // ContextParallelUtils._CONTEXT_PARALLEL_SIZE return cp_group_rank else: return 0 def _gather_tensor_shape(local_ts): cp_size = ContextParallelUtils.get_cp_size() local_shape = torch.tensor(local_ts.shape, dtype=torch.int64, device=local_ts.device) gathered_shapes = [torch.zeros(len(local_shape), dtype=torch.int64, device=local_ts.device) for _ in range(cp_size)] dist.all_gather(gathered_shapes, local_shape, group=ContextParallelUtils._CONTEXT_PARALLEL_GROUP) return [shape.tolist() for shape in gathered_shapes] @torch.compiler.disable() def dist_encoder_gather_result(res)->list: cp_size = ContextParallelUtils.get_cp_size() if cp_size < 2: return res shape_list = _gather_tensor_shape(res) # [[1,2,3,4],[x,x,x,x]] list of shapes on different rank encs=[torch.zeros(s, device=res.device, dtype=res.dtype) for s in shape_list] dist.barrier() encs = distops.all_gather(encs, res, group=ContextParallelUtils._CONTEXT_PARALLEL_GROUP) return encs @torch.compiler.disable() def dist_decoder_gather_result(res)->list: cp_size = ContextParallelUtils.get_cp_size() if cp_size < 2: return res shape_list = _gather_tensor_shape(res) # [[1,2,3,4],[x,x,x,x]] list of shapes on different rank decs = [torch.zeros(s, device=res.device, dtype=res.dtype) for s in shape_list] dist.barrier() decs = distops.all_gather(decs, res, group=ContextParallelUtils._CONTEXT_PARALLEL_GROUP) return decs def _send_with_shape(local_ts, next_rank): local_shape = torch.tensor(local_ts.shape, dtype=torch.int64, device=local_ts.device) torch.distributed.send(local_shape.contiguous(), next_rank) torch.distributed.send(local_ts.contiguous(), next_rank) def _recv_with_shape(pre_rank): device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu') shape = torch.zeros(5, dtype=torch.int64, device=device) torch.distributed.recv(shape, pre_rank) ts = torch.zeros(shape.tolist(), device=device) torch.distributed.recv(ts, pre_rank) return ts @torch.compiler.disable() def dist_conv_cache_send(conv_cache): cp_rank = ContextParallelUtils.get_cp_rank() global_rank = torch.distributed.get_rank() cp_size = ContextParallelUtils.get_cp_size() if cp_rank == cp_size - 1: return if conv_cache is None: return next_rank = global_rank + 1 _send_with_shape(conv_cache, next_rank) @torch.compiler.disable() def dist_conv_cache_recv(): cp_rank = ContextParallelUtils.get_cp_rank() global_rank = torch.distributed.get_rank() if cp_rank == 0: return None pre_rank = global_rank - 1 return _recv_with_shape(pre_rank)