Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import logging
from typing import Optional
import sglang.srt.distributed.parallel_state as parallel_state
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import init_model_parallel_group
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.layers.dp_attention import (
_DpGatheredBufferWrapper,
compute_dp_attention_local_info,
compute_dp_attention_world_info,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var
from specforge.distributed import get_tp_group as get_specforge_tp_group
logger = logging.getLogger(__name__)
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
local_rank: int = -1,
backend: str = "nccl",
):
logger.debug(
"world_size=%d rank=%d backend=%s",
world_size,
rank,
backend,
)
assert (
torch.distributed.is_initialized()
), "distributed environment should be initialized first"
tp_group = get_specforge_tp_group()
world_size = dist.get_world_size()
tp_size = dist.get_world_size(tp_group)
num_tp_groups = world_size // tp_size
tp_ranks = []
for i in range(num_tp_groups):
tp_ranks.append(list(range(i * tp_size, (i + 1) * tp_size)))
parallel_state._WORLD = GroupCoordinator(
group_ranks=tp_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_pymscclpp=False,
use_custom_allreduce=False,
use_torch_symm_mem_all_reduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="world",
)
# we destroy the newly created world group and replace it
# with the existing tp group from specforge to save CUDA memory
group_to_destroy = parallel_state._WORLD.device_group
parallel_state._WORLD.device_group = tp_group
dist.destroy_process_group(group_to_destroy)
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
duplicate_tp_group: bool = False,
torch_compile: Optional[bool] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = parallel_state._WORLD.world_size
backend = backend or dist.get_backend(parallel_state._WORLD.device_group)
if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (
dist.get_world_size() // tensor_model_parallel_size
)
assert (
parallel_state._TP is None
), "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
parallel_state._TP = init_model_parallel_group(
group_ranks,
parallel_state._WORLD.local_rank,
backend,
use_message_queue_broadcaster=get_bool_env_var(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
torch_compile=torch_compile,
)
if duplicate_tp_group:
assert (
parallel_state._PDMUX_PREFILL_TP_GROUP is None
), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
assert (
parallel_state._PDMUX_PREFILL_TP_GROUP is None
), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
parallel_state._PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
group_ranks,
parallel_state._WORLD.local_rank,
backend,
use_message_queue_broadcaster=get_bool_env_var(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
torch_compile=torch_compile,
)
parallel_state._TP.pynccl_comm.disabled = False
parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size
assert (
parallel_state._MOE_EP is None
), "expert model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_tp_size):
st = i * tensor_model_parallel_size + j
en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size))
group_ranks.append(ranks)
parallel_state._MOE_EP = init_model_parallel_group(
group_ranks,
parallel_state._WORLD.local_rank,
backend,
use_custom_allreduce=False,
group_name="moe_ep",
)
assert (
parallel_state._MOE_TP is None
), "moe tensor model parallel group is already initialized"
if moe_ep_size == 1:
parallel_state._MOE_TP = parallel_state._TP
else:
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_ep_size):
st = i * tensor_model_parallel_size + j * moe_tp_size
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
ranks = list(range(st, en))
group_ranks.append(ranks)
parallel_state._MOE_TP = init_model_parallel_group(
group_ranks,
parallel_state._WORLD.local_rank,
backend,
use_custom_allreduce=False,
group_name="moe_tp",
)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (
dist.get_world_size() // pipeline_model_parallel_size
)
assert (
parallel_state._PP is None
), "pipeline model parallel group is already initialized"
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(
range(i, dist.get_world_size(), num_pipeline_model_parallel_groups)
)
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
parallel_state._PP = init_model_parallel_group(
group_ranks,
parallel_state._WORLD.local_rank,
backend,
use_custom_allreduce=False,
group_name="pp",
)
def initialize_dp_attention(
server_args: ServerArgs,
model_config: ModelConfig,
):
import sglang.srt.layers.dp_attention as dp_attention
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
enable_dp_attention = server_args.enable_dp_attention
tp_size = server_args.tp_size
dp_size = server_args.dp_size
moe_dense_tp_size = server_args.moe_dense_tp_size
pp_size = server_args.pp_size
tp_rank = parallel_state.get_tensor_model_parallel_rank()
dp_attention._ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
(
dp_attention._ATTN_TP_RANK,
dp_attention._ATTN_TP_SIZE,
dp_attention._ATTN_DP_RANK,
) = compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size)
_, _, dp_attention._LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
)
if enable_dp_attention:
dp_attention._ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
dp_attention._LOCAL_ATTN_DP_SIZE = dp_attention._ATTN_DP_SIZE
else:
dp_attention._LOCAL_ATTN_DP_SIZE = max(
1, dp_size // (tp_size // moe_dense_tp_size)
)
else:
dp_attention._ATTN_DP_SIZE = 1
dp_attention._LOCAL_ATTN_DP_SIZE = 1
tp_group = parallel_state.get_tp_group()
num_model_parallel_groups = dist.get_world_size() // (pp_size * tp_size)
mp_size = pp_size * tp_size
group_ranks = []
for i in range(num_model_parallel_groups):
ranks = [
list(range(head, head + dp_attention._ATTN_TP_SIZE))
for head in range(
mp_size * i, mp_size * (i + 1), dp_attention._ATTN_TP_SIZE
)
]
group_ranks.extend(ranks)
dp_attention._ATTN_TP_GROUP = GroupCoordinator(
group_ranks,
tp_group.local_rank,
torch.distributed.get_backend(tp_group.device_group),
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,
use_custom_allreduce=False,
use_torch_symm_mem_all_reduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="attention_tp",
)
# print(f"{parallel_state._ATTN_TP_GROUP=}")
_DpGatheredBufferWrapper.set_metadata(
hidden_size=model_config.hidden_size,
dtype=model_config.dtype,
device=torch.device(server_args.device),
)