import json import tempfile import unittest from unittest.mock import MagicMock, patch from sglang.srt.server_args import PortArgs, ServerArgs, prepare_server_args from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import CustomTestCase register_cuda_ci(est_time=9, suite="stage-b-test-small-1-gpu") register_amd_ci(est_time=1, suite="stage-b-test-small-1-gpu-amd") class TestPrepareServerArgs(CustomTestCase): def test_prepare_server_args(self): server_args = prepare_server_args( [ "--model-path", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--json-model-override-args", '{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}', ] ) self.assertEqual( server_args.model_path, "meta-llama/Meta-Llama-3.1-8B-Instruct" ) self.assertEqual( json.loads(server_args.json_model_override_args), {"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}, ) class TestLoadBalanceMethod(unittest.TestCase): def test_non_pd_defaults_to_round_robin(self): server_args = ServerArgs(model_path="dummy", disaggregation_mode="null") self.assertEqual(server_args.load_balance_method, "round_robin") def test_pd_prefill_defaults_to_follow_bootstrap_room(self): server_args = ServerArgs(model_path="dummy", disaggregation_mode="prefill") self.assertEqual(server_args.load_balance_method, "follow_bootstrap_room") def test_pd_decode_defaults_to_round_robin(self): server_args = ServerArgs(model_path="dummy", disaggregation_mode="decode") self.assertEqual(server_args.load_balance_method, "round_robin") class TestPortArgs(unittest.TestCase): @patch("sglang.srt.server_args.get_free_port") @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile") def test_init_new_with_nccl_port_none(self, mock_temp_file, mock_get_free_port): """Test that get_free_port() is called when nccl_port is None""" mock_temp_file.return_value.name = "temp_file" mock_get_free_port.return_value = 45678 # Mock ephemeral port # Use MagicMock here to verify get_free_port is called server_args = MagicMock() server_args.nccl_port = None server_args.enable_dp_attention = False server_args.tokenizer_worker_num = 1 port_args = PortArgs.init_new(server_args) # Verify get_free_port was called mock_get_free_port.assert_called_once() # Verify the returned port is used self.assertEqual(port_args.nccl_port, 45678) @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile") def test_init_new_standard_case(self, mock_temp_file): mock_temp_file.return_value.name = "temp_file" server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = False port_args = PortArgs.init_new(server_args) self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://")) self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://")) self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://")) self.assertIsInstance(port_args.nccl_port, int) def test_init_new_with_single_node_dp_attention(self): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 1 server_args.dist_init_addr = None port_args = PortArgs.init_new(server_args) self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:")) self.assertTrue( port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:") ) self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:")) self.assertIsInstance(port_args.nccl_port, int) def test_init_new_with_dp_rank(self): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 1 server_args.dist_init_addr = "192.168.1.1:25000" worker_ports = [25006, 25007, 25008, 25009] port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports) self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008")) self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) self.assertIsInstance(port_args.nccl_port, int) def test_init_new_with_ipv4_address(self): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "192.168.1.1:25000" port_args = PortArgs.init_new(server_args) self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:")) self.assertTrue( port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:") ) self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:")) self.assertIsInstance(port_args.nccl_port, int) def test_init_new_with_malformed_ipv4_address(self): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "192.168.1.1" with self.assertRaises(AssertionError) as context: PortArgs.init_new(server_args) self.assertIn( "please provide --dist-init-addr as host:port", str(context.exception) ) def test_init_new_with_malformed_ipv4_address_invalid_port(self): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "192.168.1.1:abc" with self.assertRaises(ValueError): PortArgs.init_new(server_args) @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) def test_init_new_with_ipv6_address(self, mock_is_valid_ipv6): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "[2001:db8::1]:25000" port_args = PortArgs.init_new(server_args) self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")) self.assertTrue( port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:") ) self.assertTrue( port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:") ) self.assertIsInstance(port_args.nccl_port, int) @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False) def test_init_new_with_invalid_ipv6_address(self, mock_is_valid_ipv6): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "[invalid-ipv6]:25000" with self.assertRaises(ValueError) as context: PortArgs.init_new(server_args) self.assertIn("invalid IPv6 address", str(context.exception)) def test_init_new_with_malformed_ipv6_address_missing_bracket(self): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "[2001:db8::1:25000" with self.assertRaises(ValueError) as context: PortArgs.init_new(server_args) self.assertIn("invalid IPv6 address format", str(context.exception)) @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) def test_init_new_with_malformed_ipv6_address_missing_port( self, mock_is_valid_ipv6 ): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "[2001:db8::1]" with self.assertRaises(ValueError) as context: PortArgs.init_new(server_args) self.assertIn( "a port must be specified in IPv6 address", str(context.exception) ) @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) def test_init_new_with_malformed_ipv6_address_invalid_port( self, mock_is_valid_ipv6 ): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "[2001:db8::1]:abcde" with self.assertRaises(ValueError) as context: PortArgs.init_new(server_args) self.assertIn("invalid port in IPv6 address", str(context.exception)) @patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True) def test_init_new_with_malformed_ipv6_address_wrong_separator( self, mock_is_valid_ipv6 ): server_args = ServerArgs(model_path="dummy") server_args.port = 30000 server_args.nccl_port = None server_args.enable_dp_attention = True server_args.nnodes = 2 server_args.dist_init_addr = "[2001:db8::1]#25000" with self.assertRaises(ValueError) as context: PortArgs.init_new(server_args) self.assertIn("expected ':' after ']'", str(context.exception)) class TestSSLArgs(unittest.TestCase): def test_default_ssl_fields_are_none(self): server_args = ServerArgs(model_path="dummy") self.assertIsNone(server_args.ssl_keyfile) self.assertIsNone(server_args.ssl_certfile) self.assertIsNone(server_args.ssl_ca_certs) self.assertIsNone(server_args.ssl_keyfile_password) def test_ssl_keyfile_without_certfile_raises(self): with self.assertRaises(ValueError) as context: ServerArgs(model_path="dummy", ssl_keyfile="key.pem") self.assertIn("--ssl-certfile", str(context.exception)) def test_ssl_certfile_without_keyfile_raises(self): with self.assertRaises(ValueError) as context: ServerArgs(model_path="dummy", ssl_certfile="cert.pem") self.assertIn("--ssl-keyfile", str(context.exception)) @patch("os.path.isfile", return_value=True) def test_ssl_both_keyfile_and_certfile_accepted(self, _mock_isfile): server_args = ServerArgs( model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem" ) self.assertEqual(server_args.ssl_keyfile, "key.pem") self.assertEqual(server_args.ssl_certfile, "cert.pem") def test_url_returns_http_without_ssl(self): server_args = ServerArgs(model_path="dummy") self.assertTrue(server_args.url().startswith("http://")) def test_url_rewrites_all_interfaces_to_loopback(self): server_args = ServerArgs(model_path="dummy", host="0.0.0.0") self.assertEqual(server_args.url(), "http://127.0.0.1:30000") def test_url_rewrites_empty_host_to_loopback(self): server_args = ServerArgs(model_path="dummy", host="") self.assertEqual(server_args.url(), "http://127.0.0.1:30000") def test_url_rewrites_ipv6_all_interfaces_to_loopback(self): server_args = ServerArgs(model_path="dummy", host="::") self.assertEqual(server_args.url(), "http://[::1]:30000") @patch("os.path.isfile", return_value=True) def test_url_returns_https_with_ssl(self, _mock_isfile): server_args = ServerArgs( model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem" ) self.assertTrue(server_args.url().startswith("https://")) @patch("os.path.isfile", return_value=True) def test_ssl_cli_args_parsed(self, _mock_isfile): server_args = prepare_server_args( [ "--model-path", "dummy", "--ssl-keyfile", "key.pem", "--ssl-certfile", "cert.pem", "--ssl-ca-certs", "ca.pem", "--ssl-keyfile-password", "secret", ] ) self.assertEqual(server_args.ssl_keyfile, "key.pem") self.assertEqual(server_args.ssl_certfile, "cert.pem") self.assertEqual(server_args.ssl_ca_certs, "ca.pem") self.assertEqual(server_args.ssl_keyfile_password, "secret") def test_ssl_verify_without_ssl(self): server_args = ServerArgs(model_path="dummy") self.assertIs(server_args.ssl_verify(), True) @patch("os.path.isfile", return_value=True) def test_ssl_verify_with_ssl_no_ca(self, _mock_isfile): server_args = ServerArgs( model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem" ) self.assertIs(server_args.ssl_verify(), False) @patch("os.path.isfile", return_value=True) def test_ssl_verify_with_ssl_and_ca(self, _mock_isfile): server_args = ServerArgs( model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem", ssl_ca_certs="ca.pem", ) self.assertEqual(server_args.ssl_verify(), "ca.pem") def test_ssl_ca_certs_without_certfile_raises(self): with self.assertRaises(ValueError) as context: ServerArgs(model_path="dummy", ssl_ca_certs="ca.pem") self.assertIn("--ssl-ca-certs", str(context.exception)) def test_ssl_keyfile_password_without_certfile_raises(self): with self.assertRaises(ValueError) as context: ServerArgs(model_path="dummy", ssl_keyfile_password="secret") self.assertIn("--ssl-keyfile-password", str(context.exception)) def test_ssl_keyfile_not_found_raises(self): with self.assertRaises(ValueError) as context: ServerArgs( model_path="dummy", ssl_keyfile="/nonexistent/key.pem", ssl_certfile="/nonexistent/cert.pem", ) self.assertIn("not found", str(context.exception)) def test_ssl_certfile_not_found_raises(self): with tempfile.NamedTemporaryFile(suffix=".pem") as keyfile: with self.assertRaises(ValueError) as context: ServerArgs( model_path="dummy", ssl_keyfile=keyfile.name, ssl_certfile="/nonexistent/cert.pem", ) self.assertIn("SSL certificate file not found", str(context.exception)) def test_ssl_ca_certs_not_found_raises(self): with tempfile.NamedTemporaryFile(suffix=".pem") as keyfile: with tempfile.NamedTemporaryFile(suffix=".pem") as certfile: with self.assertRaises(ValueError) as context: ServerArgs( model_path="dummy", ssl_keyfile=keyfile.name, ssl_certfile=certfile.name, ssl_ca_certs="/nonexistent/ca.pem", ) self.assertIn( "SSL CA certificates file not found", str(context.exception) ) @patch("os.path.isfile", return_value=True) def test_url_returns_https_with_ssl_and_ipv6(self, _mock_isfile): server_args = ServerArgs( model_path="dummy", host="::1", ssl_keyfile="key.pem", ssl_certfile="cert.pem", ) self.assertEqual(server_args.url(), "https://[::1]:30000") def test_enable_ssl_refresh_default_false(self): server_args = ServerArgs(model_path="dummy") self.assertFalse(server_args.enable_ssl_refresh) def test_enable_ssl_refresh_without_ssl_raises(self): with self.assertRaises(ValueError) as context: ServerArgs(model_path="dummy", enable_ssl_refresh=True) self.assertIn("--enable-ssl-refresh", str(context.exception)) self.assertIn("--ssl-certfile", str(context.exception)) @patch("os.path.isfile", return_value=True) def test_enable_ssl_refresh_with_ssl_accepted(self, _mock_isfile): server_args = ServerArgs( model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem", enable_ssl_refresh=True, ) self.assertTrue(server_args.enable_ssl_refresh) @patch("os.path.isfile", return_value=True) def test_enable_ssl_refresh_cli_flag(self, _mock_isfile): server_args = prepare_server_args( [ "--model-path", "dummy", "--ssl-keyfile", "key.pem", "--ssl-certfile", "cert.pem", "--enable-ssl-refresh", ] ) self.assertTrue(server_args.enable_ssl_refresh) if __name__ == "__main__": unittest.main()