| import os |
| import torch |
| import torch.distributed as dist |
| from .utils import is_dist_avail_and_initialized, get_rank |
|
|
|
|
| SEQ_PARALLEL_GROUP = None |
| SEQ_PARALLEL_SIZE = None |
| SEQ_PARALLEL_PROC_NUM = None |
|
|
| SYNC_INPUT_GROUP = None |
| SYNC_INPUT_SIZE = None |
|
|
| def is_sequence_parallel_initialized(): |
| if SEQ_PARALLEL_GROUP is None: |
| return False |
| else: |
| return True |
|
|
|
|
| def init_sequence_parallel_group(args): |
| global SEQ_PARALLEL_GROUP |
| global SEQ_PARALLEL_SIZE |
| global SEQ_PARALLEL_PROC_NUM |
|
|
| assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized" |
| assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized" |
| SEQ_PARALLEL_SIZE = args.sp_group_size |
| |
| print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}") |
|
|
| rank = torch.distributed.get_rank() |
| world_size = torch.distributed.get_world_size() |
|
|
| if args.sp_proc_num == -1: |
| SEQ_PARALLEL_PROC_NUM = world_size |
| else: |
| SEQ_PARALLEL_PROC_NUM = args.sp_proc_num |
|
|
| assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided" |
|
|
| for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE): |
| ranks = list(range(i, i + SEQ_PARALLEL_SIZE)) |
| group = torch.distributed.new_group(ranks) |
| if rank in ranks: |
| SEQ_PARALLEL_GROUP = group |
| break |
|
|
|
|
| def init_sync_input_group(args): |
| global SYNC_INPUT_GROUP |
| global SYNC_INPUT_SIZE |
|
|
| assert SYNC_INPUT_GROUP is None, "parallel group is already initialized" |
| assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized" |
| SYNC_INPUT_SIZE = args.max_frames |
|
|
| rank = torch.distributed.get_rank() |
| world_size = torch.distributed.get_world_size() |
|
|
| for i in range(0, world_size, SYNC_INPUT_SIZE): |
| ranks = list(range(i, i + SYNC_INPUT_SIZE)) |
| group = torch.distributed.new_group(ranks) |
| if rank in ranks: |
| SYNC_INPUT_GROUP = group |
| break |
|
|
|
|
| def get_sequence_parallel_group(): |
| assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized" |
| return SEQ_PARALLEL_GROUP |
|
|
|
|
| def get_sync_input_group(): |
| return SYNC_INPUT_GROUP |
|
|
|
|
| def get_sequence_parallel_world_size(): |
| assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized" |
| return SEQ_PARALLEL_SIZE |
|
|
|
|
| def get_sequence_parallel_rank(): |
| assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized" |
| rank = get_rank() |
| cp_rank = rank % SEQ_PARALLEL_SIZE |
| return cp_rank |
|
|
|
|
| def get_sequence_parallel_group_rank(): |
| assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized" |
| rank = get_rank() |
| cp_group_rank = rank // SEQ_PARALLEL_SIZE |
| return cp_group_rank |
|
|
|
|
| def get_sequence_parallel_proc_num(): |
| return SEQ_PARALLEL_PROC_NUM |
|
|