File size: 6,813 Bytes
7a60a87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import logging
import os
import torch
from sglang.srt.distributed import (
get_pp_group,
get_tp_group,
get_world_group,
set_custom_all_reduce,
set_mscclpp_all_reduce,
set_torch_symm_mem_all_reduce,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
initialize_dp_attention,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.utils import (
cpu_has_amx_support,
get_available_gpu_memory,
get_bool_env_var,
is_hip,
is_npu,
monkey_patch_p2p_access_check,
)
from .patch import (
init_distributed_environment,
initialize_dp_attention,
initialize_model_parallel,
)
_is_hip = is_hip()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
# Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__)
class SGLangRunner(ModelRunner):
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
try:
torch.get_device_module(self.device).set_device(self.gpu_id)
except Exception:
logger.warning(
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
)
raise
if self.device == "cuda":
if self.server_args.elastic_ep_backend == "mooncake":
backend = "mooncake"
if self.server_args.mooncake_ib_device:
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
try:
from mooncake import ep as mooncake_ep
mooncake_ep.set_device_filter(mooncake_ib_device)
except:
pass # A warning will be raised in `init_distributed_environment`
else:
backend = "nccl"
elif self.device == "xpu":
backend = "xccl"
elif self.device == "hpu":
backend = "hccl"
elif self.device == "cpu":
backend = "gloo"
elif self.device == "npu":
backend = "hccl"
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check:
monkey_patch_p2p_access_check()
if self.server_args.dist_init_addr:
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
else:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
if not self.is_draft_worker:
if self.device == "cpu":
if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
@torch.library.register_fake("sgl_kernel::shm_allgather")
def _(data, dim):
return torch.cat([data] * self.tp_size, dim=dim)
else:
logger.warning(
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
)
# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size * self.pp_size,
rank=self.tp_size * self.pp_rank + self.tp_rank,
local_rank=self.gpu_id,
)
# NOTE: Updated for sglang 0.5.9
# - Removed torch_compile parameter (no longer supported)
# - Added new parameters: attention_data_parallel_size, attention_context_model_parallel_size, moe_data_model_parallel_size
# Debug: Print the values
dp_size = getattr(self.server_args, "dp_size", 1)
attn_cp_size = getattr(self.server_args, "attn_cp_size", 1)
moe_dp_size = getattr(self.server_args, "moe_dp_size", 1)
print(
f"[DEBUG] tp_size={self.tp_size}, dp_size={dp_size}, attn_cp_size={attn_cp_size}, moe_dp_size={moe_dp_size}"
)
initialize_model_parallel(
tensor_model_parallel_size=self.tp_size,
pipeline_model_parallel_size=self.pp_size,
expert_model_parallel_size=self.moe_ep_size,
attention_data_parallel_size=dp_size,
attention_context_model_parallel_size=attn_cp_size,
moe_data_model_parallel_size=moe_dp_size,
duplicate_tp_group=self.server_args.enable_pdmux,
)
initialize_dp_attention(
server_args=self.server_args,
model_config=self.model_config,
)
min_per_gpu_memory = get_available_gpu_memory(
self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)
self.tp_group = get_tp_group()
self.pp_group = get_pp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if self.tp_size > 1 and not self.is_draft_worker:
if min_per_gpu_memory < local_gpu_memory * 0.9:
if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
logger.warning(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
)
else:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
)
logger.info(
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
)
return min_per_gpu_memory
|