| import unittest |
| from types import SimpleNamespace |
|
|
| from sglang.srt.environ import envs |
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k |
| from sglang.test.server_fixtures.disaggregation_fixture import ( |
| PDDisaggregationServerBase, |
| ) |
| from sglang.test.test_utils import ( |
| DEFAULT_MODEL_NAME_FOR_TEST, |
| DEFAULT_MODEL_NAME_FOR_TEST_MLA, |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| popen_launch_pd_server, |
| try_cached_model, |
| ) |
|
|
| register_cuda_ci(est_time=600, suite="stage-c-test-8-gpu-h20") |
|
|
|
|
| class TestDisaggregationMooncakePrefillLargerTP(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) |
|
|
| cls.model = try_cached_model(DEFAULT_MODEL_NAME_FOR_TEST_MLA) |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "4", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "2", |
| "--base-gpu-id", |
| "4", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.60) |
|
|
|
|
| class TestDisaggregationMooncakeDecodeLargerTP(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) |
|
|
| cls.model = try_cached_model(DEFAULT_MODEL_NAME_FOR_TEST_MLA) |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "2", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "4", |
| "--base-gpu-id", |
| "4", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.60) |
|
|
|
|
| class TestDisaggregationMooncakeMHAPrefillLargerTP(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) |
|
|
| cls.model = try_cached_model(DEFAULT_MODEL_NAME_FOR_TEST) |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "4", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "2", |
| "--base-gpu-id", |
| "4", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.60) |
|
|
|
|
| class TestDisaggregationMooncakeMHADecodeLargerTP(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) |
|
|
| cls.model = try_cached_model(DEFAULT_MODEL_NAME_FOR_TEST) |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "2", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "4", |
| "--base-gpu-id", |
| "4", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.60) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|