| import unittest |
| from types import SimpleNamespace |
|
|
| import torch |
|
|
| from sglang.srt.utils import kill_process_tree |
| from sglang.test.run_eval import run_eval |
| from sglang.test.test_utils import ( |
| DEFAULT_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_FP8, |
| DEFAULT_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_FP8_REVISION, |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| DEFAULT_URL_FOR_TEST, |
| CustomTestCase, |
| popen_launch_server, |
| ) |
|
|
|
|
| class TestEvalFP8ModelOptQuantAccuracy(CustomTestCase): |
|
|
| def _run_test(self, model, other_args, expected_score): |
| base_url = DEFAULT_URL_FOR_TEST |
| other_args = other_args or [] |
|
|
| process = popen_launch_server( |
| model, |
| base_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=other_args, |
| ) |
|
|
| try: |
| args = SimpleNamespace( |
| base_url=base_url, |
| model=model, |
| eval_name="mmlu", |
| num_examples=64, |
| num_threads=32, |
| temperature=0.1, |
| ) |
|
|
| metrics = run_eval(args) |
| self.assertGreaterEqual(metrics["score"], expected_score) |
| finally: |
| kill_process_tree(process.pid) |
|
|
| @unittest.skipIf( |
| torch.version.hip is not None, "modelopt quantization unsupported on ROCm" |
| ) |
| def test_mmlu_offline_only(self): |
| """Test with offline quantization only.""" |
| self._run_test( |
| model=DEFAULT_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_FP8, |
| other_args=[ |
| "--revision", |
| DEFAULT_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_FP8_REVISION, |
| ], |
| expected_score=0.64, |
| ) |
|
|