| import os |
| import subprocess |
| import threading |
| import time |
| import unittest |
|
|
| import grpc |
| import zmq |
| from grpc_health.v1 import health_pb2, health_pb2_grpc |
|
|
| from sglang.srt.utils import get_zmq_socket_on_host, kill_process_tree |
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.kits.mmmu_vlm_kit import _run_lmms_eval_with_retry |
| from sglang.test.server_fixtures.disaggregation_fixture import ( |
| PDDisaggregationServerBase, |
| ) |
| from sglang.test.test_utils import ( |
| DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST, |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| is_in_ci, |
| popen_launch_server, |
| ) |
|
|
| register_cuda_ci(est_time=150, suite="stage-c-test-4-gpu-h100") |
|
|
|
|
| @unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime") |
| class TestEPDDisaggregationOneEncoder(PDDisaggregationServerBase): |
| """Test EPD disaggregation with single encode server""" |
|
|
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST |
| cls.encode_port = f"{int(cls.lb_port) + 300}" |
| cls.encode_url = f"http://{cls.base_host}:{cls.encode_port}" |
|
|
| print( |
| f"Setting up EPD (one encoder): encode={cls.encode_port}, " |
| f"prefill={cls.prefill_port}, decode={cls.decode_port}" |
| ) |
|
|
| |
| cls.start_encode() |
| prefill_thread = threading.Thread(target=cls.start_prefill) |
| decode_thread = threading.Thread(target=cls.start_decode) |
| prefill_thread.start() |
| decode_thread.start() |
| prefill_thread.join() |
| decode_thread.join() |
|
|
| |
| cls.wait_server_ready(cls.encode_url + "/health", process=cls.process_encode) |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| |
| cls.api_key = "sk-123456" |
| os.environ["OPENAI_API_KEY"] = cls.api_key |
| os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" |
|
|
| @classmethod |
| def start_encode(cls): |
| """Start encode server for multimodal processing""" |
| encode_args = [ |
| "--trust-remote-code", |
| "--encoder-only", |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--tp", |
| "1", |
| "--port", |
| cls.encode_port, |
| "--enable-prefix-mm-cache", |
| ] |
| cls.process_encode = popen_launch_server( |
| cls.model, |
| base_url=cls.encode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=encode_args, |
| ) |
|
|
| @classmethod |
| def start_prefill(cls): |
| """Start prefill server with language model only""" |
| prefill_args = [ |
| "--trust-remote-code", |
| "--language-only", |
| "--encoder-urls", |
| cls.encode_url, |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "1", |
| "--port", |
| cls.prefill_port, |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_server( |
| cls.model, |
| base_url=cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| """Start decode server""" |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "2", |
| "--port", |
| cls.decode_port, |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_server( |
| cls.model, |
| base_url=cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| """Clean up all processes""" |
| for process in [ |
| cls.process_lb, |
| cls.process_decode, |
| cls.process_prefill, |
| cls.process_encode, |
| ]: |
| if process: |
| try: |
| kill_process_tree(process.pid) |
| except Exception as e: |
| print(f"Error killing process: {e}") |
|
|
| def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): |
| """ |
| Evaluate a VLM on the MMMU validation set with lmms-eval. |
| Reference: test_vlm_models.py |
| |
| Args: |
| model_version: Model version/checkpoint to evaluate |
| output_path: Path to save evaluation results |
| limit: Number of samples to evaluate (default: "50" for CI time constraints) |
| """ |
| model = "openai_compatible" |
| tp = 1 |
| tasks = "mmmu_val" |
| batch_size = 32 |
| log_suffix = "openai_compatible" |
| os.makedirs(output_path, exist_ok=True) |
|
|
| model_args = f'model_version="{model_version}",tp={tp}' |
|
|
| cmd = [ |
| "python3", |
| "-m", |
| "lmms_eval", |
| "--model", |
| model, |
| "--model_args", |
| model_args, |
| "--tasks", |
| tasks, |
| "--batch_size", |
| str(batch_size), |
| "--log_samples", |
| "--log_samples_suffix", |
| log_suffix, |
| "--output_path", |
| str(output_path), |
| "--limit", |
| limit, |
| ] |
|
|
| _run_lmms_eval_with_retry(cmd, timeout=3600) |
|
|
| def test_mmmu(self): |
| """Test MMMU evaluation with EPD disaggregation""" |
| import glob |
| import json |
|
|
| output_path = "./logs/epd_one_encoder_mmmu" |
| self.run_mmmu_eval(self.model, output_path) |
|
|
| |
| result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) |
| if not result_files: |
| result_files = glob.glob(f"{output_path}/*.json") |
|
|
| if not result_files: |
| self.fail(f"No JSON result files found in {output_path}") |
|
|
| result_file_path = result_files[0] |
| with open(result_file_path, "r") as f: |
| result = json.load(f) |
| print(f"MMMU result: {result}") |
|
|
| mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] |
| print(f"MMMU accuracy: {mmmu_accuracy:.4f}") |
|
|
| |
| self.assertGreater(mmmu_accuracy, 0.40) |
|
|
|
|
| class TestEPDDisaggregationMultiEncoders(PDDisaggregationServerBase): |
| """ |
| Test EPD disaggregation with multiple encode servers for load balancing. |
| Both encode servers run on GPU 0 (different ports) for testing load distribution. |
| """ |
|
|
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST |
| cls.encode_port1 = f"{int(cls.lb_port) + 300}" |
| cls.encode_port2 = f"{int(cls.lb_port) + 301}" |
| cls.encode_url1 = f"http://{cls.base_host}:{cls.encode_port1}" |
| cls.encode_url2 = f"http://{cls.base_host}:{cls.encode_port2}" |
|
|
| print( |
| f"Setting up EPD (multiple encoders): encode1={cls.encode_port1}, " |
| f"encode2={cls.encode_port2}, prefill={cls.prefill_port}, decode={cls.decode_port}" |
| ) |
|
|
| |
| encode1_thread = threading.Thread( |
| target=cls.start_encode_server, args=(cls.encode_port1, 0) |
| ) |
| encode2_thread = threading.Thread( |
| target=cls.start_encode_server, args=(cls.encode_port2, 1) |
| ) |
| encode1_thread.start() |
| encode2_thread.start() |
| encode1_thread.join() |
| encode2_thread.join() |
|
|
| prefill_thread = threading.Thread(target=cls.start_prefill) |
| decode_thread = threading.Thread(target=cls.start_decode) |
| prefill_thread.start() |
| decode_thread.start() |
| prefill_thread.join() |
| decode_thread.join() |
|
|
| cls.wait_server_ready(cls.encode_url1 + "/health", process=cls.process_encode1) |
| cls.wait_server_ready(cls.encode_url2 + "/health", process=cls.process_encode2) |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| |
| cls.api_key = "sk-123456" |
| os.environ["OPENAI_API_KEY"] = cls.api_key |
| os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" |
|
|
| @classmethod |
| def start_encode_server(cls, port, gpu_id): |
| """Start an encode server on specific port and GPU""" |
| encode_args = [ |
| "--trust-remote-code", |
| "--encoder-only", |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--tp", |
| "1", |
| "--port", |
| port, |
| "--enable-prefix-mm-cache", |
| ] |
| |
| if gpu_id != 0: |
| encode_args.extend(["--base-gpu-id", str(gpu_id)]) |
|
|
| process = popen_launch_server( |
| cls.model, |
| base_url=f"http://{cls.base_host}:{port}", |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=encode_args, |
| ) |
| if port == cls.encode_port1: |
| cls.process_encode1 = process |
| else: |
| cls.process_encode2 = process |
|
|
| @classmethod |
| def start_prefill(cls): |
| """Start prefill server with multiple encode URLs""" |
| prefill_args = [ |
| "--trust-remote-code", |
| "--language-only", |
| "--encoder-urls", |
| cls.encode_url1, |
| cls.encode_url2, |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "2", |
| "--port", |
| cls.prefill_port, |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_server( |
| cls.model, |
| base_url=cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| """Start decode server""" |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "3", |
| "--port", |
| cls.decode_port, |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_server( |
| cls.model, |
| base_url=cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| """Clean up all processes""" |
| for process in [ |
| cls.process_lb, |
| cls.process_decode, |
| cls.process_prefill, |
| cls.process_encode1, |
| cls.process_encode2, |
| ]: |
| if process: |
| try: |
| kill_process_tree(process.pid) |
| except Exception as e: |
| print(f"Error killing process: {e}") |
|
|
| def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): |
| """ |
| Evaluate a VLM on the MMMU validation set with lmms-eval. |
| Reference: test_vlm_models.py |
| |
| Args: |
| model_version: Model version/checkpoint to evaluate |
| output_path: Path to save evaluation results |
| limit: Number of samples to evaluate (default: "50" for CI time constraints) |
| """ |
| model = "openai_compatible" |
| tp = 1 |
| tasks = "mmmu_val" |
| batch_size = 32 |
| log_suffix = "openai_compatible" |
| os.makedirs(output_path, exist_ok=True) |
|
|
| model_args = f'model_version="{model_version}",tp={tp}' |
|
|
| cmd = [ |
| "python3", |
| "-m", |
| "lmms_eval", |
| "--model", |
| model, |
| "--model_args", |
| model_args, |
| "--tasks", |
| tasks, |
| "--batch_size", |
| str(batch_size), |
| "--log_samples", |
| "--log_samples_suffix", |
| log_suffix, |
| "--output_path", |
| str(output_path), |
| "--limit", |
| limit, |
| ] |
|
|
| _run_lmms_eval_with_retry(cmd, timeout=3600) |
|
|
| def test_mmmu(self): |
| """Test MMMU evaluation with EPD disaggregation (multiple encoders)""" |
| import glob |
| import json |
|
|
| output_path = "./logs/epd_multi_encoder_mmmu" |
| self.run_mmmu_eval(self.model, output_path) |
|
|
| |
| result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) |
| if not result_files: |
| result_files = glob.glob(f"{output_path}/*.json") |
|
|
| if not result_files: |
| self.fail(f"No JSON result files found in {output_path}") |
|
|
| result_file_path = result_files[0] |
| with open(result_file_path, "r") as f: |
| result = json.load(f) |
| print(f"MMMU result (multi encoder): {result}") |
|
|
| mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] |
| print(f"MMMU accuracy (multi encoder): {mmmu_accuracy:.4f}") |
| |
| self.assertGreater(mmmu_accuracy, 0.40) |
|
|
|
|
| @unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime") |
| class TestEPDDisaggregationGrpcEncoderMMMU(PDDisaggregationServerBase): |
| """Test MMMU evaluation with gRPC encoder in EPD mode.""" |
|
|
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST |
| cls.encode_port = f"{int(cls.lb_port) + 304}" |
| cls.encode_url = f"grpc://{cls.base_host}:{cls.encode_port}" |
|
|
| print( |
| f"Setting up gRPC EPD (one encoder): encode={cls.encode_port}, " |
| f"prefill={cls.prefill_port}, decode={cls.decode_port}" |
| ) |
|
|
| cls.start_encode() |
| prefill_thread = threading.Thread(target=cls.start_prefill) |
| decode_thread = threading.Thread(target=cls.start_decode) |
| prefill_thread.start() |
| decode_thread.start() |
| prefill_thread.join() |
| decode_thread.join() |
|
|
| cls.wait_grpc_ready(cls.base_host, cls.encode_port, cls.process_encode) |
| cls.wait_server_ready(cls.prefill_url + "/health") |
| cls.wait_server_ready(cls.decode_url + "/health") |
|
|
| cls.launch_lb() |
|
|
| cls.api_key = "sk-123456" |
| os.environ["OPENAI_API_KEY"] = cls.api_key |
| os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" |
|
|
| @classmethod |
| def start_encode(cls): |
| encode_command = [ |
| "python3", |
| "-m", |
| "sglang.launch_server", |
| "--model-path", |
| cls.model, |
| "--host", |
| cls.base_host, |
| "--port", |
| cls.encode_port, |
| "--trust-remote-code", |
| "--encoder-only", |
| "--grpc-mode", |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "0", |
| "--enable-prefix-mm-cache", |
| ] |
| cls.process_encode = subprocess.Popen(encode_command) |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--language-only", |
| "--encoder-urls", |
| cls.encode_url, |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "1", |
| "--port", |
| cls.prefill_port, |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| prefill_env = os.environ.copy() |
| prefill_env["SGLANG_ENCODER_MM_RECEIVER_MODE"] = "grpc" |
| cls.process_prefill = popen_launch_server( |
| cls.model, |
| base_url=cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| env=prefill_env, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "2", |
| "--port", |
| cls.decode_port, |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_server( |
| cls.model, |
| base_url=cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| @staticmethod |
| def wait_grpc_ready(host, port, process, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): |
| deadline = time.time() + timeout |
| channel = grpc.insecure_channel(f"{host}:{port}") |
| stub = health_pb2_grpc.HealthStub(channel) |
| try: |
| while time.time() < deadline: |
| if process.poll() is not None: |
| raise RuntimeError( |
| f"gRPC encoder server exited with code {process.returncode}" |
| ) |
| try: |
| response = stub.Check( |
| health_pb2.HealthCheckRequest(service=""), timeout=2 |
| ) |
| if response.status == health_pb2.HealthCheckResponse.SERVING: |
| return |
| except grpc.RpcError: |
| pass |
| time.sleep(1) |
| finally: |
| channel.close() |
|
|
| raise RuntimeError( |
| f"gRPC encoder server not ready at {host}:{port} within {timeout}s" |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| os.environ.pop("SGLANG_ENCODER_MM_RECEIVER_MODE", None) |
| os.environ.pop("OPENAI_API_KEY", None) |
| os.environ.pop("OPENAI_API_BASE", None) |
| for process in [ |
| cls.process_lb, |
| cls.process_decode, |
| cls.process_prefill, |
| cls.process_encode, |
| ]: |
| if process: |
| try: |
| kill_process_tree(process.pid) |
| except Exception as e: |
| print(f"Error killing process: {e}") |
|
|
| def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): |
| model = "openai_compatible" |
| tp = 1 |
| tasks = "mmmu_val" |
| batch_size = 32 |
| log_suffix = "openai_compatible" |
| os.makedirs(output_path, exist_ok=True) |
|
|
| model_args = f'model_version="{model_version}",tp={tp}' |
|
|
| cmd = [ |
| "python3", |
| "-m", |
| "lmms_eval", |
| "--model", |
| model, |
| "--model_args", |
| model_args, |
| "--tasks", |
| tasks, |
| "--batch_size", |
| str(batch_size), |
| "--log_samples", |
| "--log_samples_suffix", |
| log_suffix, |
| "--output_path", |
| str(output_path), |
| "--limit", |
| limit, |
| ] |
|
|
| _run_lmms_eval_with_retry(cmd, timeout=3600) |
|
|
| def test_mmmu(self): |
| import glob |
| import json |
|
|
| output_path = "./logs/epd_grpc_encoder_mmmu" |
| self.run_mmmu_eval(self.model, output_path) |
|
|
| result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) |
| if not result_files: |
| result_files = glob.glob(f"{output_path}/*.json") |
|
|
| if not result_files: |
| self.fail(f"No JSON result files found in {output_path}") |
|
|
| result_file_path = result_files[0] |
| with open(result_file_path, "r") as f: |
| result = json.load(f) |
| print(f"MMMU result (grpc encoder): {result}") |
|
|
| mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] |
| print(f"MMMU accuracy (grpc encoder): {mmmu_accuracy:.4f}") |
| |
| self.assertGreater(mmmu_accuracy, 0.40) |
|
|
|
|
| @unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime") |
| class TestEPDDisaggregationGrpcEncoderOnly(PDDisaggregationServerBase): |
| """Test gRPC encoder server integration with zmq_to_scheduler transfers.""" |
|
|
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| os.environ["SGLANG_ENCODER_MM_RECEIVER_MODE"] = "grpc" |
| cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST |
| cls.encode_port = f"{int(cls.lb_port) + 302}" |
|
|
| print(f"Setting up gRPC EPD encoder: encode={cls.encode_port}") |
|
|
| cls.start_encode() |
| cls.wait_grpc_ready(cls.base_host, cls.encode_port, cls.process_encode) |
|
|
| @classmethod |
| def start_encode(cls): |
| encode_command = [ |
| "python3", |
| "-m", |
| "sglang.launch_server", |
| "--model-path", |
| cls.model, |
| "--host", |
| cls.base_host, |
| "--port", |
| cls.encode_port, |
| "--trust-remote-code", |
| "--encoder-only", |
| "--grpc-mode", |
| "--encoder-transfer-backend", |
| "zmq_to_scheduler", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "0", |
| "--enable-prefix-mm-cache", |
| ] |
| cls.process_encode = subprocess.Popen(encode_command) |
|
|
| @staticmethod |
| def wait_grpc_ready(host, port, process, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): |
| deadline = time.time() + timeout |
| channel = grpc.insecure_channel(f"{host}:{port}") |
| stub = health_pb2_grpc.HealthStub(channel) |
| try: |
| while time.time() < deadline: |
| if process.poll() is not None: |
| raise RuntimeError( |
| f"gRPC encoder server exited with code {process.returncode}" |
| ) |
| try: |
| response = stub.Check( |
| health_pb2.HealthCheckRequest(service=""), timeout=2 |
| ) |
| if response.status == health_pb2.HealthCheckResponse.SERVING: |
| return |
| except grpc.RpcError: |
| pass |
| time.sleep(1) |
| finally: |
| channel.close() |
|
|
| raise RuntimeError( |
| f"gRPC encoder server not ready at {host}:{port} within {timeout}s" |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| os.environ.pop("SGLANG_ENCODER_MM_RECEIVER_MODE", None) |
| if cls.process_encode: |
| try: |
| kill_process_tree(cls.process_encode.pid) |
| except Exception as e: |
| print(f"Error killing process: {e}") |
| super().tearDownClass() |
|
|
| def test_grpc_encoder_zmq_to_scheduler(self): |
| from smg_grpc_proto import sglang_encoder_pb2, sglang_encoder_pb2_grpc |
|
|
| context = zmq.Context() |
| recv_port, recv_socket = get_zmq_socket_on_host( |
| context, zmq.PULL, host=self.base_host |
| ) |
| channel = grpc.insecure_channel(f"{self.base_host}:{self.encode_port}") |
| stub = sglang_encoder_pb2_grpc.SglangEncoderStub(channel) |
| req_id = f"grpc-epd-{int(time.time() * 1000)}" |
| image_path = os.path.abspath("examples/assets/example_image.png") |
|
|
| try: |
| stub.SchedulerReceiveUrl( |
| sglang_encoder_pb2.SchedulerReceiveUrlRequest( |
| req_id=req_id, |
| receive_url=f"{self.base_host}:{recv_port}", |
| receive_count=1, |
| ), |
| timeout=60, |
| ) |
| stub.Encode( |
| sglang_encoder_pb2.EncodeRequest( |
| mm_items=[image_path], |
| req_id=req_id, |
| num_parts=1, |
| part_idx=0, |
| ), |
| timeout=300, |
| ) |
|
|
| poller = zmq.Poller() |
| poller.register(recv_socket, zmq.POLLIN) |
| socks = dict(poller.poll(60000)) |
| self.assertIn( |
| recv_socket, |
| socks, |
| "No embedding payload received from gRPC encoder server", |
| ) |
| parts = recv_socket.recv_multipart() |
| self.assertTrue(parts, "Empty embedding payload from gRPC encoder server") |
| finally: |
| recv_socket.close() |
| context.term() |
| channel.close() |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|