File size: 10,594 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import logging
from typing import Optional

import sglang.srt.distributed.parallel_state as parallel_state
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import init_model_parallel_group
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.layers.dp_attention import (
    _DpGatheredBufferWrapper,
    compute_dp_attention_local_info,
    compute_dp_attention_world_info,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var

from specforge.distributed import get_tp_group as get_specforge_tp_group

logger = logging.getLogger(__name__)


def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    local_rank: int = -1,
    backend: str = "nccl",
):
    logger.debug(
        "world_size=%d rank=%d backend=%s",
        world_size,
        rank,
        backend,
    )
    assert (
        torch.distributed.is_initialized()
    ), "distributed environment should be initialized first"

    tp_group = get_specforge_tp_group()
    world_size = dist.get_world_size()
    tp_size = dist.get_world_size(tp_group)
    num_tp_groups = world_size // tp_size
    tp_ranks = []
    for i in range(num_tp_groups):
        tp_ranks.append(list(range(i * tp_size, (i + 1) * tp_size)))

    parallel_state._WORLD = GroupCoordinator(
        group_ranks=tp_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=False,
        use_pymscclpp=False,
        use_custom_allreduce=False,
        use_torch_symm_mem_all_reduce=False,
        use_hpu_communicator=False,
        use_xpu_communicator=False,
        use_npu_communicator=False,
        group_name="world",
    )
    # we destroy the newly created world group and replace it
    # with the existing tp group from specforge to save CUDA memory
    group_to_destroy = parallel_state._WORLD.device_group
    parallel_state._WORLD.device_group = tp_group
    dist.destroy_process_group(group_to_destroy)


def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    expert_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    backend: Optional[str] = None,
    duplicate_tp_group: bool = False,
    torch_compile: Optional[bool] = None,
) -> None:
    """
    Initialize model parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = parallel_state._WORLD.world_size
    backend = backend or dist.get_backend(parallel_state._WORLD.device_group)

    if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:
        raise RuntimeError(
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})"
        )

    # Build the tensor model-parallel groups.
    num_tensor_model_parallel_groups: int = (
        dist.get_world_size() // tensor_model_parallel_size
    )
    assert (
        parallel_state._TP is None
    ), "tensor model parallel group is already initialized"
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
        )
        group_ranks.append(ranks)

    # message queue broadcaster is only used in tensor model parallel group
    parallel_state._TP = init_model_parallel_group(
        group_ranks,
        parallel_state._WORLD.local_rank,
        backend,
        use_message_queue_broadcaster=get_bool_env_var(
            "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
        ),
        group_name="tp",
        pynccl_use_current_stream=duplicate_tp_group,
        torch_compile=torch_compile,
    )

    if duplicate_tp_group:
        assert (
            parallel_state._PDMUX_PREFILL_TP_GROUP is None
        ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
        assert (
            parallel_state._PDMUX_PREFILL_TP_GROUP is None
        ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
        parallel_state._PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            use_message_queue_broadcaster=get_bool_env_var(
                "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
            ),
            group_name="pdmux_prefill_tp",
            pynccl_use_current_stream=True,
            torch_compile=torch_compile,
        )
        parallel_state._TP.pynccl_comm.disabled = False
        parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False

    moe_ep_size = expert_model_parallel_size

    moe_tp_size = tensor_model_parallel_size // moe_ep_size
    assert (
        parallel_state._MOE_EP is None
    ), "expert model parallel group is already initialized"
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        for j in range(moe_tp_size):
            st = i * tensor_model_parallel_size + j
            en = (i + 1) * tensor_model_parallel_size + j
            ranks = list(range(st, en, moe_tp_size))
            group_ranks.append(ranks)

    parallel_state._MOE_EP = init_model_parallel_group(
        group_ranks,
        parallel_state._WORLD.local_rank,
        backend,
        use_custom_allreduce=False,
        group_name="moe_ep",
    )

    assert (
        parallel_state._MOE_TP is None
    ), "moe tensor model parallel group is already initialized"
    if moe_ep_size == 1:
        parallel_state._MOE_TP = parallel_state._TP
    else:
        group_ranks = []
        for i in range(num_tensor_model_parallel_groups):
            for j in range(moe_ep_size):
                st = i * tensor_model_parallel_size + j * moe_tp_size
                en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
                ranks = list(range(st, en))
                group_ranks.append(ranks)
        parallel_state._MOE_TP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            use_custom_allreduce=False,
            group_name="moe_tp",
        )

    # Build the pipeline model-parallel groups.
    num_pipeline_model_parallel_groups: int = (
        dist.get_world_size() // pipeline_model_parallel_size
    )
    assert (
        parallel_state._PP is None
    ), "pipeline model parallel group is already initialized"
    group_ranks = []
    for i in range(num_pipeline_model_parallel_groups):
        ranks = list(
            range(i, dist.get_world_size(), num_pipeline_model_parallel_groups)
        )
        group_ranks.append(ranks)
    # pipeline parallel does not need custom allreduce
    parallel_state._PP = init_model_parallel_group(
        group_ranks,
        parallel_state._WORLD.local_rank,
        backend,
        use_custom_allreduce=False,
        group_name="pp",
    )


def initialize_dp_attention(
    server_args: ServerArgs,
    model_config: ModelConfig,
):
    import sglang.srt.layers.dp_attention as dp_attention
    from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP

    enable_dp_attention = server_args.enable_dp_attention
    tp_size = server_args.tp_size
    dp_size = server_args.dp_size
    moe_dense_tp_size = server_args.moe_dense_tp_size
    pp_size = server_args.pp_size

    tp_rank = parallel_state.get_tensor_model_parallel_rank()

    dp_attention._ENABLE_DP_ATTENTION_FLAG = enable_dp_attention

    (
        dp_attention._ATTN_TP_RANK,
        dp_attention._ATTN_TP_SIZE,
        dp_attention._ATTN_DP_RANK,
    ) = compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size)
    _, _, dp_attention._LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
        enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
    )

    if enable_dp_attention:
        dp_attention._ATTN_DP_SIZE = dp_size
        if moe_dense_tp_size is None:
            dp_attention._LOCAL_ATTN_DP_SIZE = dp_attention._ATTN_DP_SIZE
        else:
            dp_attention._LOCAL_ATTN_DP_SIZE = max(
                1, dp_size // (tp_size // moe_dense_tp_size)
            )
    else:
        dp_attention._ATTN_DP_SIZE = 1
        dp_attention._LOCAL_ATTN_DP_SIZE = 1

    tp_group = parallel_state.get_tp_group()
    num_model_parallel_groups = dist.get_world_size() // (pp_size * tp_size)
    mp_size = pp_size * tp_size
    group_ranks = []

    for i in range(num_model_parallel_groups):
        ranks = [
            list(range(head, head + dp_attention._ATTN_TP_SIZE))
            for head in range(
                mp_size * i, mp_size * (i + 1), dp_attention._ATTN_TP_SIZE
            )
        ]
        group_ranks.extend(ranks)

    dp_attention._ATTN_TP_GROUP = GroupCoordinator(
        group_ranks,
        tp_group.local_rank,
        torch.distributed.get_backend(tp_group.device_group),
        use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
        use_pymscclpp=False,
        use_custom_allreduce=False,
        use_torch_symm_mem_all_reduce=False,
        use_hpu_communicator=False,
        use_xpu_communicator=False,
        use_npu_communicator=False,
        group_name="attention_tp",
    )
    # print(f"{parallel_state._ATTN_TP_GROUP=}")

    _DpGatheredBufferWrapper.set_metadata(
        hidden_size=model_config.hidden_size,
        dtype=model_config.dtype,
        device=torch.device(server_args.device),
    )