File size: 4,964 Bytes
e14f899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

import torch
import torch.distributed as dist
import os
import time
import random
import functools
from typing import List, Optional, Tuple, Union

class COMM_INFO:

    def __init__(self):
        self.group = None
        self.sp_size = 1
        self.global_rank = 0
        self.rank_within_group = 0
        self.group_id = 0
        
        # the group info for teacher-student parallel
        self.ts_group_size = 1
        self.ts_group = None  # for fsdp data parallel communication
        self.ts_group_id = 0
        
        # the group for teacher-student unit union
        self.ts_unit_size = 1
        self.ts_unit_group = None
        self.rank_within_ts_unit_group = 0
        self.ts_unit_group_id = 0


nccl_info = COMM_INFO()
_SEQUENCE_PARALLEL_STATE = False
_TEACHER_STUDENT_PARALLEL_STATE = False

def initialize_sequence_parallel_state(sequence_parallel_size):
    global _SEQUENCE_PARALLEL_STATE
    if sequence_parallel_size > 1:
        _SEQUENCE_PARALLEL_STATE = True
        initialize_sequence_parallel_group(sequence_parallel_size)
    else:
        nccl_info.sp_size = 1
        nccl_info.global_rank = int(os.getenv("RANK", "0"))
        nccl_info.rank_within_group = 0
        nccl_info.group_id = int(os.getenv("RANK", "0"))


def set_sequence_parallel_state(state):
    global _SEQUENCE_PARALLEL_STATE
    _SEQUENCE_PARALLEL_STATE = state


def get_sequence_parallel_state():
    return _SEQUENCE_PARALLEL_STATE


def initialize_sequence_parallel_group(sequence_parallel_size):
    """Initialize the sequence parallel group."""
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    assert (
        world_size % sequence_parallel_size == 0
    ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
        world_size, sequence_parallel_size
    )
    nccl_info.sp_size = sequence_parallel_size #序列并行size
    nccl_info.global_rank = rank #全局rank
    num_sequence_parallel_groups: int = world_size // sequence_parallel_size
    for i in range(num_sequence_parallel_groups):
        ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
        group = dist.new_group(ranks)
        if rank in ranks:
            nccl_info.group = group
            nccl_info.rank_within_group = rank - i * sequence_parallel_size #rank在序列并行group中的rank
            nccl_info.group_id = i #sequence parallel group id

def get_sequence_parallel_state():
    return _SEQUENCE_PARALLEL_STATE


def set_teacher_student_parallel_state(state):
    global _TEACHER_STUDENT_PARALLEL_STATE
    _TEACHER_STUDENT_PARALLEL_STATE = state


def get_teacher_student_parallel_state():
    return _TEACHER_STUDENT_PARALLEL_STATE


            
def initialize_teacher_student_parallel_state(sequence_parallel_size):
    global _TEACHER_STUDENT_PARALLEL_STATE
    """Initialize the teacher-student parallel group."""
    rank = int(os.getenv("RANK", "0"))
    world_size = int(os.getenv("WORLD_SIZE", "1"))
    assert (
        world_size % (2 * sequence_parallel_size) == 0
    ), "world_size must be divisible by 2 * sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
        world_size, sequence_parallel_size
    )
    _TEACHER_STUDENT_PARALLEL_STATE = True
    nccl_info.global_rank = rank
    # init ts_unit_group and assign info
    # teacher and student must have the same sp size temporally! 
    # In the unit, front is student, back is teacher
    nccl_info.ts_unit_size = sequence_parallel_size * 2
    num_teacher_student_union_groups = world_size // sequence_parallel_size // 2    
    for j in range(num_teacher_student_union_groups):
        ts_unit_ranks = range(j * sequence_parallel_size * 2, (j+1) * sequence_parallel_size * 2)
        ts_unit_group = dist.new_group(ts_unit_ranks)
        if rank in ts_unit_ranks:
            nccl_info.ts_unit_group = ts_unit_group
            nccl_info.ts_unit_group_id = j
            nccl_info.rank_within_ts_unit_group = rank - j * sequence_parallel_size * 2
    
    # init ts_goup and assign info    
    nccl_info.ts_group_size = world_size // 2
    for i in range(2):
        ranks = []
        for j in range(num_teacher_student_union_groups):
            ranks += range((j*2+i) * sequence_parallel_size, (j*2+i+1) * sequence_parallel_size)
                        
        ts_group = dist.new_group(ranks)
        if rank in ranks:
            nccl_info.ts_group = ts_group
            nccl_info.ts_group_id = i
    
def destroy_sequence_parallel_group():
    """Destroy the sequence parallel group."""
    dist.destroy_process_group()

def is_teacher_group():
    if _TEACHER_STUDENT_PARALLEL_STATE:
        return nccl_info.group_id % 2 == 1
    else:
        return True

def is_student_group():
    if _TEACHER_STUDENT_PARALLEL_STATE:
        return nccl_info.group_id % 2 == 0
    else:
        return True