import logging from typing import Optional import sglang.srt.distributed.parallel_state as parallel_state import torch import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import init_model_parallel_group from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.layers.dp_attention import ( _DpGatheredBufferWrapper, compute_dp_attention_local_info, compute_dp_attention_world_info, ) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var from specforge.distributed import get_tp_group as get_specforge_tp_group logger = logging.getLogger(__name__) def init_distributed_environment( world_size: int = -1, rank: int = -1, local_rank: int = -1, backend: str = "nccl", ): logger.debug( "world_size=%d rank=%d backend=%s", world_size, rank, backend, ) assert ( torch.distributed.is_initialized() ), "distributed environment should be initialized first" tp_group = get_specforge_tp_group() world_size = dist.get_world_size() tp_size = dist.get_world_size(tp_group) num_tp_groups = world_size // tp_size tp_ranks = [] for i in range(num_tp_groups): tp_ranks.append(list(range(i * tp_size, (i + 1) * tp_size))) parallel_state._WORLD = GroupCoordinator( group_ranks=tp_ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, use_pymscclpp=False, use_custom_allreduce=False, use_torch_symm_mem_all_reduce=False, use_hpu_communicator=False, use_xpu_communicator=False, use_npu_communicator=False, group_name="world", ) # we destroy the newly created world group and replace it # with the existing tp group from specforge to save CUDA memory group_to_destroy = parallel_state._WORLD.device_group parallel_state._WORLD.device_group = tp_group dist.destroy_process_group(group_to_destroy) def initialize_model_parallel( tensor_model_parallel_size: int = 1, expert_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, attention_data_parallel_size: int = 1, attention_context_model_parallel_size: int = 1, moe_data_model_parallel_size: int = 1, backend: Optional[str] = None, duplicate_tp_group: bool = False, # NOTE: torch_compile parameter was removed in sglang 0.5.9 # torch_compile: Optional[bool] = None, ) -> None: """ Initialize model parallel groups. Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. attention_data_parallel_size: number of GPUs used for attention data parallelism. (Added in sglang 0.5.9) attention_context_model_parallel_size: number of GPUs used for attention context parallelism. (Added in sglang 0.5.9) moe_data_model_parallel_size: number of GPUs used for moe data parallelism. (Added in sglang 0.5.9) Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: 4 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 pipeline model-parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = parallel_state._WORLD.world_size backend = backend or dist.get_backend(parallel_state._WORLD.device_group) if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" ) # Build the tensor model-parallel groups. num_tensor_model_parallel_groups: int = ( dist.get_world_size() // tensor_model_parallel_size ) assert ( parallel_state._TP is None ), "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): ranks = list( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) ) group_ranks.append(ranks) # message queue broadcaster is only used in tensor model parallel group # NOTE: torch_compile parameter was removed in sglang 0.5.9 parallel_state._TP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, use_message_queue_broadcaster=get_bool_env_var( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="tp", pynccl_use_current_stream=duplicate_tp_group, ) if duplicate_tp_group: assert ( parallel_state._PDMUX_PREFILL_TP_GROUP is None ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" # NOTE: torch_compile parameter was removed in sglang 0.5.9 parallel_state._PDMUX_PREFILL_TP_GROUP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, use_message_queue_broadcaster=get_bool_env_var( "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" ), group_name="pdmux_prefill_tp", pynccl_use_current_stream=True, ) # NOTE: Check pynccl_comm exists before accessing it (may be None in sglang 0.5.9) if parallel_state._TP.pynccl_comm is not None: parallel_state._TP.pynccl_comm.disabled = False if parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm is not None: parallel_state._PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False moe_ep_size = expert_model_parallel_size moe_tp_size = tensor_model_parallel_size // moe_ep_size assert ( parallel_state._MOE_EP is None ), "expert model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): for j in range(moe_tp_size): st = i * tensor_model_parallel_size + j en = (i + 1) * tensor_model_parallel_size + j ranks = list(range(st, en, moe_tp_size)) group_ranks.append(ranks) parallel_state._MOE_EP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, use_custom_allreduce=False, group_name="moe_ep", ) assert ( parallel_state._MOE_TP is None ), "moe tensor model parallel group is already initialized" if moe_ep_size == 1: parallel_state._MOE_TP = parallel_state._TP else: group_ranks = [] for i in range(num_tensor_model_parallel_groups): for j in range(moe_ep_size): st = i * tensor_model_parallel_size + j * moe_tp_size en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size ranks = list(range(st, en)) group_ranks.append(ranks) parallel_state._MOE_TP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, use_custom_allreduce=False, group_name="moe_tp", ) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = ( dist.get_world_size() // pipeline_model_parallel_size ) assert ( parallel_state._PP is None ), "pipeline model parallel group is already initialized" group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list( range(i, dist.get_world_size(), num_pipeline_model_parallel_groups) ) group_ranks.append(ranks) # pipeline parallel does not need custom allreduce parallel_state._PP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, use_custom_allreduce=False, group_name="pp", ) # NOTE: Added for sglang 0.5.9 - Initialize attention parallel groups # These are required by get_attention_tp_group() and get_attention_cp_group() from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP attn_dp_size = attention_data_parallel_size attn_cp_size = attention_context_model_parallel_size attn_tp_size = tensor_model_parallel_size // attn_cp_size // attn_dp_size # Initialize _ATTN_CP (attention context parallel group) if not hasattr(parallel_state, "_ATTN_CP"): parallel_state._ATTN_CP = None assert ( parallel_state._ATTN_CP is None ), "attention context model parallel group is already initialized" if attn_cp_size == tensor_model_parallel_size: parallel_state._ATTN_CP = parallel_state._TP else: group_ranks = [] for tp_group_idx in range(num_tensor_model_parallel_groups): for dp_idx in range(attn_dp_size): for attn_tp_idx in range(attn_tp_size): st = ( tp_group_idx * tensor_model_parallel_size + dp_idx * attn_tp_size * attn_cp_size + attn_tp_idx ) en = ( tp_group_idx * tensor_model_parallel_size + (dp_idx + 1) * attn_tp_size * attn_cp_size + attn_tp_idx ) ranks = list(range(st, en, attn_tp_size)) group_ranks.append(ranks) parallel_state._ATTN_CP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, group_name="attn_cp", ) # Initialize _ATTN_TP (attention tensor parallel group) if not hasattr(parallel_state, "_ATTN_TP"): parallel_state._ATTN_TP = None assert ( parallel_state._ATTN_TP is None ), "attention tensor model parallel group is already initialized" if attn_tp_size == tensor_model_parallel_size: parallel_state._ATTN_TP = parallel_state._TP else: group_ranks = [] for tp_group_idx in range(num_tensor_model_parallel_groups): for cp_dp_combined_idx in range(attn_cp_size * attn_dp_size): st = ( tp_group_idx * tensor_model_parallel_size + cp_dp_combined_idx * attn_tp_size ) en = ( tp_group_idx * tensor_model_parallel_size + (cp_dp_combined_idx + 1) * attn_tp_size ) ranks = list(range(st, en)) group_ranks.append(ranks) parallel_state._ATTN_TP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, use_mscclpp_allreduce=False, use_custom_allreduce=False, use_torch_symm_mem_allreduce=False, group_name="attention_tp", ) # Initialize _MOE_DP (moe data parallel group) if not hasattr(parallel_state, "_MOE_DP"): parallel_state._MOE_DP = None assert ( parallel_state._MOE_DP is None ), "moe data parallel group is already initialized" moe_dp_size = moe_data_model_parallel_size moe_tp_size_for_dp = tensor_model_parallel_size // moe_ep_size // moe_dp_size if moe_dp_size == tensor_model_parallel_size: parallel_state._MOE_DP = parallel_state._TP else: group_ranks = [] for tp_group_idx in range(num_tensor_model_parallel_groups): for tp_ep_combined_idx in range(moe_tp_size_for_dp * moe_ep_size): st = tp_group_idx * tensor_model_parallel_size + tp_ep_combined_idx en = ( tp_group_idx + 1 ) * tensor_model_parallel_size + tp_ep_combined_idx ranks = list(range(st, en, moe_tp_size_for_dp * moe_ep_size)) group_ranks.append(ranks) parallel_state._MOE_DP = init_model_parallel_group( group_ranks, parallel_state._WORLD.local_rank, backend, group_name="moe_dp", ) def initialize_dp_attention( server_args: ServerArgs, model_config: ModelConfig, ): """ Initialize data parallel attention. Updated for sglang 0.5.9: - Added attn_cp_size parameter support - Removed _ATTN_TP_GROUP creation (now handled by initialize_model_parallel in sglang 0.5.9) """ import sglang.srt.layers.dp_attention as dp_attention enable_dp_attention = server_args.enable_dp_attention tp_size = server_args.tp_size dp_size = server_args.dp_size moe_dense_tp_size = server_args.moe_dense_tp_size pp_size = server_args.pp_size # NOTE: attn_cp_size is new in sglang 0.5.9 attn_cp_size = getattr(server_args, "attn_cp_size", 1) tp_rank = parallel_state.get_tensor_model_parallel_rank() dp_attention._ENABLE_DP_ATTENTION_FLAG = enable_dp_attention # NOTE: Added attn_cp_size parameter for sglang 0.5.9 ( dp_attention._ATTN_TP_RANK, dp_attention._ATTN_TP_SIZE, dp_attention._ATTN_DP_RANK, ) = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size, attn_cp_size ) _, _, dp_attention._LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info( enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size ) if enable_dp_attention: dp_attention._ATTN_DP_SIZE = dp_size if moe_dense_tp_size is None: dp_attention._LOCAL_ATTN_DP_SIZE = dp_attention._ATTN_DP_SIZE else: dp_attention._LOCAL_ATTN_DP_SIZE = max( 1, dp_size // (tp_size // moe_dense_tp_size) ) else: dp_attention._ATTN_DP_SIZE = 1 dp_attention._LOCAL_ATTN_DP_SIZE = 1 # NOTE: In sglang 0.5.9, _ATTN_TP_GROUP is created in initialize_model_parallel. # We no longer need to manually create it here to avoid conflicts. # The assertion error occurs because we were trying to recreate an already-initialized group. _DpGatheredBufferWrapper.set_metadata( hidden_size=model_config.hidden_size, dtype=model_config.dtype, device=torch.device(server_args.device), )