| import json |
| import tempfile |
| import unittest |
| from unittest.mock import MagicMock, patch |
|
|
| from sglang.srt.server_args import PortArgs, ServerArgs, prepare_server_args |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
| from sglang.test.test_utils import CustomTestCase |
|
|
| register_cuda_ci(est_time=9, suite="stage-b-test-small-1-gpu") |
| register_amd_ci(est_time=1, suite="stage-b-test-small-1-gpu-amd") |
|
|
|
|
| class TestPrepareServerArgs(CustomTestCase): |
| def test_prepare_server_args(self): |
| server_args = prepare_server_args( |
| [ |
| "--model-path", |
| "meta-llama/Meta-Llama-3.1-8B-Instruct", |
| "--json-model-override-args", |
| '{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}', |
| ] |
| ) |
| self.assertEqual( |
| server_args.model_path, "meta-llama/Meta-Llama-3.1-8B-Instruct" |
| ) |
| self.assertEqual( |
| json.loads(server_args.json_model_override_args), |
| {"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}, |
| ) |
|
|
|
|
| class TestLoadBalanceMethod(unittest.TestCase): |
| def test_non_pd_defaults_to_round_robin(self): |
| server_args = ServerArgs(model_path="dummy", disaggregation_mode="null") |
| self.assertEqual(server_args.load_balance_method, "round_robin") |
|
|
| def test_pd_prefill_defaults_to_follow_bootstrap_room(self): |
| server_args = ServerArgs(model_path="dummy", disaggregation_mode="prefill") |
| self.assertEqual(server_args.load_balance_method, "follow_bootstrap_room") |
|
|
| def test_pd_decode_defaults_to_round_robin(self): |
| server_args = ServerArgs(model_path="dummy", disaggregation_mode="decode") |
| self.assertEqual(server_args.load_balance_method, "round_robin") |
|
|
|
|
| class TestPortArgs(unittest.TestCase): |
| @patch("sglang.srt.server_args.get_free_port") |
| @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile") |
| def test_init_new_with_nccl_port_none(self, mock_temp_file, mock_get_free_port): |
| """Test that get_free_port() is called when nccl_port is None""" |
| mock_temp_file.return_value.name = "temp_file" |
| mock_get_free_port.return_value = 45678 |
|
|
| |
| server_args = MagicMock() |
| server_args.nccl_port = None |
| server_args.enable_dp_attention = False |
| server_args.tokenizer_worker_num = 1 |
|
|
| port_args = PortArgs.init_new(server_args) |
|
|
| |
| mock_get_free_port.assert_called_once() |
|
|
| |
| self.assertEqual(port_args.nccl_port, 45678) |
|
|
| @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile") |
| def test_init_new_standard_case(self, mock_temp_file): |
| mock_temp_file.return_value.name = "temp_file" |
|
|
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
| server_args.enable_dp_attention = False |
|
|
| port_args = PortArgs.init_new(server_args) |
|
|
| self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://")) |
| self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://")) |
| self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://")) |
| self.assertIsInstance(port_args.nccl_port, int) |
|
|
| def test_init_new_with_single_node_dp_attention(self): |
|
|
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
| server_args.enable_dp_attention = True |
| server_args.nnodes = 1 |
| server_args.dist_init_addr = None |
|
|
| port_args = PortArgs.init_new(server_args) |
|
|
| self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:")) |
| self.assertTrue( |
| port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:") |
| ) |
| self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:")) |
| self.assertIsInstance(port_args.nccl_port, int) |
|
|
| def test_init_new_with_dp_rank(self): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
| server_args.enable_dp_attention = True |
| server_args.nnodes = 1 |
| server_args.dist_init_addr = "192.168.1.1:25000" |
|
|
| worker_ports = [25006, 25007, 25008, 25009] |
| port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports) |
|
|
| self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008")) |
|
|
| self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) |
| self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) |
| self.assertIsInstance(port_args.nccl_port, int) |
|
|
| def test_init_new_with_ipv4_address(self): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "192.168.1.1:25000" |
|
|
| port_args = PortArgs.init_new(server_args) |
|
|
| self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) |
| self.assertTrue( |
| port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:") |
| ) |
| self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) |
| self.assertIsInstance(port_args.nccl_port, int) |
|
|
| def test_init_new_with_malformed_ipv4_address(self): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "192.168.1.1" |
|
|
| with self.assertRaises(AssertionError) as context: |
| PortArgs.init_new(server_args) |
|
|
| self.assertIn( |
| "please provide --dist-init-addr as host:port", str(context.exception) |
| ) |
|
|
| def test_init_new_with_malformed_ipv4_address_invalid_port(self): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "192.168.1.1:abc" |
|
|
| with self.assertRaises(ValueError): |
| PortArgs.init_new(server_args) |
|
|
| @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) |
| def test_init_new_with_ipv6_address(self, mock_is_valid_ipv6): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "[2001:db8::1]:25000" |
|
|
| port_args = PortArgs.init_new(server_args) |
|
|
| self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")) |
| self.assertTrue( |
| port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:") |
| ) |
| self.assertTrue( |
| port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:") |
| ) |
| self.assertIsInstance(port_args.nccl_port, int) |
|
|
| @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False) |
| def test_init_new_with_invalid_ipv6_address(self, mock_is_valid_ipv6): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "[invalid-ipv6]:25000" |
|
|
| with self.assertRaises(ValueError) as context: |
| PortArgs.init_new(server_args) |
|
|
| self.assertIn("invalid IPv6 address", str(context.exception)) |
|
|
| def test_init_new_with_malformed_ipv6_address_missing_bracket(self): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "[2001:db8::1:25000" |
|
|
| with self.assertRaises(ValueError) as context: |
| PortArgs.init_new(server_args) |
|
|
| self.assertIn("invalid IPv6 address format", str(context.exception)) |
|
|
| @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) |
| def test_init_new_with_malformed_ipv6_address_missing_port( |
| self, mock_is_valid_ipv6 |
| ): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "[2001:db8::1]" |
|
|
| with self.assertRaises(ValueError) as context: |
| PortArgs.init_new(server_args) |
|
|
| self.assertIn( |
| "a port must be specified in IPv6 address", str(context.exception) |
| ) |
|
|
| @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) |
| def test_init_new_with_malformed_ipv6_address_invalid_port( |
| self, mock_is_valid_ipv6 |
| ): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "[2001:db8::1]:abcde" |
|
|
| with self.assertRaises(ValueError) as context: |
| PortArgs.init_new(server_args) |
|
|
| self.assertIn("invalid port in IPv6 address", str(context.exception)) |
|
|
| @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) |
| def test_init_new_with_malformed_ipv6_address_wrong_separator( |
| self, mock_is_valid_ipv6 |
| ): |
| server_args = ServerArgs(model_path="dummy") |
| server_args.port = 30000 |
| server_args.nccl_port = None |
|
|
| server_args.enable_dp_attention = True |
| server_args.nnodes = 2 |
| server_args.dist_init_addr = "[2001:db8::1]#25000" |
|
|
| with self.assertRaises(ValueError) as context: |
| PortArgs.init_new(server_args) |
|
|
| self.assertIn("expected ':' after ']'", str(context.exception)) |
|
|
|
|
| class TestSSLArgs(unittest.TestCase): |
| def test_default_ssl_fields_are_none(self): |
| server_args = ServerArgs(model_path="dummy") |
| self.assertIsNone(server_args.ssl_keyfile) |
| self.assertIsNone(server_args.ssl_certfile) |
| self.assertIsNone(server_args.ssl_ca_certs) |
| self.assertIsNone(server_args.ssl_keyfile_password) |
|
|
| def test_ssl_keyfile_without_certfile_raises(self): |
| with self.assertRaises(ValueError) as context: |
| ServerArgs(model_path="dummy", ssl_keyfile="key.pem") |
| self.assertIn("--ssl-certfile", str(context.exception)) |
|
|
| def test_ssl_certfile_without_keyfile_raises(self): |
| with self.assertRaises(ValueError) as context: |
| ServerArgs(model_path="dummy", ssl_certfile="cert.pem") |
| self.assertIn("--ssl-keyfile", str(context.exception)) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_ssl_both_keyfile_and_certfile_accepted(self, _mock_isfile): |
| server_args = ServerArgs( |
| model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem" |
| ) |
| self.assertEqual(server_args.ssl_keyfile, "key.pem") |
| self.assertEqual(server_args.ssl_certfile, "cert.pem") |
|
|
| def test_url_returns_http_without_ssl(self): |
| server_args = ServerArgs(model_path="dummy") |
| self.assertTrue(server_args.url().startswith("http://")) |
|
|
| def test_url_rewrites_all_interfaces_to_loopback(self): |
| server_args = ServerArgs(model_path="dummy", host="0.0.0.0") |
| self.assertEqual(server_args.url(), "http://127.0.0.1:30000") |
|
|
| def test_url_rewrites_empty_host_to_loopback(self): |
| server_args = ServerArgs(model_path="dummy", host="") |
| self.assertEqual(server_args.url(), "http://127.0.0.1:30000") |
|
|
| def test_url_rewrites_ipv6_all_interfaces_to_loopback(self): |
| server_args = ServerArgs(model_path="dummy", host="::") |
| self.assertEqual(server_args.url(), "http://[::1]:30000") |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_url_returns_https_with_ssl(self, _mock_isfile): |
| server_args = ServerArgs( |
| model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem" |
| ) |
| self.assertTrue(server_args.url().startswith("https://")) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_ssl_cli_args_parsed(self, _mock_isfile): |
| server_args = prepare_server_args( |
| [ |
| "--model-path", |
| "dummy", |
| "--ssl-keyfile", |
| "key.pem", |
| "--ssl-certfile", |
| "cert.pem", |
| "--ssl-ca-certs", |
| "ca.pem", |
| "--ssl-keyfile-password", |
| "secret", |
| ] |
| ) |
| self.assertEqual(server_args.ssl_keyfile, "key.pem") |
| self.assertEqual(server_args.ssl_certfile, "cert.pem") |
| self.assertEqual(server_args.ssl_ca_certs, "ca.pem") |
| self.assertEqual(server_args.ssl_keyfile_password, "secret") |
|
|
| def test_ssl_verify_without_ssl(self): |
| server_args = ServerArgs(model_path="dummy") |
| self.assertIs(server_args.ssl_verify(), True) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_ssl_verify_with_ssl_no_ca(self, _mock_isfile): |
| server_args = ServerArgs( |
| model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem" |
| ) |
| self.assertIs(server_args.ssl_verify(), False) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_ssl_verify_with_ssl_and_ca(self, _mock_isfile): |
| server_args = ServerArgs( |
| model_path="dummy", |
| ssl_keyfile="key.pem", |
| ssl_certfile="cert.pem", |
| ssl_ca_certs="ca.pem", |
| ) |
| self.assertEqual(server_args.ssl_verify(), "ca.pem") |
|
|
| def test_ssl_ca_certs_without_certfile_raises(self): |
| with self.assertRaises(ValueError) as context: |
| ServerArgs(model_path="dummy", ssl_ca_certs="ca.pem") |
| self.assertIn("--ssl-ca-certs", str(context.exception)) |
|
|
| def test_ssl_keyfile_password_without_certfile_raises(self): |
| with self.assertRaises(ValueError) as context: |
| ServerArgs(model_path="dummy", ssl_keyfile_password="secret") |
| self.assertIn("--ssl-keyfile-password", str(context.exception)) |
|
|
| def test_ssl_keyfile_not_found_raises(self): |
| with self.assertRaises(ValueError) as context: |
| ServerArgs( |
| model_path="dummy", |
| ssl_keyfile="/nonexistent/key.pem", |
| ssl_certfile="/nonexistent/cert.pem", |
| ) |
| self.assertIn("not found", str(context.exception)) |
|
|
| def test_ssl_certfile_not_found_raises(self): |
| with tempfile.NamedTemporaryFile(suffix=".pem") as keyfile: |
| with self.assertRaises(ValueError) as context: |
| ServerArgs( |
| model_path="dummy", |
| ssl_keyfile=keyfile.name, |
| ssl_certfile="/nonexistent/cert.pem", |
| ) |
| self.assertIn("SSL certificate file not found", str(context.exception)) |
|
|
| def test_ssl_ca_certs_not_found_raises(self): |
| with tempfile.NamedTemporaryFile(suffix=".pem") as keyfile: |
| with tempfile.NamedTemporaryFile(suffix=".pem") as certfile: |
| with self.assertRaises(ValueError) as context: |
| ServerArgs( |
| model_path="dummy", |
| ssl_keyfile=keyfile.name, |
| ssl_certfile=certfile.name, |
| ssl_ca_certs="/nonexistent/ca.pem", |
| ) |
| self.assertIn( |
| "SSL CA certificates file not found", str(context.exception) |
| ) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_url_returns_https_with_ssl_and_ipv6(self, _mock_isfile): |
| server_args = ServerArgs( |
| model_path="dummy", |
| host="::1", |
| ssl_keyfile="key.pem", |
| ssl_certfile="cert.pem", |
| ) |
| self.assertEqual(server_args.url(), "https://[::1]:30000") |
|
|
| def test_enable_ssl_refresh_default_false(self): |
| server_args = ServerArgs(model_path="dummy") |
| self.assertFalse(server_args.enable_ssl_refresh) |
|
|
| def test_enable_ssl_refresh_without_ssl_raises(self): |
| with self.assertRaises(ValueError) as context: |
| ServerArgs(model_path="dummy", enable_ssl_refresh=True) |
| self.assertIn("--enable-ssl-refresh", str(context.exception)) |
| self.assertIn("--ssl-certfile", str(context.exception)) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_enable_ssl_refresh_with_ssl_accepted(self, _mock_isfile): |
| server_args = ServerArgs( |
| model_path="dummy", |
| ssl_keyfile="key.pem", |
| ssl_certfile="cert.pem", |
| enable_ssl_refresh=True, |
| ) |
| self.assertTrue(server_args.enable_ssl_refresh) |
|
|
| @patch("os.path.isfile", return_value=True) |
| def test_enable_ssl_refresh_cli_flag(self, _mock_isfile): |
| server_args = prepare_server_args( |
| [ |
| "--model-path", |
| "dummy", |
| "--ssl-keyfile", |
| "key.pem", |
| "--ssl-certfile", |
| "cert.pem", |
| "--enable-ssl-refresh", |
| ] |
| ) |
| self.assertTrue(server_args.enable_ssl_refresh) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|