Hanrui / sglang /test /registered /distributed /test_epd_disaggregation.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
import os
import subprocess
import threading
import time
import unittest
import grpc
import zmq
from grpc_health.v1 import health_pb2, health_pb2_grpc
from sglang.srt.utils import get_zmq_socket_on_host, kill_process_tree
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.kits.mmmu_vlm_kit import _run_lmms_eval_with_retry
from sglang.test.server_fixtures.disaggregation_fixture import (
PDDisaggregationServerBase,
)
from sglang.test.test_utils import (
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
is_in_ci,
popen_launch_server,
)
register_cuda_ci(est_time=150, suite="stage-c-test-4-gpu-h100")
@unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime")
class TestEPDDisaggregationOneEncoder(PDDisaggregationServerBase):
"""Test EPD disaggregation with single encode server"""
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
cls.encode_port = f"{int(cls.lb_port) + 300}"
cls.encode_url = f"http://{cls.base_host}:{cls.encode_port}"
print(
f"Setting up EPD (one encoder): encode={cls.encode_port}, "
f"prefill={cls.prefill_port}, decode={cls.decode_port}"
)
# Start servers in order: encode -> prefill/decode
cls.start_encode()
prefill_thread = threading.Thread(target=cls.start_prefill)
decode_thread = threading.Thread(target=cls.start_decode)
prefill_thread.start()
decode_thread.start()
prefill_thread.join()
decode_thread.join()
# Wait for all servers to be ready
cls.wait_server_ready(cls.encode_url + "/health", process=cls.process_encode)
cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill)
cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode)
cls.launch_lb()
# Set OpenAI API key and base URL environment variables. Needed for lmms-eval to work.
cls.api_key = "sk-123456"
os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1"
@classmethod
def start_encode(cls):
"""Start encode server for multimodal processing"""
encode_args = [
"--trust-remote-code",
"--encoder-only",
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--tp",
"1",
"--port",
cls.encode_port,
"--enable-prefix-mm-cache",
]
cls.process_encode = popen_launch_server(
cls.model,
base_url=cls.encode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=encode_args,
)
@classmethod
def start_prefill(cls):
"""Start prefill server with language model only"""
prefill_args = [
"--trust-remote-code",
"--language-only",
"--encoder-urls",
cls.encode_url,
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"--base-gpu-id",
"1",
"--port",
cls.prefill_port,
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_server(
cls.model,
base_url=cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
"""Start decode server"""
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"2",
"--port",
cls.decode_port,
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_server(
cls.model,
base_url=cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
"""Clean up all processes"""
for process in [
cls.process_lb,
cls.process_decode,
cls.process_prefill,
cls.process_encode,
]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"):
"""
Evaluate a VLM on the MMMU validation set with lmms-eval.
Reference: test_vlm_models.py
Args:
model_version: Model version/checkpoint to evaluate
output_path: Path to save evaluation results
limit: Number of samples to evaluate (default: "50" for CI time constraints)
"""
model = "openai_compatible"
tp = 1
tasks = "mmmu_val"
batch_size = 32
log_suffix = "openai_compatible"
os.makedirs(output_path, exist_ok=True)
model_args = f'model_version="{model_version}",tp={tp}'
cmd = [
"python3",
"-m",
"lmms_eval",
"--model",
model,
"--model_args",
model_args,
"--tasks",
tasks,
"--batch_size",
str(batch_size),
"--log_samples",
"--log_samples_suffix",
log_suffix,
"--output_path",
str(output_path),
"--limit",
limit,
]
_run_lmms_eval_with_retry(cmd, timeout=3600)
def test_mmmu(self):
"""Test MMMU evaluation with EPD disaggregation"""
import glob
import json
output_path = "./logs/epd_one_encoder_mmmu"
self.run_mmmu_eval(self.model, output_path)
# Get the result file
result_files = glob.glob(f"{output_path}/**/*.json", recursive=True)
if not result_files:
result_files = glob.glob(f"{output_path}/*.json")
if not result_files:
self.fail(f"No JSON result files found in {output_path}")
result_file_path = result_files[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"MMMU result: {result}")
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(f"MMMU accuracy: {mmmu_accuracy:.4f}")
# for qwen2.5-vl-3b-instruct, the accuracy is 0.40
self.assertGreater(mmmu_accuracy, 0.40)
class TestEPDDisaggregationMultiEncoders(PDDisaggregationServerBase):
"""
Test EPD disaggregation with multiple encode servers for load balancing.
Both encode servers run on GPU 0 (different ports) for testing load distribution.
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
cls.encode_port1 = f"{int(cls.lb_port) + 300}"
cls.encode_port2 = f"{int(cls.lb_port) + 301}"
cls.encode_url1 = f"http://{cls.base_host}:{cls.encode_port1}"
cls.encode_url2 = f"http://{cls.base_host}:{cls.encode_port2}"
print(
f"Setting up EPD (multiple encoders): encode1={cls.encode_port1}, "
f"encode2={cls.encode_port2}, prefill={cls.prefill_port}, decode={cls.decode_port}"
)
# Start two encode servers on GPU 0/1
encode1_thread = threading.Thread(
target=cls.start_encode_server, args=(cls.encode_port1, 0)
)
encode2_thread = threading.Thread(
target=cls.start_encode_server, args=(cls.encode_port2, 1)
)
encode1_thread.start()
encode2_thread.start()
encode1_thread.join()
encode2_thread.join()
prefill_thread = threading.Thread(target=cls.start_prefill)
decode_thread = threading.Thread(target=cls.start_decode)
prefill_thread.start()
decode_thread.start()
prefill_thread.join()
decode_thread.join()
cls.wait_server_ready(cls.encode_url1 + "/health", process=cls.process_encode1)
cls.wait_server_ready(cls.encode_url2 + "/health", process=cls.process_encode2)
cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill)
cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode)
cls.launch_lb()
# Set OpenAI API key and base URL environment variables. Needed for lmms-eval to work.
cls.api_key = "sk-123456"
os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1"
@classmethod
def start_encode_server(cls, port, gpu_id):
"""Start an encode server on specific port and GPU"""
encode_args = [
"--trust-remote-code",
"--encoder-only",
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--tp",
"1",
"--port",
port,
"--enable-prefix-mm-cache",
]
# Only set base-gpu-id if not using GPU 0
if gpu_id != 0:
encode_args.extend(["--base-gpu-id", str(gpu_id)])
process = popen_launch_server(
cls.model,
base_url=f"http://{cls.base_host}:{port}",
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=encode_args,
)
if port == cls.encode_port1:
cls.process_encode1 = process
else:
cls.process_encode2 = process
@classmethod
def start_prefill(cls):
"""Start prefill server with multiple encode URLs"""
prefill_args = [
"--trust-remote-code",
"--language-only",
"--encoder-urls",
cls.encode_url1,
cls.encode_url2,
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"--base-gpu-id",
"2",
"--port",
cls.prefill_port,
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_server(
cls.model,
base_url=cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
"""Start decode server"""
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"3",
"--port",
cls.decode_port,
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_server(
cls.model,
base_url=cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
"""Clean up all processes"""
for process in [
cls.process_lb,
cls.process_decode,
cls.process_prefill,
cls.process_encode1,
cls.process_encode2,
]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"):
"""
Evaluate a VLM on the MMMU validation set with lmms-eval.
Reference: test_vlm_models.py
Args:
model_version: Model version/checkpoint to evaluate
output_path: Path to save evaluation results
limit: Number of samples to evaluate (default: "50" for CI time constraints)
"""
model = "openai_compatible"
tp = 1
tasks = "mmmu_val"
batch_size = 32
log_suffix = "openai_compatible"
os.makedirs(output_path, exist_ok=True)
model_args = f'model_version="{model_version}",tp={tp}'
cmd = [
"python3",
"-m",
"lmms_eval",
"--model",
model,
"--model_args",
model_args,
"--tasks",
tasks,
"--batch_size",
str(batch_size),
"--log_samples",
"--log_samples_suffix",
log_suffix,
"--output_path",
str(output_path),
"--limit",
limit,
]
_run_lmms_eval_with_retry(cmd, timeout=3600)
def test_mmmu(self):
"""Test MMMU evaluation with EPD disaggregation (multiple encoders)"""
import glob
import json
output_path = "./logs/epd_multi_encoder_mmmu"
self.run_mmmu_eval(self.model, output_path)
# Get the result file
result_files = glob.glob(f"{output_path}/**/*.json", recursive=True)
if not result_files:
result_files = glob.glob(f"{output_path}/*.json")
if not result_files:
self.fail(f"No JSON result files found in {output_path}")
result_file_path = result_files[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"MMMU result (multi encoder): {result}")
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(f"MMMU accuracy (multi encoder): {mmmu_accuracy:.4f}")
# for qwen2.5-vl-3b-instruct, the accuracy is 0.40
self.assertGreater(mmmu_accuracy, 0.40)
@unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime")
class TestEPDDisaggregationGrpcEncoderMMMU(PDDisaggregationServerBase):
"""Test MMMU evaluation with gRPC encoder in EPD mode."""
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
cls.encode_port = f"{int(cls.lb_port) + 304}"
cls.encode_url = f"grpc://{cls.base_host}:{cls.encode_port}"
print(
f"Setting up gRPC EPD (one encoder): encode={cls.encode_port}, "
f"prefill={cls.prefill_port}, decode={cls.decode_port}"
)
cls.start_encode()
prefill_thread = threading.Thread(target=cls.start_prefill)
decode_thread = threading.Thread(target=cls.start_decode)
prefill_thread.start()
decode_thread.start()
prefill_thread.join()
decode_thread.join()
cls.wait_grpc_ready(cls.base_host, cls.encode_port, cls.process_encode)
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
cls.api_key = "sk-123456"
os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1"
@classmethod
def start_encode(cls):
encode_command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
cls.model,
"--host",
cls.base_host,
"--port",
cls.encode_port,
"--trust-remote-code",
"--encoder-only",
"--grpc-mode",
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--tp",
"1",
"--base-gpu-id",
"0",
"--enable-prefix-mm-cache",
]
cls.process_encode = subprocess.Popen(encode_command)
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--language-only",
"--encoder-urls",
cls.encode_url,
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"--base-gpu-id",
"1",
"--port",
cls.prefill_port,
]
prefill_args += cls.transfer_backend + cls.rdma_devices
prefill_env = os.environ.copy()
prefill_env["SGLANG_ENCODER_MM_RECEIVER_MODE"] = "grpc"
cls.process_prefill = popen_launch_server(
cls.model,
base_url=cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
env=prefill_env,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"2",
"--port",
cls.decode_port,
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_server(
cls.model,
base_url=cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@staticmethod
def wait_grpc_ready(host, port, process, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
deadline = time.time() + timeout
channel = grpc.insecure_channel(f"{host}:{port}")
stub = health_pb2_grpc.HealthStub(channel)
try:
while time.time() < deadline:
if process.poll() is not None:
raise RuntimeError(
f"gRPC encoder server exited with code {process.returncode}"
)
try:
response = stub.Check(
health_pb2.HealthCheckRequest(service=""), timeout=2
)
if response.status == health_pb2.HealthCheckResponse.SERVING:
return
except grpc.RpcError:
pass
time.sleep(1)
finally:
channel.close()
raise RuntimeError(
f"gRPC encoder server not ready at {host}:{port} within {timeout}s"
)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_ENCODER_MM_RECEIVER_MODE", None)
os.environ.pop("OPENAI_API_KEY", None)
os.environ.pop("OPENAI_API_BASE", None)
for process in [
cls.process_lb,
cls.process_decode,
cls.process_prefill,
cls.process_encode,
]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"):
model = "openai_compatible"
tp = 1
tasks = "mmmu_val"
batch_size = 32
log_suffix = "openai_compatible"
os.makedirs(output_path, exist_ok=True)
model_args = f'model_version="{model_version}",tp={tp}'
cmd = [
"python3",
"-m",
"lmms_eval",
"--model",
model,
"--model_args",
model_args,
"--tasks",
tasks,
"--batch_size",
str(batch_size),
"--log_samples",
"--log_samples_suffix",
log_suffix,
"--output_path",
str(output_path),
"--limit",
limit,
]
_run_lmms_eval_with_retry(cmd, timeout=3600)
def test_mmmu(self):
import glob
import json
output_path = "./logs/epd_grpc_encoder_mmmu"
self.run_mmmu_eval(self.model, output_path)
result_files = glob.glob(f"{output_path}/**/*.json", recursive=True)
if not result_files:
result_files = glob.glob(f"{output_path}/*.json")
if not result_files:
self.fail(f"No JSON result files found in {output_path}")
result_file_path = result_files[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"MMMU result (grpc encoder): {result}")
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(f"MMMU accuracy (grpc encoder): {mmmu_accuracy:.4f}")
# for qwen2.5-vl-3b-instruct, the accuracy is 0.40
self.assertGreater(mmmu_accuracy, 0.40)
@unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime")
class TestEPDDisaggregationGrpcEncoderOnly(PDDisaggregationServerBase):
"""Test gRPC encoder server integration with zmq_to_scheduler transfers."""
@classmethod
def setUpClass(cls):
super().setUpClass()
os.environ["SGLANG_ENCODER_MM_RECEIVER_MODE"] = "grpc"
cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
cls.encode_port = f"{int(cls.lb_port) + 302}"
print(f"Setting up gRPC EPD encoder: encode={cls.encode_port}")
cls.start_encode()
cls.wait_grpc_ready(cls.base_host, cls.encode_port, cls.process_encode)
@classmethod
def start_encode(cls):
encode_command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
cls.model,
"--host",
cls.base_host,
"--port",
cls.encode_port,
"--trust-remote-code",
"--encoder-only",
"--grpc-mode",
"--encoder-transfer-backend",
"zmq_to_scheduler",
"--tp",
"1",
"--base-gpu-id",
"0",
"--enable-prefix-mm-cache",
]
cls.process_encode = subprocess.Popen(encode_command)
@staticmethod
def wait_grpc_ready(host, port, process, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
deadline = time.time() + timeout
channel = grpc.insecure_channel(f"{host}:{port}")
stub = health_pb2_grpc.HealthStub(channel)
try:
while time.time() < deadline:
if process.poll() is not None:
raise RuntimeError(
f"gRPC encoder server exited with code {process.returncode}"
)
try:
response = stub.Check(
health_pb2.HealthCheckRequest(service=""), timeout=2
)
if response.status == health_pb2.HealthCheckResponse.SERVING:
return
except grpc.RpcError:
pass
time.sleep(1)
finally:
channel.close()
raise RuntimeError(
f"gRPC encoder server not ready at {host}:{port} within {timeout}s"
)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_ENCODER_MM_RECEIVER_MODE", None)
if cls.process_encode:
try:
kill_process_tree(cls.process_encode.pid)
except Exception as e:
print(f"Error killing process: {e}")
super().tearDownClass()
def test_grpc_encoder_zmq_to_scheduler(self):
from smg_grpc_proto import sglang_encoder_pb2, sglang_encoder_pb2_grpc
context = zmq.Context()
recv_port, recv_socket = get_zmq_socket_on_host(
context, zmq.PULL, host=self.base_host
)
channel = grpc.insecure_channel(f"{self.base_host}:{self.encode_port}")
stub = sglang_encoder_pb2_grpc.SglangEncoderStub(channel)
req_id = f"grpc-epd-{int(time.time() * 1000)}"
image_path = os.path.abspath("examples/assets/example_image.png")
try:
stub.SchedulerReceiveUrl(
sglang_encoder_pb2.SchedulerReceiveUrlRequest(
req_id=req_id,
receive_url=f"{self.base_host}:{recv_port}",
receive_count=1,
),
timeout=60,
)
stub.Encode(
sglang_encoder_pb2.EncodeRequest(
mm_items=[image_path],
req_id=req_id,
num_parts=1,
part_idx=0,
),
timeout=300,
)
poller = zmq.Poller()
poller.register(recv_socket, zmq.POLLIN)
socks = dict(poller.poll(60000))
self.assertIn(
recv_socket,
socks,
"No embedding payload received from gRPC encoder server",
)
parts = recv_socket.recv_multipart()
self.assertTrue(parts, "Empty embedding payload from gRPC encoder server")
finally:
recv_socket.close()
context.term()
channel.close()
if __name__ == "__main__":
unittest.main()