| """ | |
| Usage: | |
| python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k | |
| python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency | |
| python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs | |
| python3 -m unittest test_pp_single_node.TestQwenVLPPAccuracy.test_mmmu | |
| """ | |
| import time | |
| import unittest | |
| from types import SimpleNamespace | |
| import requests | |
| from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs | |
| from sglang.srt.server_args import ServerArgs | |
| from sglang.srt.utils import kill_process_tree | |
| from sglang.test.ci.ci_register import register_cuda_ci | |
| from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k | |
| from sglang.test.run_eval import run_eval | |
| from sglang.test.test_utils import ( | |
| DEFAULT_MLA_MODEL_NAME_FOR_TEST, | |
| DEFAULT_MODEL_NAME_FOR_TEST, | |
| DEFAULT_MODEL_NAME_FOR_TEST_GLM_41V_PP, | |
| DEFAULT_MODEL_NAME_FOR_TEST_VL_PP, | |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| DEFAULT_URL_FOR_TEST, | |
| CustomTestCase, | |
| is_in_ci, | |
| popen_launch_server, | |
| run_bench_one_batch_server, | |
| ) | |
| register_cuda_ci(est_time=500, suite="stage-c-test-4-gpu-h100") | |
| class TestPPAccuracy(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.base_url = "http://127.0.0.1:23333" | |
| cls.process = popen_launch_server( | |
| DEFAULT_MODEL_NAME_FOR_TEST, | |
| cls.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--tp-size", | |
| 2, | |
| "--pp-size", | |
| 2, | |
| "--chunked-prefill-size", | |
| 256, | |
| ], | |
| ) | |
| def tearDownClass(cls): | |
| kill_process_tree(cls.process.pid) | |
| def test_gsm8k(self): | |
| args = SimpleNamespace( | |
| num_shots=5, | |
| data_path=None, | |
| num_questions=200, | |
| max_new_tokens=512, | |
| parallel=128, | |
| host="http://127.0.0.1", | |
| port=int(self.base_url.split(":")[-1]), | |
| ) | |
| metrics = run_eval_few_shot_gsm8k(args) | |
| print(f"{metrics=}") | |
| self.assertGreater(metrics["accuracy"], 0.74) | |
| # Wait a little bit so that the memory check happens. | |
| time.sleep(4) | |
| def test_logprob(self): | |
| response = requests.post( | |
| f"{self.base_url}/generate", | |
| json={ | |
| "text": "The capital of France is", | |
| "sampling_params": { | |
| "temperature": 0, | |
| "max_new_tokens": 16, | |
| }, | |
| "return_logprob": True, | |
| "top_logprobs_num": 5, | |
| "logprob_start_len": 0, | |
| }, | |
| ) | |
| response_json = response.json() | |
| input_token_logprobs = response_json["meta_info"]["input_token_logprobs"] | |
| output_token_logprobs = response_json["meta_info"]["output_token_logprobs"] | |
| output_top_logprobs = response_json["meta_info"]["output_top_logprobs"] | |
| assert len(input_token_logprobs) == 6 | |
| assert len(output_token_logprobs) == 16 | |
| assert len(output_top_logprobs) == 16 | |
| class TestDPAttentionDP2PP2(CustomTestCase): | |
| def setUpClass(cls): | |
| cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST | |
| cls.base_url = DEFAULT_URL_FOR_TEST | |
| cls.process = popen_launch_server( | |
| cls.model, | |
| cls.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--trust-remote-code", | |
| "--tp", | |
| "2", | |
| "--pp-size", | |
| "2", | |
| "--enable-dp-attention", | |
| "--dp", | |
| "2", | |
| ], | |
| ) | |
| def tearDownClass(cls): | |
| kill_process_tree(cls.process.pid) | |
| def test_mgsm_en(self): | |
| args = SimpleNamespace( | |
| base_url=self.base_url, | |
| model=self.model, | |
| eval_name="mgsm_en", | |
| num_examples=None, | |
| num_threads=1024, | |
| ) | |
| metrics = run_eval(args) | |
| print(f"{metrics=}") | |
| self.assertGreater(metrics["score"], 0.8) | |
| class TestQwenVLPPAccuracy(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.model = DEFAULT_MODEL_NAME_FOR_TEST_VL_PP | |
| cls.base_url = "http://127.0.0.1:23333" | |
| cls.process = popen_launch_server( | |
| DEFAULT_MODEL_NAME_FOR_TEST_VL_PP, | |
| cls.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--tp-size", | |
| 1, | |
| "--pp-size", | |
| 4, | |
| "--chunked-prefill-size", | |
| 8192, | |
| "--enable-multimodal", | |
| ], | |
| ) | |
| def test_gsm8k(self): | |
| args = SimpleNamespace( | |
| num_shots=5, | |
| data_path=None, | |
| num_questions=200, | |
| max_new_tokens=512, | |
| parallel=128, | |
| host="http://127.0.0.1", | |
| port=int(self.base_url.split(":")[-1]), | |
| ) | |
| metrics = run_eval_few_shot_gsm8k(args) | |
| print(f"{metrics=}") | |
| self.assertGreater(metrics["accuracy"], 0.65) | |
| # Wait a little bit so that the memory check happens. | |
| time.sleep(4) | |
| def tearDownClass(cls): | |
| kill_process_tree(cls.process.pid) | |
| def test_mmmu(self): | |
| args = SimpleNamespace( | |
| base_url=self.base_url, | |
| model=self.model, | |
| eval_name="mmmu", | |
| num_examples=None, | |
| num_threads=32, | |
| ) | |
| metrics = run_eval(args) | |
| print(f"{metrics=}") | |
| self.assertGreater(metrics["score"], 0.26) | |
| class TestQwenPPAccuracy(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts | |
| cls.model_name = "Qwen/Qwen3-8B" # replace with your Qwen Model if needed | |
| def run_gsm8k_test(self, pp_size): | |
| process = popen_launch_server( | |
| self.model_name, | |
| self.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--pp-size", | |
| pp_size, | |
| "--chunked-prefill-size", | |
| 256, | |
| ], | |
| ) | |
| try: | |
| args = SimpleNamespace( | |
| num_shots=5, | |
| data_path=None, | |
| num_questions=200, | |
| max_new_tokens=512, | |
| parallel=128, | |
| host="http://127.0.0.1", | |
| port=int(self.base_url.split(":")[-1]), | |
| ) | |
| metrics = run_eval_few_shot_gsm8k(args) | |
| time.sleep(5) | |
| return metrics | |
| finally: | |
| kill_process_tree(process.pid) | |
| def test_pp_consistency(self): | |
| baseline = self.run_gsm8k_test(pp_size=1) | |
| pp_metrics = self.run_gsm8k_test(pp_size=2) | |
| print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") | |
| self.assertGreaterEqual(baseline["accuracy"], 0.74) | |
| self.assertGreaterEqual( | |
| pp_metrics["accuracy"], | |
| baseline["accuracy"] - 0.02, | |
| msg=( | |
| f"PP accuracy dropped more than 1% compared to baseline. " | |
| f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" | |
| ), | |
| ) | |
| class TestQwenPPTieWeightsAccuracy(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.base_url = "http://127.0.0.1:23335" # different ports to avoid conflicts | |
| cls.model_name = ( | |
| "Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True | |
| ) | |
| def run_gsm8k_test(self, pp_size): | |
| process = popen_launch_server( | |
| self.model_name, | |
| self.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--pp-size", | |
| pp_size, | |
| "--chunked-prefill-size", | |
| 256, | |
| ], | |
| ) | |
| try: | |
| args = SimpleNamespace( | |
| num_shots=5, | |
| data_path=None, | |
| num_questions=200, | |
| max_new_tokens=512, | |
| parallel=128, | |
| host="http://127.0.0.1", | |
| port=int(self.base_url.split(":")[-1]), | |
| ) | |
| metrics = run_eval_few_shot_gsm8k(args) | |
| time.sleep(5) | |
| return metrics | |
| finally: | |
| kill_process_tree(process.pid) | |
| def test_pp_consistency(self): | |
| baseline = self.run_gsm8k_test(pp_size=1) | |
| pp_metrics = self.run_gsm8k_test(pp_size=2) | |
| print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") | |
| self.assertGreaterEqual(baseline["accuracy"], 0.38) | |
| self.assertGreaterEqual( | |
| pp_metrics["accuracy"], | |
| baseline["accuracy"] - 0.02, | |
| msg=( | |
| f"PP accuracy dropped more than 1% compared to baseline. " | |
| f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" | |
| ), | |
| ) | |
| class TestQwenMoePPAccuracy(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.base_url = "http://127.0.0.1:23336" # different ports to avoid conflicts | |
| cls.model_name = "Qwen/Qwen3-30B-A3B" # replace with your Qwen Model if needed | |
| def run_gsm8k_test(self, pp_size): | |
| process = popen_launch_server( | |
| self.model_name, | |
| self.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--pp-size", | |
| pp_size, | |
| "--chunked-prefill-size", | |
| 256, | |
| ], | |
| ) | |
| try: | |
| args = SimpleNamespace( | |
| num_shots=5, | |
| data_path=None, | |
| num_questions=200, | |
| max_new_tokens=512, | |
| parallel=128, | |
| host="http://127.0.0.1", | |
| port=int(self.base_url.split(":")[-1]), | |
| ) | |
| metrics = run_eval_few_shot_gsm8k(args) | |
| time.sleep(5) | |
| return metrics | |
| finally: | |
| kill_process_tree(process.pid) | |
| def test_pp_consistency(self): | |
| baseline = self.run_gsm8k_test(pp_size=1) | |
| pp_metrics = self.run_gsm8k_test(pp_size=2) | |
| print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}") | |
| self.assertGreaterEqual(baseline["accuracy"], 0.74) | |
| self.assertGreaterEqual( | |
| pp_metrics["accuracy"], | |
| baseline["accuracy"] - 0.02, | |
| msg=( | |
| f"PP accuracy dropped more than 1% compared to baseline. " | |
| f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}" | |
| ), | |
| ) | |
| class TestFixedBugs(unittest.TestCase): | |
| def test_chunked_prefill_with_small_bs(self): | |
| model = DEFAULT_MODEL_NAME_FOR_TEST | |
| server_args = ServerArgs(model_path=model) | |
| bench_args = OneBatchBenchArgs( | |
| batch_size=(1,), | |
| input_len=(1,), | |
| output_len=(1,), | |
| base_url=DEFAULT_URL_FOR_TEST, | |
| ) | |
| other_server_args = [ | |
| "--tp-size", | |
| 2, | |
| "--pp-size", | |
| 2, | |
| "--chunked-prefill", | |
| 256, | |
| "--max-running-requests", | |
| 2, | |
| ] | |
| run_bench_one_batch_server( | |
| model, | |
| DEFAULT_URL_FOR_TEST, | |
| server_args, | |
| bench_args, | |
| other_server_args, | |
| ) | |
| class TestGLM41VPPAccuracy(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.model = DEFAULT_MODEL_NAME_FOR_TEST_GLM_41V_PP | |
| cls.base_url = DEFAULT_URL_FOR_TEST | |
| cls.process = popen_launch_server( | |
| DEFAULT_MODEL_NAME_FOR_TEST_GLM_41V_PP, | |
| cls.base_url, | |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | |
| other_args=[ | |
| "--tp-size", | |
| 1, | |
| "--pp-size", | |
| 2, | |
| "--chunked-prefill-size", | |
| 8192, | |
| "--enable-multimodal", | |
| "--reasoning-parser", | |
| "glm45", | |
| ], | |
| ) | |
| def tearDownClass(cls): | |
| kill_process_tree(cls.process.pid) | |
| def test_mmmu(self): | |
| args = SimpleNamespace( | |
| base_url=self.base_url, | |
| model=self.model, | |
| eval_name="mmmu", | |
| num_examples=None, | |
| num_threads=32, | |
| response_answer_regex=r"<\|begin_of_box\|>(.*)<\|end_of_box\|>", | |
| ) | |
| metrics = run_eval(args) | |
| print(f"{metrics=}") | |
| self.assertGreater(metrics["score"], 0.45) | |
| if __name__ == "__main__": | |
| unittest.main() | |