Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
import logging
import os
import torch
from sglang.srt.distributed import (
get_pp_group,
get_tp_group,
get_world_group,
set_custom_all_reduce,
set_mscclpp_all_reduce,
set_torch_symm_mem_all_reduce,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
initialize_dp_attention,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.utils import (
cpu_has_amx_support,
get_available_gpu_memory,
get_bool_env_var,
is_hip,
is_npu,
monkey_patch_p2p_access_check,
)
from .patch import (
init_distributed_environment,
initialize_dp_attention,
initialize_model_parallel,
)
_is_hip = is_hip()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
# Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__)
class SGLangRunner(ModelRunner):
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
try:
torch.get_device_module(self.device).set_device(self.gpu_id)
except Exception:
logger.warning(
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
)
raise
if self.device == "cuda":
if self.server_args.elastic_ep_backend == "mooncake":
backend = "mooncake"
if self.server_args.mooncake_ib_device:
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
try:
from mooncake import ep as mooncake_ep
mooncake_ep.set_device_filter(mooncake_ib_device)
except:
pass # A warning will be raised in `init_distributed_environment`
else:
backend = "nccl"
elif self.device == "xpu":
backend = "xccl"
elif self.device == "hpu":
backend = "hccl"
elif self.device == "cpu":
backend = "gloo"
elif self.device == "npu":
backend = "hccl"
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check:
monkey_patch_p2p_access_check()
if self.server_args.dist_init_addr:
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
else:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
if not self.is_draft_worker:
if self.device == "cpu":
if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
@torch.library.register_fake("sgl_kernel::shm_allgather")
def _(data, dim):
return torch.cat([data] * self.tp_size, dim=dim)
else:
logger.warning(
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
)
# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
world_size=self.tp_size * self.pp_size,
rank=self.tp_size * self.pp_rank + self.tp_rank,
local_rank=self.gpu_id,
)
# NOTE: Updated for sglang 0.5.9
# - Removed torch_compile parameter (no longer supported)
# - Added new parameters: attention_data_parallel_size, attention_context_model_parallel_size, moe_data_model_parallel_size
# Debug: Print the values
dp_size = getattr(self.server_args, "dp_size", 1)
attn_cp_size = getattr(self.server_args, "attn_cp_size", 1)
moe_dp_size = getattr(self.server_args, "moe_dp_size", 1)
print(
f"[DEBUG] tp_size={self.tp_size}, dp_size={dp_size}, attn_cp_size={attn_cp_size}, moe_dp_size={moe_dp_size}"
)
initialize_model_parallel(
tensor_model_parallel_size=self.tp_size,
pipeline_model_parallel_size=self.pp_size,
expert_model_parallel_size=self.moe_ep_size,
attention_data_parallel_size=dp_size,
attention_context_model_parallel_size=attn_cp_size,
moe_data_model_parallel_size=moe_dp_size,
duplicate_tp_group=self.server_args.enable_pdmux,
)
initialize_dp_attention(
server_args=self.server_args,
model_config=self.model_config,
)
min_per_gpu_memory = get_available_gpu_memory(
self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)
self.tp_group = get_tp_group()
self.pp_group = get_pp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if self.tp_size > 1 and not self.is_draft_worker:
if min_per_gpu_memory < local_gpu_memory * 0.9:
if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
logger.warning(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
)
else:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
)
logger.info(
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
)
return min_per_gpu_memory