File size: 15,180 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import logging
from typing import Optional

import sglang.srt.distributed.parallel_state as parallel_state
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import init_model_parallel_group
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.layers.dp_attention import (
    _DpGatheredBufferWrapper,
    compute_dp_attention_local_info,
    compute_dp_attention_world_info,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var

from specforge.distributed import get_tp_group as get_specforge_tp_group

logger = logging.getLogger(__name__)


def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    local_rank: int = -1,
    backend: str = "nccl",
):
    logger.debug(
        "world_size=%d rank=%d backend=%s",
        world_size,
        rank,
        backend,
    )
    assert (
        torch.distributed.is_initialized()
    ), "distributed environment should be initialized first"

    tp_group = get_specforge_tp_group()
    world_size = dist.get_world_size()
    tp_size = dist.get_world_size(tp_group)
    num_tp_groups = world_size // tp_size
    tp_ranks = []
    for i in range(num_tp_groups):
        tp_ranks.append(list(range(i * tp_size, (i + 1) * tp_size)))

    parallel_state._WORLD = GroupCoordinator(
        group_ranks=tp_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=False,
        use_pymscclpp=False,
        use_custom_allreduce=False,
        use_torch_symm_mem_all_reduce=False,
        use_hpu_communicator=False,
        use_xpu_communicator=False,
        use_npu_communicator=False,
        group_name="world",
    )
    # we destroy the newly created world group and replace it
    # with the existing tp group from specforge to save CUDA memory
    group_to_destroy = parallel_state._WORLD.device_group
    parallel_state._WORLD.device_group = tp_group
    dist.destroy_process_group(group_to_destroy)


def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    expert_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    attention_data_parallel_size: int = 1,
    attention_context_model_parallel_size: int = 1,
    moe_data_model_parallel_size: int = 1,
    backend: Optional[str] = None,
    duplicate_tp_group: bool = False,
    # NOTE: torch_compile parameter was removed in sglang 0.5.9
    # torch_compile: Optional[bool] = None,
) -> None:
    """
    Initialize model parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
        attention_data_parallel_size: number of GPUs used for attention data
            parallelism. (Added in sglang 0.5.9)
        attention_context_model_parallel_size: number of GPUs used for attention context
            parallelism. (Added in sglang 0.5.9)
        moe_data_model_parallel_size: number of GPUs used for moe data
            parallelism. (Added in sglang 0.5.9)

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = parallel_state._WORLD.world_size
    backend = backend or dist.get_backend(parallel_state._WORLD.device_group)

    if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:
        raise RuntimeError(
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})"
        )

    # Build the tensor model-parallel groups.
    num_tensor_model_parallel_groups: int = (
        dist.get_world_size() // tensor_model_parallel_size
    )
    assert (
        parallel_state._TP is None
    ), "tensor model parallel group is already initialized"
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        ranks = list(
            range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
        )
        group_ranks.append(ranks)

    # message queue broadcaster is only used in tensor model parallel group
    # NOTE: torch_compile parameter was removed in sglang 0.5.9
    parallel_state._TP = init_model_parallel_group(
        group_ranks,
        parallel_state._WORLD.local_rank,
        backend,
        use_message_queue_broadcaster=get_bool_env_var(
            "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
        ),
        group_name="tp",
        pynccl_use_current_stream=duplicate_tp_group,
    )

    if duplicate_tp_group:
        assert (
            parallel_state._PDMUX_PREFILL_TP_GROUP is None
        ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
        # NOTE: torch_compile parameter was removed in sglang 0.5.9
        parallel_state._PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            use_message_queue_broadcaster=get_bool_env_var(
                "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
            ),
            group_name="pdmux_prefill_tp",
            pynccl_use_current_stream=True,
        )
        # NOTE: Check pynccl_comm exists before accessing it (may be None in sglang 0.5.9)
        if parallel_state._TP.pynccl_comm is not None:
            parallel_state._TP.pynccl_comm.disabled = False
        if parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm is not None:
            parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False

    moe_ep_size = expert_model_parallel_size

    moe_tp_size = tensor_model_parallel_size // moe_ep_size
    assert (
        parallel_state._MOE_EP is None
    ), "expert model parallel group is already initialized"
    group_ranks = []
    for i in range(num_tensor_model_parallel_groups):
        for j in range(moe_tp_size):
            st = i * tensor_model_parallel_size + j
            en = (i + 1) * tensor_model_parallel_size + j
            ranks = list(range(st, en, moe_tp_size))
            group_ranks.append(ranks)

    parallel_state._MOE_EP = init_model_parallel_group(
        group_ranks,
        parallel_state._WORLD.local_rank,
        backend,
        use_custom_allreduce=False,
        group_name="moe_ep",
    )

    assert (
        parallel_state._MOE_TP is None
    ), "moe tensor model parallel group is already initialized"
    if moe_ep_size == 1:
        parallel_state._MOE_TP = parallel_state._TP
    else:
        group_ranks = []
        for i in range(num_tensor_model_parallel_groups):
            for j in range(moe_ep_size):
                st = i * tensor_model_parallel_size + j * moe_tp_size
                en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
                ranks = list(range(st, en))
                group_ranks.append(ranks)
        parallel_state._MOE_TP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            use_custom_allreduce=False,
            group_name="moe_tp",
        )

    # Build the pipeline model-parallel groups.
    num_pipeline_model_parallel_groups: int = (
        dist.get_world_size() // pipeline_model_parallel_size
    )
    assert (
        parallel_state._PP is None
    ), "pipeline model parallel group is already initialized"
    group_ranks = []
    for i in range(num_pipeline_model_parallel_groups):
        ranks = list(
            range(i, dist.get_world_size(), num_pipeline_model_parallel_groups)
        )
        group_ranks.append(ranks)
    # pipeline parallel does not need custom allreduce
    parallel_state._PP = init_model_parallel_group(
        group_ranks,
        parallel_state._WORLD.local_rank,
        backend,
        use_custom_allreduce=False,
        group_name="pp",
    )

    # NOTE: Added for sglang 0.5.9 - Initialize attention parallel groups
    # These are required by get_attention_tp_group() and get_attention_cp_group()
    from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP

    attn_dp_size = attention_data_parallel_size
    attn_cp_size = attention_context_model_parallel_size
    attn_tp_size = tensor_model_parallel_size // attn_cp_size // attn_dp_size

    # Initialize _ATTN_CP (attention context parallel group)
    if not hasattr(parallel_state, "_ATTN_CP"):
        parallel_state._ATTN_CP = None
    assert (
        parallel_state._ATTN_CP is None
    ), "attention context model parallel group is already initialized"
    if attn_cp_size == tensor_model_parallel_size:
        parallel_state._ATTN_CP = parallel_state._TP
    else:
        group_ranks = []
        for tp_group_idx in range(num_tensor_model_parallel_groups):
            for dp_idx in range(attn_dp_size):
                for attn_tp_idx in range(attn_tp_size):
                    st = (
                        tp_group_idx * tensor_model_parallel_size
                        + dp_idx * attn_tp_size * attn_cp_size
                        + attn_tp_idx
                    )
                    en = (
                        tp_group_idx * tensor_model_parallel_size
                        + (dp_idx + 1) * attn_tp_size * attn_cp_size
                        + attn_tp_idx
                    )
                    ranks = list(range(st, en, attn_tp_size))
                    group_ranks.append(ranks)
        parallel_state._ATTN_CP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            group_name="attn_cp",
        )

    # Initialize _ATTN_TP (attention tensor parallel group)
    if not hasattr(parallel_state, "_ATTN_TP"):
        parallel_state._ATTN_TP = None
    assert (
        parallel_state._ATTN_TP is None
    ), "attention tensor model parallel group is already initialized"
    if attn_tp_size == tensor_model_parallel_size:
        parallel_state._ATTN_TP = parallel_state._TP
    else:
        group_ranks = []
        for tp_group_idx in range(num_tensor_model_parallel_groups):
            for cp_dp_combined_idx in range(attn_cp_size * attn_dp_size):
                st = (
                    tp_group_idx * tensor_model_parallel_size
                    + cp_dp_combined_idx * attn_tp_size
                )
                en = (
                    tp_group_idx * tensor_model_parallel_size
                    + (cp_dp_combined_idx + 1) * attn_tp_size
                )
                ranks = list(range(st, en))
                group_ranks.append(ranks)
        parallel_state._ATTN_TP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
            use_mscclpp_allreduce=False,
            use_custom_allreduce=False,
            use_torch_symm_mem_allreduce=False,
            group_name="attention_tp",
        )

    # Initialize _MOE_DP (moe data parallel group)
    if not hasattr(parallel_state, "_MOE_DP"):
        parallel_state._MOE_DP = None
    assert (
        parallel_state._MOE_DP is None
    ), "moe data parallel group is already initialized"
    moe_dp_size = moe_data_model_parallel_size
    moe_tp_size_for_dp = tensor_model_parallel_size // moe_ep_size // moe_dp_size
    if moe_dp_size == tensor_model_parallel_size:
        parallel_state._MOE_DP = parallel_state._TP
    else:
        group_ranks = []
        for tp_group_idx in range(num_tensor_model_parallel_groups):
            for tp_ep_combined_idx in range(moe_tp_size_for_dp * moe_ep_size):
                st = tp_group_idx * tensor_model_parallel_size + tp_ep_combined_idx
                en = (
                    tp_group_idx + 1
                ) * tensor_model_parallel_size + tp_ep_combined_idx
                ranks = list(range(st, en, moe_tp_size_for_dp * moe_ep_size))
                group_ranks.append(ranks)
        parallel_state._MOE_DP = init_model_parallel_group(
            group_ranks,
            parallel_state._WORLD.local_rank,
            backend,
            group_name="moe_dp",
        )


def initialize_dp_attention(
    server_args: ServerArgs,
    model_config: ModelConfig,
):
    """
    Initialize data parallel attention.

    Updated for sglang 0.5.9:
    - Added attn_cp_size parameter support
    - Removed _ATTN_TP_GROUP creation (now handled by initialize_model_parallel in sglang 0.5.9)
    """
    import sglang.srt.layers.dp_attention as dp_attention

    enable_dp_attention = server_args.enable_dp_attention
    tp_size = server_args.tp_size
    dp_size = server_args.dp_size
    moe_dense_tp_size = server_args.moe_dense_tp_size
    pp_size = server_args.pp_size
    # NOTE: attn_cp_size is new in sglang 0.5.9
    attn_cp_size = getattr(server_args, "attn_cp_size", 1)

    tp_rank = parallel_state.get_tensor_model_parallel_rank()

    dp_attention._ENABLE_DP_ATTENTION_FLAG = enable_dp_attention

    # NOTE: Added attn_cp_size parameter for sglang 0.5.9
    (
        dp_attention._ATTN_TP_RANK,
        dp_attention._ATTN_TP_SIZE,
        dp_attention._ATTN_DP_RANK,
    ) = compute_dp_attention_world_info(
        enable_dp_attention, tp_rank, tp_size, dp_size, attn_cp_size
    )
    _, _, dp_attention._LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
        enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
    )

    if enable_dp_attention:
        dp_attention._ATTN_DP_SIZE = dp_size
        if moe_dense_tp_size is None:
            dp_attention._LOCAL_ATTN_DP_SIZE = dp_attention._ATTN_DP_SIZE
        else:
            dp_attention._LOCAL_ATTN_DP_SIZE = max(
                1, dp_size // (tp_size // moe_dense_tp_size)
            )
    else:
        dp_attention._ATTN_DP_SIZE = 1
        dp_attention._LOCAL_ATTN_DP_SIZE = 1

    # NOTE: In sglang 0.5.9, _ATTN_TP_GROUP is created in initialize_model_parallel.
    # We no longer need to manually create it here to avoid conflicts.
    # The assertion error occurs because we were trying to recreate an already-initialized group.

    _DpGatheredBufferWrapper.set_metadata(
        hidden_size=model_config.hidden_size,
        dtype=model_config.dtype,
        device=torch.device(server_args.device),
    )