|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import os |
|
|
import time |
|
|
import random |
|
|
import functools |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
class COMM_INFO: |
|
|
|
|
|
def __init__(self): |
|
|
self.group = None |
|
|
self.sp_size = 1 |
|
|
self.global_rank = 0 |
|
|
self.rank_within_group = 0 |
|
|
self.group_id = 0 |
|
|
|
|
|
|
|
|
self.ts_group_size = 1 |
|
|
self.ts_group = None |
|
|
self.ts_group_id = 0 |
|
|
|
|
|
|
|
|
self.ts_unit_size = 1 |
|
|
self.ts_unit_group = None |
|
|
self.rank_within_ts_unit_group = 0 |
|
|
self.ts_unit_group_id = 0 |
|
|
|
|
|
|
|
|
nccl_info = COMM_INFO() |
|
|
_SEQUENCE_PARALLEL_STATE = False |
|
|
_TEACHER_STUDENT_PARALLEL_STATE = False |
|
|
|
|
|
def initialize_sequence_parallel_state(sequence_parallel_size): |
|
|
global _SEQUENCE_PARALLEL_STATE |
|
|
if sequence_parallel_size > 1: |
|
|
_SEQUENCE_PARALLEL_STATE = True |
|
|
initialize_sequence_parallel_group(sequence_parallel_size) |
|
|
else: |
|
|
nccl_info.sp_size = 1 |
|
|
nccl_info.global_rank = int(os.getenv("RANK", "0")) |
|
|
nccl_info.rank_within_group = 0 |
|
|
nccl_info.group_id = int(os.getenv("RANK", "0")) |
|
|
|
|
|
|
|
|
def set_sequence_parallel_state(state): |
|
|
global _SEQUENCE_PARALLEL_STATE |
|
|
_SEQUENCE_PARALLEL_STATE = state |
|
|
|
|
|
|
|
|
def get_sequence_parallel_state(): |
|
|
return _SEQUENCE_PARALLEL_STATE |
|
|
|
|
|
|
|
|
def initialize_sequence_parallel_group(sequence_parallel_size): |
|
|
"""Initialize the sequence parallel group.""" |
|
|
rank = int(os.getenv("RANK", "0")) |
|
|
world_size = int(os.getenv("WORLD_SIZE", "1")) |
|
|
assert ( |
|
|
world_size % sequence_parallel_size == 0 |
|
|
), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( |
|
|
world_size, sequence_parallel_size |
|
|
) |
|
|
nccl_info.sp_size = sequence_parallel_size |
|
|
nccl_info.global_rank = rank |
|
|
num_sequence_parallel_groups: int = world_size // sequence_parallel_size |
|
|
for i in range(num_sequence_parallel_groups): |
|
|
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) |
|
|
group = dist.new_group(ranks) |
|
|
if rank in ranks: |
|
|
nccl_info.group = group |
|
|
nccl_info.rank_within_group = rank - i * sequence_parallel_size |
|
|
nccl_info.group_id = i |
|
|
|
|
|
def get_sequence_parallel_state(): |
|
|
return _SEQUENCE_PARALLEL_STATE |
|
|
|
|
|
|
|
|
def set_teacher_student_parallel_state(state): |
|
|
global _TEACHER_STUDENT_PARALLEL_STATE |
|
|
_TEACHER_STUDENT_PARALLEL_STATE = state |
|
|
|
|
|
|
|
|
def get_teacher_student_parallel_state(): |
|
|
return _TEACHER_STUDENT_PARALLEL_STATE |
|
|
|
|
|
|
|
|
|
|
|
def initialize_teacher_student_parallel_state(sequence_parallel_size): |
|
|
global _TEACHER_STUDENT_PARALLEL_STATE |
|
|
"""Initialize the teacher-student parallel group.""" |
|
|
rank = int(os.getenv("RANK", "0")) |
|
|
world_size = int(os.getenv("WORLD_SIZE", "1")) |
|
|
assert ( |
|
|
world_size % (2 * sequence_parallel_size) == 0 |
|
|
), "world_size must be divisible by 2 * sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( |
|
|
world_size, sequence_parallel_size |
|
|
) |
|
|
_TEACHER_STUDENT_PARALLEL_STATE = True |
|
|
nccl_info.global_rank = rank |
|
|
|
|
|
|
|
|
|
|
|
nccl_info.ts_unit_size = sequence_parallel_size * 2 |
|
|
num_teacher_student_union_groups = world_size // sequence_parallel_size // 2 |
|
|
for j in range(num_teacher_student_union_groups): |
|
|
ts_unit_ranks = range(j * sequence_parallel_size * 2, (j+1) * sequence_parallel_size * 2) |
|
|
ts_unit_group = dist.new_group(ts_unit_ranks) |
|
|
if rank in ts_unit_ranks: |
|
|
nccl_info.ts_unit_group = ts_unit_group |
|
|
nccl_info.ts_unit_group_id = j |
|
|
nccl_info.rank_within_ts_unit_group = rank - j * sequence_parallel_size * 2 |
|
|
|
|
|
|
|
|
nccl_info.ts_group_size = world_size // 2 |
|
|
for i in range(2): |
|
|
ranks = [] |
|
|
for j in range(num_teacher_student_union_groups): |
|
|
ranks += range((j*2+i) * sequence_parallel_size, (j*2+i+1) * sequence_parallel_size) |
|
|
|
|
|
ts_group = dist.new_group(ranks) |
|
|
if rank in ranks: |
|
|
nccl_info.ts_group = ts_group |
|
|
nccl_info.ts_group_id = i |
|
|
|
|
|
def destroy_sequence_parallel_group(): |
|
|
"""Destroy the sequence parallel group.""" |
|
|
dist.destroy_process_group() |
|
|
|
|
|
def is_teacher_group(): |
|
|
if _TEACHER_STUDENT_PARALLEL_STATE: |
|
|
return nccl_info.group_id % 2 == 1 |
|
|
else: |
|
|
return True |
|
|
|
|
|
def is_student_group(): |
|
|
if _TEACHER_STUDENT_PARALLEL_STATE: |
|
|
return nccl_info.group_id % 2 == 0 |
|
|
else: |
|
|
return True |
|
|
|