Hanrui / sglang /test /registered /core /test_server_args.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
import json
import tempfile
import unittest
from unittest.mock import MagicMock, patch
from sglang.srt.server_args import PortArgs, ServerArgs, prepare_server_args
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.test_utils import CustomTestCase
register_cuda_ci(est_time=9, suite="stage-b-test-small-1-gpu")
register_amd_ci(est_time=1, suite="stage-b-test-small-1-gpu-amd")
class TestPrepareServerArgs(CustomTestCase):
def test_prepare_server_args(self):
server_args = prepare_server_args(
[
"--model-path",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--json-model-override-args",
'{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}}',
]
)
self.assertEqual(
server_args.model_path, "meta-llama/Meta-Llama-3.1-8B-Instruct"
)
self.assertEqual(
json.loads(server_args.json_model_override_args),
{"rope_scaling": {"factor": 2.0, "rope_type": "linear"}},
)
class TestLoadBalanceMethod(unittest.TestCase):
def test_non_pd_defaults_to_round_robin(self):
server_args = ServerArgs(model_path="dummy", disaggregation_mode="null")
self.assertEqual(server_args.load_balance_method, "round_robin")
def test_pd_prefill_defaults_to_follow_bootstrap_room(self):
server_args = ServerArgs(model_path="dummy", disaggregation_mode="prefill")
self.assertEqual(server_args.load_balance_method, "follow_bootstrap_room")
def test_pd_decode_defaults_to_round_robin(self):
server_args = ServerArgs(model_path="dummy", disaggregation_mode="decode")
self.assertEqual(server_args.load_balance_method, "round_robin")
class TestPortArgs(unittest.TestCase):
@patch("sglang.srt.server_args.get_free_port")
@patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")
def test_init_new_with_nccl_port_none(self, mock_temp_file, mock_get_free_port):
"""Test that get_free_port() is called when nccl_port is None"""
mock_temp_file.return_value.name = "temp_file"
mock_get_free_port.return_value = 45678 # Mock ephemeral port
# Use MagicMock here to verify get_free_port is called
server_args = MagicMock()
server_args.nccl_port = None
server_args.enable_dp_attention = False
server_args.tokenizer_worker_num = 1
port_args = PortArgs.init_new(server_args)
# Verify get_free_port was called
mock_get_free_port.assert_called_once()
# Verify the returned port is used
self.assertEqual(port_args.nccl_port, 45678)
@patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")
def test_init_new_standard_case(self, mock_temp_file):
mock_temp_file.return_value.name = "temp_file"
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = False
port_args = PortArgs.init_new(server_args)
self.assertTrue(port_args.tokenizer_ipc_name.startswith("ipc://"))
self.assertTrue(port_args.scheduler_input_ipc_name.startswith("ipc://"))
self.assertTrue(port_args.detokenizer_ipc_name.startswith("ipc://"))
self.assertIsInstance(port_args.nccl_port, int)
def test_init_new_with_single_node_dp_attention(self):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 1
server_args.dist_init_addr = None
port_args = PortArgs.init_new(server_args)
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
self.assertTrue(
port_args.scheduler_input_ipc_name.startswith("tcp://127.0.0.1:")
)
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://127.0.0.1:"))
self.assertIsInstance(port_args.nccl_port, int)
def test_init_new_with_dp_rank(self):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 1
server_args.dist_init_addr = "192.168.1.1:25000"
worker_ports = [25006, 25007, 25008, 25009]
port_args = PortArgs.init_new(server_args, dp_rank=2, worker_ports=worker_ports)
self.assertTrue(port_args.scheduler_input_ipc_name.endswith(":25008"))
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertIsInstance(port_args.nccl_port, int)
def test_init_new_with_ipv4_address(self):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "192.168.1.1:25000"
port_args = PortArgs.init_new(server_args)
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertTrue(
port_args.scheduler_input_ipc_name.startswith("tcp://192.168.1.1:")
)
self.assertTrue(port_args.detokenizer_ipc_name.startswith("tcp://192.168.1.1:"))
self.assertIsInstance(port_args.nccl_port, int)
def test_init_new_with_malformed_ipv4_address(self):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "192.168.1.1"
with self.assertRaises(AssertionError) as context:
PortArgs.init_new(server_args)
self.assertIn(
"please provide --dist-init-addr as host:port", str(context.exception)
)
def test_init_new_with_malformed_ipv4_address_invalid_port(self):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "192.168.1.1:abc"
with self.assertRaises(ValueError):
PortArgs.init_new(server_args)
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_ipv6_address(self, mock_is_valid_ipv6):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]:25000"
port_args = PortArgs.init_new(server_args)
self.assertTrue(port_args.tokenizer_ipc_name.startswith("tcp://[2001:db8::1]:"))
self.assertTrue(
port_args.scheduler_input_ipc_name.startswith("tcp://[2001:db8::1]:")
)
self.assertTrue(
port_args.detokenizer_ipc_name.startswith("tcp://[2001:db8::1]:")
)
self.assertIsInstance(port_args.nccl_port, int)
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=False)
def test_init_new_with_invalid_ipv6_address(self, mock_is_valid_ipv6):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[invalid-ipv6]:25000"
with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)
self.assertIn("invalid IPv6 address", str(context.exception))
def test_init_new_with_malformed_ipv6_address_missing_bracket(self):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1:25000"
with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)
self.assertIn("invalid IPv6 address format", str(context.exception))
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_malformed_ipv6_address_missing_port(
self, mock_is_valid_ipv6
):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]"
with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)
self.assertIn(
"a port must be specified in IPv6 address", str(context.exception)
)
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_malformed_ipv6_address_invalid_port(
self, mock_is_valid_ipv6
):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]:abcde"
with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)
self.assertIn("invalid port in IPv6 address", str(context.exception))
@patch("sglang.srt.server_args.is_valid_ipv6_address", return_value=True)
def test_init_new_with_malformed_ipv6_address_wrong_separator(
self, mock_is_valid_ipv6
):
server_args = ServerArgs(model_path="dummy")
server_args.port = 30000
server_args.nccl_port = None
server_args.enable_dp_attention = True
server_args.nnodes = 2
server_args.dist_init_addr = "[2001:db8::1]#25000"
with self.assertRaises(ValueError) as context:
PortArgs.init_new(server_args)
self.assertIn("expected ':' after ']'", str(context.exception))
class TestSSLArgs(unittest.TestCase):
def test_default_ssl_fields_are_none(self):
server_args = ServerArgs(model_path="dummy")
self.assertIsNone(server_args.ssl_keyfile)
self.assertIsNone(server_args.ssl_certfile)
self.assertIsNone(server_args.ssl_ca_certs)
self.assertIsNone(server_args.ssl_keyfile_password)
def test_ssl_keyfile_without_certfile_raises(self):
with self.assertRaises(ValueError) as context:
ServerArgs(model_path="dummy", ssl_keyfile="key.pem")
self.assertIn("--ssl-certfile", str(context.exception))
def test_ssl_certfile_without_keyfile_raises(self):
with self.assertRaises(ValueError) as context:
ServerArgs(model_path="dummy", ssl_certfile="cert.pem")
self.assertIn("--ssl-keyfile", str(context.exception))
@patch("os.path.isfile", return_value=True)
def test_ssl_both_keyfile_and_certfile_accepted(self, _mock_isfile):
server_args = ServerArgs(
model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem"
)
self.assertEqual(server_args.ssl_keyfile, "key.pem")
self.assertEqual(server_args.ssl_certfile, "cert.pem")
def test_url_returns_http_without_ssl(self):
server_args = ServerArgs(model_path="dummy")
self.assertTrue(server_args.url().startswith("http://"))
def test_url_rewrites_all_interfaces_to_loopback(self):
server_args = ServerArgs(model_path="dummy", host="0.0.0.0")
self.assertEqual(server_args.url(), "http://127.0.0.1:30000")
def test_url_rewrites_empty_host_to_loopback(self):
server_args = ServerArgs(model_path="dummy", host="")
self.assertEqual(server_args.url(), "http://127.0.0.1:30000")
def test_url_rewrites_ipv6_all_interfaces_to_loopback(self):
server_args = ServerArgs(model_path="dummy", host="::")
self.assertEqual(server_args.url(), "http://[::1]:30000")
@patch("os.path.isfile", return_value=True)
def test_url_returns_https_with_ssl(self, _mock_isfile):
server_args = ServerArgs(
model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem"
)
self.assertTrue(server_args.url().startswith("https://"))
@patch("os.path.isfile", return_value=True)
def test_ssl_cli_args_parsed(self, _mock_isfile):
server_args = prepare_server_args(
[
"--model-path",
"dummy",
"--ssl-keyfile",
"key.pem",
"--ssl-certfile",
"cert.pem",
"--ssl-ca-certs",
"ca.pem",
"--ssl-keyfile-password",
"secret",
]
)
self.assertEqual(server_args.ssl_keyfile, "key.pem")
self.assertEqual(server_args.ssl_certfile, "cert.pem")
self.assertEqual(server_args.ssl_ca_certs, "ca.pem")
self.assertEqual(server_args.ssl_keyfile_password, "secret")
def test_ssl_verify_without_ssl(self):
server_args = ServerArgs(model_path="dummy")
self.assertIs(server_args.ssl_verify(), True)
@patch("os.path.isfile", return_value=True)
def test_ssl_verify_with_ssl_no_ca(self, _mock_isfile):
server_args = ServerArgs(
model_path="dummy", ssl_keyfile="key.pem", ssl_certfile="cert.pem"
)
self.assertIs(server_args.ssl_verify(), False)
@patch("os.path.isfile", return_value=True)
def test_ssl_verify_with_ssl_and_ca(self, _mock_isfile):
server_args = ServerArgs(
model_path="dummy",
ssl_keyfile="key.pem",
ssl_certfile="cert.pem",
ssl_ca_certs="ca.pem",
)
self.assertEqual(server_args.ssl_verify(), "ca.pem")
def test_ssl_ca_certs_without_certfile_raises(self):
with self.assertRaises(ValueError) as context:
ServerArgs(model_path="dummy", ssl_ca_certs="ca.pem")
self.assertIn("--ssl-ca-certs", str(context.exception))
def test_ssl_keyfile_password_without_certfile_raises(self):
with self.assertRaises(ValueError) as context:
ServerArgs(model_path="dummy", ssl_keyfile_password="secret")
self.assertIn("--ssl-keyfile-password", str(context.exception))
def test_ssl_keyfile_not_found_raises(self):
with self.assertRaises(ValueError) as context:
ServerArgs(
model_path="dummy",
ssl_keyfile="/nonexistent/key.pem",
ssl_certfile="/nonexistent/cert.pem",
)
self.assertIn("not found", str(context.exception))
def test_ssl_certfile_not_found_raises(self):
with tempfile.NamedTemporaryFile(suffix=".pem") as keyfile:
with self.assertRaises(ValueError) as context:
ServerArgs(
model_path="dummy",
ssl_keyfile=keyfile.name,
ssl_certfile="/nonexistent/cert.pem",
)
self.assertIn("SSL certificate file not found", str(context.exception))
def test_ssl_ca_certs_not_found_raises(self):
with tempfile.NamedTemporaryFile(suffix=".pem") as keyfile:
with tempfile.NamedTemporaryFile(suffix=".pem") as certfile:
with self.assertRaises(ValueError) as context:
ServerArgs(
model_path="dummy",
ssl_keyfile=keyfile.name,
ssl_certfile=certfile.name,
ssl_ca_certs="/nonexistent/ca.pem",
)
self.assertIn(
"SSL CA certificates file not found", str(context.exception)
)
@patch("os.path.isfile", return_value=True)
def test_url_returns_https_with_ssl_and_ipv6(self, _mock_isfile):
server_args = ServerArgs(
model_path="dummy",
host="::1",
ssl_keyfile="key.pem",
ssl_certfile="cert.pem",
)
self.assertEqual(server_args.url(), "https://[::1]:30000")
def test_enable_ssl_refresh_default_false(self):
server_args = ServerArgs(model_path="dummy")
self.assertFalse(server_args.enable_ssl_refresh)
def test_enable_ssl_refresh_without_ssl_raises(self):
with self.assertRaises(ValueError) as context:
ServerArgs(model_path="dummy", enable_ssl_refresh=True)
self.assertIn("--enable-ssl-refresh", str(context.exception))
self.assertIn("--ssl-certfile", str(context.exception))
@patch("os.path.isfile", return_value=True)
def test_enable_ssl_refresh_with_ssl_accepted(self, _mock_isfile):
server_args = ServerArgs(
model_path="dummy",
ssl_keyfile="key.pem",
ssl_certfile="cert.pem",
enable_ssl_refresh=True,
)
self.assertTrue(server_args.enable_ssl_refresh)
@patch("os.path.isfile", return_value=True)
def test_enable_ssl_refresh_cli_flag(self, _mock_isfile):
server_args = prepare_server_args(
[
"--model-path",
"dummy",
"--ssl-keyfile",
"key.pem",
"--ssl-certfile",
"cert.pem",
"--enable-ssl-refresh",
]
)
self.assertTrue(server_args.enable_ssl_refresh)
if __name__ == "__main__":
unittest.main()