| import copy |
| import unittest |
|
|
| from sglang.srt.managers.io_struct import GenerateReqInput |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
| from sglang.test.test_utils import ( |
| DEFAULT_SMALL_MODEL_NAME_FOR_TEST, |
| DEFAULT_URL_FOR_TEST, |
| CustomTestCase, |
| ) |
|
|
| register_cuda_ci(est_time=8, suite="stage-b-test-large-1-gpu") |
| register_amd_ci(est_time=8, suite="stage-b-test-small-1-gpu-amd") |
|
|
|
|
| class TestGenerateReqInputNormalization(CustomTestCase): |
| """Test the normalization of GenerateReqInput for batch processing and different input formats.""" |
|
|
| @classmethod |
| def setUpClass(cls): |
| cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
| cls.base_url = DEFAULT_URL_FOR_TEST |
|
|
| def setUp(self): |
| |
| self.base_req = GenerateReqInput( |
| text=["Hello", "World"], |
| sampling_params=[{}, {}], |
| rid=["id1", "id2"], |
| ) |
|
|
| def test_single_image_to_list_of_lists(self): |
| """Test that a single image is converted to a list of single-image lists.""" |
| req = copy.deepcopy(self.base_req) |
| req.image_data = "single_image.jpg" |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 2) |
| self.assertEqual(len(req.image_data[0]), 1) |
| self.assertEqual(len(req.image_data[1]), 1) |
| self.assertEqual(req.image_data[0][0], "single_image.jpg") |
| self.assertEqual(req.image_data[1][0], "single_image.jpg") |
|
|
| |
| self.assertEqual(req.modalities, ["image", "image"]) |
|
|
| def test_list_of_images_to_list_of_lists(self): |
| """Test that a list of images is converted to a list of single-image lists.""" |
| req = copy.deepcopy(self.base_req) |
| req.image_data = ["image1.jpg", "image2.jpg"] |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 2) |
| self.assertEqual(len(req.image_data[0]), 1) |
| self.assertEqual(len(req.image_data[1]), 1) |
| self.assertEqual(req.image_data[0][0], "image1.jpg") |
| self.assertEqual(req.image_data[1][0], "image2.jpg") |
|
|
| |
| self.assertEqual(req.modalities, ["image", "image"]) |
|
|
| def test_list_of_lists_with_different_modalities(self): |
| """Test handling of list of lists of images with different modalities.""" |
| req = copy.deepcopy(self.base_req) |
| req.image_data = [ |
| ["image1.jpg"], |
| ["image2.jpg", "image3.jpg"], |
| ] |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 2) |
| self.assertEqual(len(req.image_data[0]), 1) |
| self.assertEqual(len(req.image_data[1]), 2) |
|
|
| |
| self.assertEqual(req.modalities, ["image", "multi-images"]) |
|
|
| def test_list_of_lists_with_none_values(self): |
| """Test handling of list of lists with None values.""" |
| req = copy.deepcopy(self.base_req) |
| req.image_data = [ |
| [None], |
| ["image.jpg"], |
| ] |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 2) |
| self.assertEqual(len(req.image_data[0]), 1) |
| self.assertEqual(len(req.image_data[1]), 1) |
|
|
| |
| self.assertEqual(req.modalities, [None, "image"]) |
|
|
| def test_expanding_parallel_sample_correlation(self): |
| """Test that when expanding with parallel samples, prompts, images and modalities are properly correlated.""" |
| req = copy.deepcopy(self.base_req) |
| req.text = ["Prompt 1", "Prompt 2"] |
| req.image_data = [ |
| ["image1.jpg"], |
| ["image2.jpg", "image3.jpg"], |
| ] |
| req.sampling_params = {"n": 3} |
|
|
| |
| expected_text = req.text * 3 |
| expected_images = req.image_data * 3 |
| expected_modalities = ["image", "multi-images"] * 3 |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 6) |
|
|
| |
| self.assertEqual(req.image_data, expected_images) |
|
|
| |
| self.assertEqual(req.modalities, expected_modalities) |
|
|
| |
| self.assertEqual(req.text, expected_text) |
|
|
| def test_specific_parallel_n_per_sample(self): |
| """Test parallel expansion when different samples have different n values.""" |
| req = copy.deepcopy(self.base_req) |
| req.text = ["Prompt 1", "Prompt 2"] |
| req.image_data = [ |
| ["image1.jpg"], |
| ["image2.jpg", "image3.jpg"], |
| ] |
| req.sampling_params = [ |
| {"n": 2}, |
| {"n": 2}, |
| ] |
|
|
| expected_images = req.image_data * 2 |
| expected_modalities = ["image", "multi-images"] * 2 |
| expected_text = req.text * 2 |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 4) |
|
|
| |
| self.assertEqual(req.image_data, expected_images) |
|
|
| |
| self.assertEqual(req.modalities, expected_modalities) |
|
|
| |
| self.assertEqual(req.text, expected_text) |
|
|
| def test_mixed_none_and_images_with_parallel_samples(self): |
| """Test that when some batch items have images and others None, parallel expansion works correctly.""" |
| req = copy.deepcopy(self.base_req) |
| req.text = ["Prompt 1", "Prompt 2", "Prompt 3"] |
| req.rid = ["id1", "id2", "id3"] |
| req.image_data = [ |
| ["image1.jpg"], |
| None, |
| ["image3_1.jpg", "image3_2.jpg"], |
| ] |
| req.sampling_params = {"n": 2} |
|
|
| expected_images = req.image_data * 2 |
| expected_modalities = ["image", None, "multi-images"] * 2 |
| expected_text = req.text * 2 |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.image_data), 6) |
|
|
| |
| self.assertEqual(req.image_data, expected_images) |
|
|
| |
| self.assertEqual(req.modalities, expected_modalities) |
|
|
| |
| self.assertEqual(req.text, expected_text) |
|
|
| def test_correlation_with_sampling_params(self): |
| """Test that sampling parameters are correctly correlated with prompts during expansion.""" |
| req = copy.deepcopy(self.base_req) |
| req.text = ["Prompt 1", "Prompt 2"] |
| req.image_data = [ |
| ["image1.jpg"], |
| ["image2.jpg"], |
| ] |
| req.sampling_params = [ |
| {"temperature": 0.7, "n": 2}, |
| {"temperature": 0.9, "n": 2}, |
| ] |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.sampling_params), 4) |
| self.assertEqual(req.sampling_params[0]["temperature"], 0.7) |
| self.assertEqual(req.sampling_params[1]["temperature"], 0.9) |
| self.assertEqual(req.sampling_params[2]["temperature"], 0.7) |
| self.assertEqual(req.sampling_params[3]["temperature"], 0.9) |
|
|
| |
| self.assertEqual(len(req.image_data), 4) |
|
|
| |
| self.assertEqual(req.image_data[0], ["image1.jpg"]) |
| self.assertEqual(req.image_data[1], ["image2.jpg"]) |
| self.assertEqual(req.image_data[2], ["image1.jpg"]) |
| self.assertEqual(req.image_data[3], ["image2.jpg"]) |
|
|
| def test_single_example_with_image(self): |
| """Test handling of single example with image.""" |
| req = GenerateReqInput( |
| text="Hello", |
| image_data="single_image.jpg", |
| ) |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(req.image_data, "single_image.jpg") |
| self.assertIsNone(req.modalities) |
|
|
| def test_single_to_batch_with_parallel_sampling(self): |
| """Test single example converted to batch with parallel sampling.""" |
| req = GenerateReqInput( |
| text="Hello", |
| image_data="single_image.jpg", |
| sampling_params={"n": 3}, |
| ) |
|
|
| |
| expected_text = ["Hello"] * 3 |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(req.text, expected_text) |
|
|
| |
| self.assertEqual(len(req.image_data), 3) |
| self.assertEqual(req.image_data[0][0], "single_image.jpg") |
| self.assertEqual(req.image_data[1][0], "single_image.jpg") |
| self.assertEqual(req.image_data[2][0], "single_image.jpg") |
|
|
| |
| self.assertEqual(req.modalities, ["image", "image", "image"]) |
|
|
| def test_audio_data_handling(self): |
| """Test handling of audio_data.""" |
| req = copy.deepcopy(self.base_req) |
| req.audio_data = "audio.mp3" |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.audio_data), 2) |
| self.assertEqual(req.audio_data[0], "audio.mp3") |
| self.assertEqual(req.audio_data[1], "audio.mp3") |
|
|
| |
| req = copy.deepcopy(self.base_req) |
| req.audio_data = ["audio1.mp3", "audio2.mp3"] |
|
|
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertEqual(len(req.audio_data), 2) |
| self.assertEqual(req.audio_data[0], "audio1.mp3") |
| self.assertEqual(req.audio_data[1], "audio2.mp3") |
|
|
| def test_input_ids_normalization(self): |
| """Test normalization of input_ids instead of text.""" |
| |
| req = GenerateReqInput(input_ids=[1, 2, 3]) |
| req.normalize_batch_and_arguments() |
| self.assertTrue(req.is_single) |
| self.assertEqual(req.batch_size, 1) |
|
|
| |
| req = GenerateReqInput(input_ids=[[1, 2, 3], [4, 5, 6]]) |
| req.normalize_batch_and_arguments() |
| self.assertFalse(req.is_single) |
| self.assertEqual(req.batch_size, 2) |
|
|
| |
| req = GenerateReqInput( |
| input_ids=[[1, 2, 3], [4, 5, 6]], sampling_params={"n": 2} |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(len(req.input_ids), 4) |
|
|
| def test_input_embeds_normalization(self): |
| """Test normalization of input_embeds.""" |
| |
| req = GenerateReqInput(input_embeds=[[0.1, 0.2], [0.3, 0.4]]) |
| req.normalize_batch_and_arguments() |
| self.assertTrue(req.is_single) |
| self.assertEqual(req.batch_size, 1) |
|
|
| |
| req = GenerateReqInput(input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]]) |
| req.normalize_batch_and_arguments() |
| self.assertFalse(req.is_single) |
| self.assertEqual(req.batch_size, 2) |
|
|
| def test_input_embeds_with_parallel_sampling(self): |
| """Test input_embeds normalization with parallel sampling (n > 1).""" |
| |
| req = GenerateReqInput( |
| input_embeds=[[0.1, 0.2]], |
| sampling_params={"n": 2}, |
| ) |
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertFalse(req.is_single) |
| self.assertEqual(len(req.input_embeds), 2) |
| |
| self.assertEqual(req.input_embeds[0], [[0.1, 0.2]]) |
| self.assertEqual(req.input_embeds[1], [[0.1, 0.2]]) |
|
|
| |
| req = GenerateReqInput( |
| input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], sampling_params={"n": 3} |
| ) |
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertFalse(req.is_single) |
| self.assertEqual(len(req.input_embeds), 6) |
|
|
| |
| expected_embeds = [[[0.1, 0.2]], [[0.3, 0.4]]] * 3 |
| self.assertEqual(req.input_embeds, expected_embeds) |
|
|
| |
| req = GenerateReqInput( |
| input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]], |
| sampling_params=[{"n": 2}, {"n": 3}], |
| ) |
| with self.assertRaises(ValueError): |
| req.normalize_batch_and_arguments() |
|
|
| def test_input_embeds_single_to_batch_conversion(self): |
| """Test that single input_embeds are properly converted to batch when using parallel sampling.""" |
| |
| req = GenerateReqInput( |
| input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 2} |
| ) |
| req.normalize_batch_and_arguments() |
|
|
| |
| self.assertFalse(req.is_single) |
| self.assertEqual(len(req.input_embeds), 2) |
|
|
| |
| self.assertEqual(req.input_embeds[0], [[0.1, 0.2, 0.3]]) |
| self.assertEqual(req.input_embeds[1], [[0.1, 0.2, 0.3]]) |
|
|
| |
| req = GenerateReqInput(input_embeds=[[0.1, 0.2, 0.3]], sampling_params={"n": 5}) |
| req.normalize_batch_and_arguments() |
|
|
| self.assertFalse(req.is_single) |
| self.assertEqual(len(req.input_embeds), 5) |
|
|
| |
| for i in range(5): |
| self.assertEqual(req.input_embeds[i], [[0.1, 0.2, 0.3]]) |
|
|
| def test_lora_path_normalization(self): |
| """Test normalization of lora_path.""" |
| |
| req = GenerateReqInput(text=["Hello", "World"], lora_path="path/to/lora") |
|
|
| |
| expected_lora_paths = ["path/to/lora", "path/to/lora"] |
|
|
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.lora_path, expected_lora_paths) |
|
|
| |
| req = GenerateReqInput(text=["Hello", "World"], lora_path=["path1", "path2"]) |
|
|
| |
| expected_lora_paths = ["path1", "path2"] |
|
|
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.lora_path, expected_lora_paths) |
|
|
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], |
| lora_path=["path1", "path2"], |
| sampling_params={"n": 2}, |
| ) |
|
|
| |
| expected_lora_paths = ["path1", "path2"] * 2 |
|
|
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.lora_path, expected_lora_paths) |
|
|
| def test_logprob_parameters_normalization(self): |
| """Test normalization of logprob-related parameters.""" |
| |
| req = GenerateReqInput( |
| text="Hello", |
| return_logprob=True, |
| logprob_start_len=10, |
| top_logprobs_num=5, |
| token_ids_logprob=[7, 8, 9], |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.return_logprob, True) |
| self.assertEqual(req.logprob_start_len, 10) |
| self.assertEqual(req.top_logprobs_num, 5) |
| self.assertEqual(req.token_ids_logprob, [7, 8, 9]) |
|
|
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], |
| return_logprob=True, |
| logprob_start_len=10, |
| top_logprobs_num=5, |
| token_ids_logprob=[7, 8, 9], |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.return_logprob, [True, True]) |
| self.assertEqual(req.logprob_start_len, [10, 10]) |
| self.assertEqual(req.top_logprobs_num, [5, 5]) |
| self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [7, 8, 9]]) |
|
|
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], |
| return_logprob=[True, False], |
| logprob_start_len=[10, 5], |
| top_logprobs_num=[5, 3], |
| token_ids_logprob=[[7, 8, 9], [4, 5, 6]], |
| return_hidden_states=[False, False, True], |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.return_logprob, [True, False]) |
| self.assertEqual(req.logprob_start_len, [10, 5]) |
| self.assertEqual(req.top_logprobs_num, [5, 3]) |
| self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]]) |
| self.assertEqual(req.return_hidden_states, [False, False, True]) |
|
|
| def test_custom_logit_processor_normalization(self): |
| """Test normalization of custom_logit_processor.""" |
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], custom_logit_processor="serialized_processor" |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual( |
| req.custom_logit_processor, ["serialized_processor", "serialized_processor"] |
| ) |
|
|
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], custom_logit_processor=["processor1", "processor2"] |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.custom_logit_processor, ["processor1", "processor2"]) |
|
|
| def test_session_params_handling(self): |
| """Test handling of session_params.""" |
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], session_params={"id": "session1", "offset": 10} |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.session_params, {"id": "session1", "offset": 10}) |
|
|
| |
| req = GenerateReqInput( |
| text=["Hello", "World"], |
| session_params=[{"id": "session1"}, {"id": "session2"}], |
| ) |
| req.normalize_batch_and_arguments() |
| self.assertEqual(req.session_params, [{"id": "session1"}, {"id": "session2"}]) |
|
|
| def test_getitem_method(self): |
| """Test the __getitem__ method.""" |
| req = GenerateReqInput( |
| text=["Hello", "World"], |
| image_data=[["img1.jpg"], ["img2.jpg"]], |
| audio_data=["audio1.mp3", "audio2.mp3"], |
| sampling_params=[{"temp": 0.7}, {"temp": 0.8}], |
| rid=["id1", "id2"], |
| return_logprob=[True, False], |
| logprob_start_len=[10, 5], |
| top_logprobs_num=[5, 3], |
| token_ids_logprob=[[7, 8, 9], [4, 5, 6]], |
| stream=True, |
| log_metrics=True, |
| modalities=["image", "image"], |
| lora_path=["path1", "path2"], |
| custom_logit_processor=["processor1", "processor2"], |
| return_hidden_states=True, |
| ) |
| req.normalize_batch_and_arguments() |
|
|
| |
| item0 = req[0] |
| self.assertEqual(item0.text, "Hello") |
| self.assertEqual(item0.image_data, ["img1.jpg"]) |
| self.assertEqual(item0.audio_data, "audio1.mp3") |
| self.assertEqual(item0.sampling_params, {"temp": 0.7}) |
| self.assertEqual(item0.rid, "id1") |
| self.assertEqual(item0.return_logprob, True) |
| self.assertEqual(item0.logprob_start_len, 10) |
| self.assertEqual(item0.top_logprobs_num, 5) |
| self.assertEqual(item0.token_ids_logprob, [7, 8, 9]) |
| self.assertEqual(item0.stream, True) |
| self.assertEqual(item0.log_metrics, True) |
| self.assertEqual(item0.modalities, "image") |
| self.assertEqual(item0.lora_path, "path1") |
| self.assertEqual(item0.custom_logit_processor, "processor1") |
| self.assertEqual(item0.return_hidden_states, True) |
|
|
| def test_regenerate_rid(self): |
| """Test the regenerate_rid method.""" |
| req = GenerateReqInput(text="Hello") |
| req.normalize_batch_and_arguments() |
|
|
| original_rid = req.rid |
| new_rid = req.regenerate_rid() |
|
|
| self.assertNotEqual(original_rid, new_rid) |
| self.assertEqual(req.rid, new_rid) |
|
|
| def test_error_cases(self): |
| """Test various error cases.""" |
| |
| with self.assertRaises(ValueError): |
| req = GenerateReqInput() |
| req.normalize_batch_and_arguments() |
|
|
| |
| with self.assertRaises(ValueError): |
| req = GenerateReqInput( |
| text="Hello", input_ids=[1, 2, 3], input_embeds=[[0.1, 0.2]] |
| ) |
| req.normalize_batch_and_arguments() |
|
|
| def test_multiple_input_formats(self): |
| """Test different combinations of input formats.""" |
| |
| req = GenerateReqInput(text="Hello") |
| req.normalize_batch_and_arguments() |
| self.assertTrue(req.is_single) |
|
|
| |
| req = GenerateReqInput(input_ids=[1, 2, 3]) |
| req.normalize_batch_and_arguments() |
| self.assertTrue(req.is_single) |
|
|
| |
| req = GenerateReqInput(input_embeds=[[0.1, 0.2]]) |
| req.normalize_batch_and_arguments() |
| self.assertTrue(req.is_single) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|