File size: 3,139 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import torch
import torch.nn as nn
import torch.distributed as dist
from .comm.pg_utils import ProcessGroupManager
from .comm.comm import set_sp_comm_group, split_sequence, gather_sequence, all_to_all_comm
from .comm.operation import gather_forward_split_backward

class SequenceParallelManager:
    _SP_GROUP = None
    _SP_SIZE = 0

    @staticmethod
    def sp_on():
        return SequenceParallelManager._SP_GROUP is not None

    @staticmethod
    def init_sp(sp_size):
        if SequenceParallelManager._SP_GROUP is not None:
            print("WARN: sequence parallel group is already initialized")
            return

        if sp_size <= 1:
            print(f"WARN: sequence parallel size must > 1 but got {sp_size}")
            return

        world_size = dist.get_world_size()
        assert world_size % sp_size == 0, f"world_size {world_size} must be divisible by sp_size({sp_size})"
        SequenceParallelManager._SP_SIZE = sp_size

        pm = ProcessGroupManager(
            world_size // sp_size,
            sp_size,
            dp_axis=0,
            sp_axis=1,
        )
        pm_group = pm.sp_group
        set_sp_comm_group(pm_group)
        SequenceParallelManager._SP_GROUP = pm_group
        return

    @staticmethod
    def get_sp_group():
        return SequenceParallelManager._SP_GROUP

    @staticmethod
    def get_sp_size():
        return SequenceParallelManager._SP_SIZE

    @staticmethod
    def get_sp_group_nums():
        # if 2 sp_size, 8 ranks, group nums is 4
        if SequenceParallelManager.sp_on():
            world_size = torch.distributed.get_world_size()
            return world_size // SequenceParallelManager._SP_SIZE
        else:
            return 0

    @staticmethod
    def get_sp_rank():
        if SequenceParallelManager.sp_on():
            global_rank = torch.distributed.get_rank()
            sp_rank = global_rank % SequenceParallelManager._SP_SIZE
            return sp_rank
        else:
            return 0

    def get_sp_group_rank():
        if SequenceParallelManager.sp_on():
            global_rank = torch.distributed.get_rank()
            sp_group_rank = global_rank // SequenceParallelManager._SP_SIZE
            return sp_group_rank
        else:
            return 0

def sp_split_sequence_by_dim(seq, seqlen_dim=1) -> torch.Tensor:
    """
    split the raw sequence by seqlen_dim
    """
    return split_sequence(seq, SequenceParallelManager.get_sp_group(), seqlen_dim, 'down')

def sp_gather_sequence_by_dim(seq, seqlen_dim=1) -> torch.Tensor:
    """
    gather seqlen_dim to recover raw sequence
    """
    return gather_sequence(seq, SequenceParallelManager.get_sp_group(), seqlen_dim, 'up')

def sp_all_to_all(ts, scatter_dim, gather_dim):
    """
    reorder the tensor's dimension, like [raw_seq_len/sp_size, hidden_dim] to [raw_seq_len, hidden_dim/sp_size]

    scatter_dim: the dimension to split the tensor
    gather_dim: the dimension to concatenate
    """

    return all_to_all_comm(ts, SequenceParallelManager.get_sp_group(), scatter_dim, gather_dim)