import logging import os import torch from sglang.srt.distributed import ( get_pp_group, get_tp_group, get_world_group, set_custom_all_reduce, set_mscclpp_all_reduce, set_torch_symm_mem_all_reduce, ) from sglang.srt.layers.dp_attention import ( get_attention_tp_group, initialize_dp_attention, ) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.utils import ( cpu_has_amx_support, get_available_gpu_memory, get_bool_env_var, is_hip, is_npu, monkey_patch_p2p_access_check, ) from .patch import ( init_distributed_environment, initialize_dp_attention, initialize_model_parallel, ) _is_hip = is_hip() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() # Use a small KV cache pool size for tests in CI SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) # Detect stragger ranks in model loading UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 logger = logging.getLogger(__name__) class SGLangRunner(ModelRunner): def init_torch_distributed(self): logger.info("Init torch distributed begin.") try: torch.get_device_module(self.device).set_device(self.gpu_id) except Exception: logger.warning( f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}" ) raise if self.device == "cuda": if self.server_args.elastic_ep_backend == "mooncake": backend = "mooncake" if self.server_args.mooncake_ib_device: mooncake_ib_device = self.server_args.mooncake_ib_device.split(",") try: from mooncake import ep as mooncake_ep mooncake_ep.set_device_filter(mooncake_ib_device) except: pass # A warning will be raised in `init_distributed_environment` else: backend = "nccl" elif self.device == "xpu": backend = "xccl" elif self.device == "hpu": backend = "hccl" elif self.device == "cpu": backend = "gloo" elif self.device == "npu": backend = "hccl" before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) if not self.server_args.enable_p2p_check: monkey_patch_p2p_access_check() if self.server_args.dist_init_addr: dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_mscclpp_all_reduce(self.server_args.enable_mscclpp) set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) if not self.is_draft_worker: if self.device == "cpu": if _is_cpu_amx_available: # Bind OpenMP threads to CPU cores torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) # Set local size to hint SGLang to use shared memory based AllReduce os.environ["LOCAL_SIZE"] = str(self.tp_size) torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) @torch.library.register_fake("sgl_kernel::shm_allgather") def _(data, dim): return torch.cat([data] * self.tp_size, dim=dim) else: logger.warning( "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" ) # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, world_size=self.tp_size * self.pp_size, rank=self.tp_size * self.pp_rank + self.tp_rank, local_rank=self.gpu_id, ) # NOTE: Updated for sglang 0.5.9 # - Removed torch_compile parameter (no longer supported) # - Added new parameters: attention_data_parallel_size, attention_context_model_parallel_size, moe_data_model_parallel_size # Debug: Print the values dp_size = getattr(self.server_args, "dp_size", 1) attn_cp_size = getattr(self.server_args, "attn_cp_size", 1) moe_dp_size = getattr(self.server_args, "moe_dp_size", 1) print( f"[DEBUG] tp_size={self.tp_size}, dp_size={dp_size}, attn_cp_size={attn_cp_size}, moe_dp_size={moe_dp_size}" ) initialize_model_parallel( tensor_model_parallel_size=self.tp_size, pipeline_model_parallel_size=self.pp_size, expert_model_parallel_size=self.moe_ep_size, attention_data_parallel_size=dp_size, attention_context_model_parallel_size=attn_cp_size, moe_data_model_parallel_size=moe_dp_size, duplicate_tp_group=self.server_args.enable_pdmux, ) initialize_dp_attention( server_args=self.server_args, model_config=self.model_config, ) min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=get_world_group().world_size > 1, cpu_group=get_world_group().cpu_group, ) self.tp_group = get_tp_group() self.pp_group = get_pp_group() self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) if self.tp_size > 1 and not self.is_draft_worker: if min_per_gpu_memory < local_gpu_memory * 0.9: if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): logger.warning( "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" ) else: raise ValueError( "The memory capacity is unbalanced. Some GPUs may be occupied by other processes. " f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}" ) logger.info( f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB" ) return min_per_gpu_memory