| import json |
| import os |
| import unittest |
| from types import SimpleNamespace |
|
|
| import openai |
| import requests |
| from transformers import AutoTokenizer |
|
|
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k |
| from sglang.test.server_fixtures.disaggregation_fixture import ( |
| PDDisaggregationServerBase, |
| ) |
| from sglang.test.test_utils import ( |
| DEFAULT_DRAFT_MODEL_EAGLE, |
| DEFAULT_MODEL_NAME_FOR_TEST, |
| DEFAULT_TARGET_MODEL_EAGLE, |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| popen_launch_pd_server, |
| ) |
|
|
| register_cuda_ci(est_time=400, suite="stage-b-test-large-2-gpu") |
|
|
|
|
| class TestDisaggregationAccuracy(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls.model = DEFAULT_MODEL_NAME_FOR_TEST |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "1", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.62) |
|
|
| def test_logprob(self): |
| prompt = "The capital of france is " |
| response = requests.post( |
| self.lb_url + "/generate", |
| json={ |
| "text": prompt, |
| "sampling_params": {"temperature": 0}, |
| "return_logprob": True, |
| "return_input_logprob": True, |
| "logprob_start_len": 0, |
| }, |
| ) |
|
|
| j = response.json() |
| completion_tokens = j["meta_info"]["completion_tokens"] |
| input_logprobs = j["meta_info"]["input_token_logprobs"] |
| output_logprobs = j["meta_info"]["output_token_logprobs"] |
|
|
| assert ( |
| len(output_logprobs) == completion_tokens |
| ), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}" |
| assert ( |
| len(input_logprobs) > 0 |
| ), f"input_logprobs should have at least one token, but got {len(input_logprobs)}" |
|
|
| def test_structured_output(self): |
| json_schema = json.dumps( |
| { |
| "type": "object", |
| "properties": { |
| "name": {"type": "string", "pattern": "^[\\w]+$"}, |
| "population": {"type": "integer"}, |
| }, |
| "required": ["name", "population"], |
| } |
| ) |
|
|
| |
| response = requests.post( |
| f"{self.lb_url}/generate", |
| json={ |
| "text": "Here is the information of the capital of France in the JSON format.\n", |
| "sampling_params": { |
| "temperature": 0, |
| "max_new_tokens": 64, |
| "json_schema": json_schema, |
| }, |
| }, |
| ) |
| output = response.json()["text"] |
| |
| json.loads(output) |
|
|
| def test_first_token_finish(self): |
| client = openai.Client(api_key="empty", base_url=f"{self.lb_url}/v1") |
| tokenizer = AutoTokenizer.from_pretrained(self.model) |
| eos_token = tokenizer.eos_token_id |
| prompt = "The best programming language for AI is" |
|
|
| |
| res = client.completions.create( |
| model="dummy", prompt=prompt, logit_bias={eos_token: 42} |
| ).model_dump() |
| print(f"{res=}") |
|
|
| assert res["usage"]["completion_tokens"] == 1, ( |
| "Expected completion_tokens to be 1 when first token is EOS, " |
| f"but got {res['usage']['completion_tokens']}" |
| ) |
|
|
| |
| res = client.completions.create( |
| model="dummy", |
| prompt=prompt, |
| logit_bias={eos_token: 42}, |
| extra_body={"ignore_eos": True}, |
| ).model_dump() |
| print(f"{res=}") |
|
|
| assert res["usage"]["completion_tokens"] > 1, ( |
| "Expected completion_tokens to be greater than 1 when ignore_eos is True, " |
| f"but got {res['usage']['completion_tokens']}" |
| ) |
|
|
| |
| stop_token_id = tokenizer.encode(" hello", add_special_tokens=False)[0] |
| res = client.completions.create( |
| model="dummy", |
| prompt=prompt, |
| logit_bias={stop_token_id: 42}, |
| stop=[" hello"], |
| ).model_dump() |
| print(f"{res=}") |
|
|
| assert res["usage"]["completion_tokens"] == 1, ( |
| "Expected completion_tokens to be 1 when first token is stop token, " |
| f"but got {res['usage']['completion_tokens']}" |
| ) |
|
|
|
|
| class TestDisaggregationMooncakeFailure(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05" |
|
|
| cls.model = DEFAULT_MODEL_NAME_FOR_TEST |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def tearDownClass(cls): |
| os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB") |
| super().tearDownClass() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "1", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
|
|
| |
| try: |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
| except Exception as e: |
| print(f"Test encountered expected errors: {e}") |
| |
| try: |
| response = requests.get(self.prefill_url + "/health_generate") |
| assert response.status_code == 200 |
| response = requests.get(self.decode_url + "/health_generate") |
| assert response.status_code == 200 |
| except Exception as health_check_error: |
| |
| raise e from health_check_error |
|
|
|
|
| class TestDisaggregationMooncakeSpec(PDDisaggregationServerBase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls.model = DEFAULT_TARGET_MODEL_EAGLE |
| cls.draft_model = DEFAULT_DRAFT_MODEL_EAGLE |
| cls.spec_args = [ |
| "--speculative-algorithm", |
| "EAGLE", |
| "--speculative-draft-model-path", |
| cls.draft_model, |
| "--speculative-num-steps", |
| "3", |
| "--speculative-eagle-topk", |
| "4", |
| "--speculative-num-draft-tokens", |
| "16", |
| "--cuda-graph-max-bs", |
| "8", |
| ] |
| print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| ] + cls.spec_args |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "1", |
| ] + cls.spec_args |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=2, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.20) |
|
|
|
|
| class TestDisaggregationSimulatedRetract(PDDisaggregationServerBase): |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| os.environ["SGLANG_TEST_RETRACT"] = "true" |
| cls.model = DEFAULT_MODEL_NAME_FOR_TEST |
|
|
| |
| cls.start_prefill() |
| cls.start_decode() |
|
|
| |
| cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) |
| cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) |
|
|
| cls.launch_lb() |
|
|
| @classmethod |
| def tearDownClass(cls): |
| os.environ.pop("SGLANG_TEST_RETRACT") |
| super().tearDownClass() |
|
|
| @classmethod |
| def start_prefill(cls): |
| prefill_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "prefill", |
| "--tp", |
| "1", |
| ] |
| prefill_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_prefill = popen_launch_pd_server( |
| cls.model, |
| cls.prefill_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=prefill_args, |
| ) |
|
|
| @classmethod |
| def start_decode(cls): |
| decode_args = [ |
| "--trust-remote-code", |
| "--disaggregation-mode", |
| "decode", |
| "--tp", |
| "1", |
| "--base-gpu-id", |
| "1", |
| ] |
| decode_args += cls.transfer_backend + cls.rdma_devices |
| cls.process_decode = popen_launch_pd_server( |
| cls.model, |
| cls.decode_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=decode_args, |
| ) |
|
|
| def test_gsm8k(self): |
| args = SimpleNamespace( |
| num_shots=5, |
| data_path=None, |
| num_questions=200, |
| max_new_tokens=512, |
| parallel=128, |
| host=f"http://{self.base_host}", |
| port=int(self.lb_port), |
| ) |
| metrics = run_eval_few_shot_gsm8k(args) |
| print(f"Evaluation metrics: {metrics}") |
|
|
| self.assertGreater(metrics["accuracy"], 0.62) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|