| import unittest |
| from types import SimpleNamespace |
|
|
| from sglang.srt.utils import kill_process_tree |
| from sglang.test.ci.ci_register import register_amd_ci |
| from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k |
| from sglang.test.send_one import BenchArgs, send_one_prompt |
| from sglang.test.test_utils import ( |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| DEFAULT_URL_FOR_TEST, |
| CustomTestCase, |
| is_in_amd_ci, |
| is_in_ci, |
| popen_launch_server, |
| write_github_step_summary, |
| ) |
|
|
| register_amd_ci(est_time=1800, suite="stage-c-test-large-8-gpu-amd") |
|
|
| DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2" |
|
|
|
|
| @unittest.skipIf(is_in_amd_ci(), "Skip DP test for AMD CI, run TP only.") |
| class TestDeepseekV32DP(CustomTestCase): |
| @classmethod |
| def setUpClass(cls): |
| cls.model = DEEPSEEK_V32_MODEL_PATH |
| cls.base_url = DEFAULT_URL_FOR_TEST |
| other_args = [ |
| "--trust-remote-code", |
| "--tp", |
| "8", |
| "--dp", |
| "8", |
| "--enable-dp-attention", |
| "--model-loader-extra-config", |
| '{"enable_multithread_load": true}', |
| ] |
| if is_in_amd_ci(): |
| other_args += [ |
| "--nsa-prefill-backend", |
| "tilelang", |
| "--nsa-decode-backend", |
| "tilelang", |
| ] |
|
|
| cls.process = popen_launch_server( |
| cls.model, |
| cls.base_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=other_args, |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| kill_process_tree(cls.process.pid) |
|
|
| def test_a_gsm8k( |
| self, |
| ): |
| args = SimpleNamespace( |
| num_shots=20, |
| data_path=None, |
| num_questions=1400, |
| parallel=1400, |
| max_new_tokens=512, |
| host="http://127.0.0.1", |
| port=int(self.base_url.split(":")[-1]), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"{metrics=}") |
|
|
| if is_in_ci(): |
| write_github_step_summary( |
| f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n' |
| ) |
| self.assertGreater(metrics["accuracy"], 0.935) |
|
|
| def test_bs_1_speed(self): |
| args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) |
| acc_length, speed = send_one_prompt(args) |
|
|
| print(f"{speed=:.2f}") |
|
|
| if is_in_ci(): |
| write_github_step_summary( |
| f"### test_bs_1_speed (deepseek-v32)\n" f"{speed=:.2f} token/s\n" |
| ) |
| if is_in_amd_ci(): |
| self.assertGreater(speed, 10) |
| else: |
| self.assertGreater(speed, 50) |
|
|
|
|
| class TestDeepseekV32TP(CustomTestCase): |
| @classmethod |
| def setUpClass(cls): |
| cls.model = DEEPSEEK_V32_MODEL_PATH |
| cls.base_url = DEFAULT_URL_FOR_TEST |
| other_args = [ |
| "--trust-remote-code", |
| "--tp", |
| "8", |
| "--model-loader-extra-config", |
| '{"enable_multithread_load": true}', |
| ] |
| if is_in_amd_ci(): |
| other_args += [ |
| "--nsa-prefill-backend", |
| "tilelang", |
| "--nsa-decode-backend", |
| "tilelang", |
| ] |
|
|
| cls.process = popen_launch_server( |
| cls.model, |
| cls.base_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=other_args, |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| kill_process_tree(cls.process.pid) |
|
|
| def test_a_gsm8k( |
| self, |
| ): |
| args = SimpleNamespace( |
| num_shots=20, |
| data_path=None, |
| num_questions=1400, |
| parallel=1400, |
| max_new_tokens=512, |
| host="http://127.0.0.1", |
| port=int(self.base_url.split(":")[-1]), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"{metrics=}") |
|
|
| if is_in_ci(): |
| write_github_step_summary( |
| f"### test_gsm8k (deepseek-v32)\n" f'{metrics["accuracy"]=:.3f}\n' |
| ) |
| self.assertGreater(metrics["accuracy"], 0.935) |
|
|
| def test_bs_1_speed(self): |
| args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) |
| acc_length, speed = send_one_prompt(args) |
|
|
| print(f"{speed=:.2f}") |
|
|
| if is_in_ci(): |
| write_github_step_summary( |
| f"### test_bs_1_speed (deepseek-v32)\n" f"{speed=:.2f} token/s\n" |
| ) |
| if is_in_amd_ci(): |
| self.assertGreater(speed, 15) |
| else: |
| self.assertGreater(speed, 70) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|