| import unittest |
|
|
| from sglang.srt.environ import envs |
| from sglang.srt.utils import kill_process_tree |
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.kits.gsm8k_accuracy_kit import GSM8KMixin |
| from sglang.test.test_utils import ( |
| DEFAULT_TARGET_MODEL_NGRAM, |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| DEFAULT_URL_FOR_TEST, |
| CustomTestCase, |
| popen_launch_server, |
| ) |
|
|
| register_cuda_ci(est_time=230, suite="stage-b-test-large-1-gpu") |
|
|
| GSM_DATASET_PATH = None |
|
|
|
|
| |
| DEFAULT_SERVER_ARGS = [ |
| "--trust-remote-code", |
| "--cuda-graph-max-bs", |
| "8", |
| "--speculative-algorithm", |
| "NGRAM", |
| "--speculative-num-draft-tokens", |
| "16", |
| "--mem-fraction-static", |
| 0.8, |
| ] |
|
|
|
|
| class TestNgramSpeculativeDecodingBase(GSM8KMixin, CustomTestCase): |
| model = DEFAULT_TARGET_MODEL_NGRAM |
| base_url = DEFAULT_URL_FOR_TEST |
| gsm8k_accuracy_thres = 0.79 |
| gsm8k_accept_length_thres = 1.8 |
|
|
| @classmethod |
| def get_server_args(cls): |
| """Return the arguments for the server launch. Override in subclasses.""" |
| return DEFAULT_SERVER_ARGS + ["--attention-backend", "fa3"] |
|
|
| @classmethod |
| def setUpClass(cls): |
| |
| |
| envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False) |
| envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) |
| model = cls.model |
| cls.process = popen_launch_server( |
| model, |
| cls.base_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=cls.get_server_args(), |
| ) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| kill_process_tree(cls.process.pid) |
|
|
|
|
| class TestNgramSpeculativeDecodingTriton(TestNgramSpeculativeDecodingBase): |
|
|
| @classmethod |
| def get_server_args(cls): |
| return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"] |
|
|
|
|
| class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase): |
| @classmethod |
| def get_server_args(cls): |
| return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"] |
|
|
|
|
| class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase): |
|
|
| @classmethod |
| def get_server_args(cls): |
| return DEFAULT_SERVER_ARGS + [ |
| "--attention-backend", |
| "flashinfer", |
| "--page-size", |
| "64", |
| ] |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|