| """ |
| Unit tests for argument parsing functionality in sglang_router. |
| |
| These tests focus on testing the argument parsing logic in isolation, |
| without starting actual router instances. |
| """ |
|
|
| from types import SimpleNamespace |
|
|
| import pytest |
| from sglang_router.launch_router import RouterArgs, parse_router_args |
| from sglang_router.router import policy_from_str |
|
|
|
|
| class TestRouterArgs: |
| """Test RouterArgs dataclass and its methods.""" |
|
|
| def test_default_values(self): |
| """Test that RouterArgs has correct default values.""" |
| args = RouterArgs() |
|
|
| |
| assert args.host == "0.0.0.0" |
| assert args.port == 30000 |
| assert args.policy == "cache_aware" |
| assert args.worker_urls == [] |
| assert args.pd_disaggregation is False |
| assert args.prefill_urls == [] |
| assert args.decode_urls == [] |
|
|
| |
| assert args.prefill_policy is None |
| assert args.decode_policy is None |
|
|
| |
| assert args.service_discovery is False |
| assert args.selector == {} |
| assert args.service_discovery_port == 80 |
| assert args.service_discovery_namespace is None |
|
|
| |
| assert args.retry_max_retries == 5 |
| assert args.cb_failure_threshold == 10 |
| assert args.disable_retries is False |
| assert args.disable_circuit_breaker is False |
|
|
| def test_parse_selector_valid(self): |
| """Test parsing valid selector arguments.""" |
| |
| result = RouterArgs._parse_selector(["app=worker"]) |
| assert result == {"app": "worker"} |
|
|
| |
| result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"]) |
| assert result == {"app": "worker", "env": "prod", "version": "v1"} |
|
|
| |
| result = RouterArgs._parse_selector([]) |
| assert result == {} |
|
|
| |
| result = RouterArgs._parse_selector(None) |
| assert result == {} |
|
|
| def test_parse_selector_invalid(self): |
| """Test parsing invalid selector arguments.""" |
| |
| result = RouterArgs._parse_selector(["app"]) |
| assert result == {} |
|
|
| |
| result = RouterArgs._parse_selector(["app=worker=extra"]) |
| assert result == {"app": "worker=extra"} |
|
|
| def test_parse_prefill_urls_valid(self): |
| """Test parsing valid prefill URL arguments.""" |
| |
| result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]]) |
| assert result == [("http://prefill1:8000", 9000)] |
|
|
| |
| result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]]) |
| assert result == [("http://prefill1:8000", None)] |
|
|
| |
| result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]]) |
| assert result == [("http://prefill1:8000", None)] |
|
|
| |
| result = RouterArgs._parse_prefill_urls( |
| [ |
| ["http://prefill1:8000", "9000"], |
| ["http://prefill2:8000", "none"], |
| ["http://prefill3:8000"], |
| ] |
| ) |
| expected = [ |
| ("http://prefill1:8000", 9000), |
| ("http://prefill2:8000", None), |
| ("http://prefill3:8000", None), |
| ] |
| assert result == expected |
|
|
| |
| result = RouterArgs._parse_prefill_urls([]) |
| assert result == [] |
|
|
| |
| result = RouterArgs._parse_prefill_urls(None) |
| assert result == [] |
|
|
| def test_parse_prefill_urls_invalid(self): |
| """Test parsing invalid prefill URL arguments.""" |
| |
| with pytest.raises(ValueError, match="Invalid bootstrap port"): |
| RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]]) |
|
|
| def test_parse_decode_urls_valid(self): |
| """Test parsing valid decode URL arguments.""" |
| |
| result = RouterArgs._parse_decode_urls([["http://decode1:8001"]]) |
| assert result == ["http://decode1:8001"] |
|
|
| |
| result = RouterArgs._parse_decode_urls( |
| [["http://decode1:8001"], ["http://decode2:8001"]] |
| ) |
| assert result == ["http://decode1:8001", "http://decode2:8001"] |
|
|
| |
| result = RouterArgs._parse_decode_urls([]) |
| assert result == [] |
|
|
| |
| result = RouterArgs._parse_decode_urls(None) |
| assert result == [] |
|
|
| def test_from_cli_args_basic(self): |
| """Test creating RouterArgs from basic CLI arguments.""" |
| args = SimpleNamespace( |
| host="0.0.0.0", |
| port=30001, |
| worker_urls=["http://worker1:8000", "http://worker2:8000"], |
| policy="round_robin", |
| prefill=None, |
| decode=None, |
| router_policy="round_robin", |
| router_pd_disaggregation=False, |
| router_prefill_policy=None, |
| router_decode_policy=None, |
| router_worker_startup_timeout_secs=300, |
| router_worker_startup_check_interval=15, |
| router_cache_threshold=0.7, |
| router_balance_abs_threshold=128, |
| router_balance_rel_threshold=2.0, |
| router_eviction_interval=180, |
| router_max_tree_size=2**28, |
| router_max_payload_size=1024 * 1024 * 1024, |
| router_dp_aware=True, |
| router_api_key="test-key", |
| router_log_dir="/tmp/logs", |
| router_log_level="debug", |
| router_service_discovery=True, |
| router_selector=["app=worker", "env=test"], |
| router_service_discovery_port=8080, |
| router_service_discovery_namespace="default", |
| router_prefill_selector=["app=prefill"], |
| router_decode_selector=["app=decode"], |
| router_prometheus_port=29000, |
| router_prometheus_host="0.0.0.0", |
| router_request_id_headers=["x-request-id", "x-trace-id"], |
| router_request_timeout_secs=1200, |
| router_max_concurrent_requests=512, |
| router_queue_size=200, |
| router_queue_timeout_secs=120, |
| router_rate_limit_tokens_per_second=100, |
| router_cors_allowed_origins=["http://localhost:3000"], |
| router_retry_max_retries=3, |
| router_retry_initial_backoff_ms=100, |
| router_retry_max_backoff_ms=10000, |
| router_retry_backoff_multiplier=2.0, |
| router_retry_jitter_factor=0.1, |
| router_cb_failure_threshold=5, |
| router_cb_success_threshold=2, |
| router_cb_timeout_duration_secs=30, |
| router_cb_window_duration_secs=60, |
| router_disable_retries=False, |
| router_disable_circuit_breaker=False, |
| router_health_failure_threshold=2, |
| router_health_success_threshold=1, |
| router_health_check_timeout_secs=3, |
| router_health_check_interval_secs=30, |
| router_health_check_endpoint="/healthz", |
| ) |
|
|
| router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) |
|
|
| |
| assert router_args.host == "0.0.0.0" |
| assert router_args.port == 30001 |
| assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] |
| assert router_args.policy == "round_robin" |
|
|
| |
| assert router_args.pd_disaggregation is False |
| assert router_args.prefill_urls == [] |
| assert router_args.decode_urls == [] |
|
|
| |
| assert router_args.service_discovery is True |
| assert router_args.selector == {"app": "worker", "env": "test"} |
| assert router_args.service_discovery_port == 8080 |
| assert router_args.service_discovery_namespace == "default" |
| assert router_args.prefill_selector == {"app": "prefill"} |
| assert router_args.decode_selector == {"app": "decode"} |
|
|
| |
| assert router_args.dp_aware is True |
| assert router_args.api_key == "test-key" |
| assert router_args.log_dir == "/tmp/logs" |
| assert router_args.log_level == "debug" |
| assert router_args.prometheus_port == 29000 |
| assert router_args.prometheus_host == "0.0.0.0" |
| assert router_args.request_id_headers == ["x-request-id", "x-trace-id"] |
| assert router_args.request_timeout_secs == 1200 |
| assert router_args.max_concurrent_requests == 512 |
| assert router_args.queue_size == 200 |
| assert router_args.queue_timeout_secs == 120 |
| assert router_args.rate_limit_tokens_per_second == 100 |
| assert router_args.cors_allowed_origins == ["http://localhost:3000"] |
|
|
| |
| assert router_args.retry_max_retries == 3 |
| assert router_args.retry_initial_backoff_ms == 100 |
| assert router_args.retry_max_backoff_ms == 10000 |
| assert router_args.retry_backoff_multiplier == 2.0 |
| assert router_args.retry_jitter_factor == 0.1 |
|
|
| |
| assert router_args.cb_failure_threshold == 5 |
| assert router_args.cb_success_threshold == 2 |
| assert router_args.cb_timeout_duration_secs == 30 |
| assert router_args.cb_window_duration_secs == 60 |
| assert router_args.disable_retries is False |
| assert router_args.disable_circuit_breaker is False |
|
|
| |
| assert router_args.health_failure_threshold == 2 |
| assert router_args.health_success_threshold == 1 |
| assert router_args.health_check_timeout_secs == 3 |
| assert router_args.health_check_interval_secs == 30 |
| assert router_args.health_check_endpoint == "/healthz" |
|
|
| |
|
|
| def test_from_cli_args_pd_mode(self): |
| """Test creating RouterArgs from CLI arguments in PD mode.""" |
| args = SimpleNamespace( |
| host="127.0.0.1", |
| port=30000, |
| worker_urls=[], |
| policy="cache_aware", |
| prefill=[ |
| ["http://prefill1:8000", "9000"], |
| ["http://prefill2:8000", "none"], |
| ], |
| decode=[["http://decode1:8001"], ["http://decode2:8001"]], |
| router_prefill=[ |
| ["http://prefill1:8000", "9000"], |
| ["http://prefill2:8000", "none"], |
| ], |
| router_decode=[["http://decode1:8001"], ["http://decode2:8001"]], |
| router_policy="cache_aware", |
| router_pd_disaggregation=True, |
| router_prefill_policy="power_of_two", |
| router_decode_policy="round_robin", |
| |
| router_worker_startup_timeout_secs=600, |
| router_worker_startup_check_interval=30, |
| router_cache_threshold=0.3, |
| router_balance_abs_threshold=64, |
| router_balance_rel_threshold=1.5, |
| router_eviction_interval=120, |
| router_max_tree_size=2**26, |
| router_max_payload_size=512 * 1024 * 1024, |
| router_dp_aware=False, |
| router_api_key=None, |
| router_log_dir=None, |
| router_log_level=None, |
| router_service_discovery=False, |
| router_selector=None, |
| router_service_discovery_port=80, |
| router_service_discovery_namespace=None, |
| router_prefill_selector=None, |
| router_decode_selector=None, |
| router_prometheus_port=None, |
| router_prometheus_host=None, |
| router_request_id_headers=None, |
| router_request_timeout_secs=1800, |
| router_max_concurrent_requests=256, |
| router_queue_size=100, |
| router_queue_timeout_secs=60, |
| router_rate_limit_tokens_per_second=None, |
| router_cors_allowed_origins=[], |
| router_retry_max_retries=5, |
| router_retry_initial_backoff_ms=50, |
| router_retry_max_backoff_ms=30000, |
| router_retry_backoff_multiplier=1.5, |
| router_retry_jitter_factor=0.2, |
| router_cb_failure_threshold=10, |
| router_cb_success_threshold=3, |
| router_cb_timeout_duration_secs=60, |
| router_cb_window_duration_secs=120, |
| router_disable_retries=False, |
| router_disable_circuit_breaker=False, |
| router_health_failure_threshold=3, |
| router_health_success_threshold=2, |
| router_health_check_timeout_secs=5, |
| router_health_check_interval_secs=60, |
| router_health_check_endpoint="/health", |
| ) |
|
|
| router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) |
|
|
| |
| assert router_args.pd_disaggregation is True |
| assert router_args.prefill_urls == [ |
| ("http://prefill1:8000", 9000), |
| ("http://prefill2:8000", None), |
| ] |
| assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] |
| assert router_args.prefill_policy == "power_of_two" |
| assert router_args.decode_policy == "round_robin" |
| assert router_args.policy == "cache_aware" |
|
|
| def test_from_cli_args_without_prefix(self): |
| """Test creating RouterArgs from CLI arguments without router prefix.""" |
| args = SimpleNamespace( |
| host="127.0.0.1", |
| port=30000, |
| worker_urls=["http://worker1:8000"], |
| policy="random", |
| prefill=None, |
| decode=None, |
| pd_disaggregation=False, |
| prefill_policy=None, |
| decode_policy=None, |
| worker_startup_timeout_secs=600, |
| worker_startup_check_interval=30, |
| cache_threshold=0.3, |
| balance_abs_threshold=64, |
| balance_rel_threshold=1.5, |
| eviction_interval=120, |
| max_tree_size=2**26, |
| max_payload_size=512 * 1024 * 1024, |
| dp_aware=False, |
| api_key=None, |
| log_dir=None, |
| log_level=None, |
| service_discovery=False, |
| selector=None, |
| service_discovery_port=80, |
| service_discovery_namespace=None, |
| prefill_selector=None, |
| decode_selector=None, |
| prometheus_port=None, |
| prometheus_host=None, |
| request_id_headers=None, |
| request_timeout_secs=1800, |
| max_concurrent_requests=256, |
| queue_size=100, |
| queue_timeout_secs=60, |
| rate_limit_tokens_per_second=None, |
| cors_allowed_origins=[], |
| retry_max_retries=5, |
| retry_initial_backoff_ms=50, |
| retry_max_backoff_ms=30000, |
| retry_backoff_multiplier=1.5, |
| retry_jitter_factor=0.2, |
| cb_failure_threshold=10, |
| cb_success_threshold=3, |
| cb_timeout_duration_secs=60, |
| cb_window_duration_secs=120, |
| disable_retries=False, |
| disable_circuit_breaker=False, |
| health_failure_threshold=3, |
| health_success_threshold=2, |
| health_check_timeout_secs=5, |
| health_check_interval_secs=60, |
| health_check_endpoint="/health", |
| model_path=None, |
| tokenizer_path=None, |
| ) |
|
|
| router_args = RouterArgs.from_cli_args(args, use_router_prefix=False) |
|
|
| assert router_args.host == "127.0.0.1" |
| assert router_args.port == 30000 |
| assert router_args.worker_urls == ["http://worker1:8000"] |
| assert router_args.policy == "random" |
| assert router_args.pd_disaggregation is False |
|
|
|
|
| class TestPolicyFromStr: |
| """Test policy string to enum conversion.""" |
|
|
| def test_valid_policies(self): |
| """Test conversion of valid policy strings.""" |
| from sglang_router.sglang_router_rs import PolicyType |
|
|
| assert policy_from_str("random") == PolicyType.Random |
| assert policy_from_str("round_robin") == PolicyType.RoundRobin |
| assert policy_from_str("cache_aware") == PolicyType.CacheAware |
| assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo |
|
|
| def test_invalid_policy(self): |
| """Test conversion of invalid policy string.""" |
| with pytest.raises(KeyError): |
| policy_from_str("invalid_policy") |
|
|
|
|
| class TestParseRouterArgs: |
| """Test the parse_router_args function.""" |
|
|
| def test_parse_basic_args(self): |
| """Test parsing basic router arguments.""" |
| args = [ |
| "--host", |
| "0.0.0.0", |
| "--port", |
| "30001", |
| "--worker-urls", |
| "http://worker1:8000", |
| "http://worker2:8000", |
| "--policy", |
| "round_robin", |
| ] |
|
|
| router_args = parse_router_args(args) |
|
|
| assert router_args.host == "0.0.0.0" |
| assert router_args.port == 30001 |
| assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] |
| assert router_args.policy == "round_robin" |
|
|
| def test_parse_pd_args(self): |
| """Test parsing PD disaggregated mode arguments.""" |
| args = [ |
| "--pd-disaggregation", |
| "--prefill", |
| "http://prefill1:8000", |
| "9000", |
| "--prefill", |
| "http://prefill2:8000", |
| "none", |
| "--decode", |
| "http://decode1:8001", |
| "--decode", |
| "http://decode2:8001", |
| "--prefill-policy", |
| "power_of_two", |
| "--decode-policy", |
| "round_robin", |
| ] |
|
|
| router_args = parse_router_args(args) |
|
|
| assert router_args.pd_disaggregation is True |
| assert router_args.prefill_urls == [ |
| ("http://prefill1:8000", 9000), |
| ("http://prefill2:8000", None), |
| ] |
| assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] |
| assert router_args.prefill_policy == "power_of_two" |
| assert router_args.decode_policy == "round_robin" |
|
|
| def test_parse_service_discovery_args(self): |
| """Test parsing service discovery arguments.""" |
| args_a = [ |
| "--service-discovery", |
| "--selector", |
| "app=worker", |
| "env=prod", |
| "--service-discovery-port", |
| "8080", |
| "--service-discovery-namespace", |
| "default", |
| ] |
| args_b = [ |
| "--service-discovery", |
| "--selector", |
| |
| "app=worker env=prod", |
| "--service-discovery-port", |
| "8080", |
| "--service-discovery-namespace", |
| "default", |
| ] |
|
|
| for args in [args_a, args_b]: |
| router_args = parse_router_args(args) |
|
|
| assert router_args.service_discovery is True |
| assert router_args.selector == {"app": "worker", "env": "prod"} |
| assert router_args.service_discovery_port == 8080 |
| assert router_args.service_discovery_namespace == "default" |
|
|
| def test_parse_retry_and_circuit_breaker_args(self): |
| """Test parsing retry and circuit breaker arguments.""" |
| args = [ |
| "--retry-max-retries", |
| "3", |
| "--retry-initial-backoff-ms", |
| "100", |
| "--retry-max-backoff-ms", |
| "10000", |
| "--retry-backoff-multiplier", |
| "2.0", |
| "--retry-jitter-factor", |
| "0.1", |
| "--disable-retries", |
| "--cb-failure-threshold", |
| "5", |
| "--cb-success-threshold", |
| "2", |
| "--cb-timeout-duration-secs", |
| "30", |
| "--cb-window-duration-secs", |
| "60", |
| "--disable-circuit-breaker", |
| ] |
|
|
| router_args = parse_router_args(args) |
|
|
| |
| assert router_args.retry_max_retries == 3 |
| assert router_args.retry_initial_backoff_ms == 100 |
| assert router_args.retry_max_backoff_ms == 10000 |
| assert router_args.retry_backoff_multiplier == 2.0 |
| assert router_args.retry_jitter_factor == 0.1 |
| assert router_args.disable_retries is True |
|
|
| |
| assert router_args.cb_failure_threshold == 5 |
| assert router_args.cb_success_threshold == 2 |
| assert router_args.cb_timeout_duration_secs == 30 |
| assert router_args.cb_window_duration_secs == 60 |
| assert router_args.disable_circuit_breaker is True |
|
|
| def test_parse_rate_limiting_args(self): |
| """Test parsing rate limiting arguments.""" |
| args = [ |
| "--max-concurrent-requests", |
| "512", |
| "--queue-size", |
| "200", |
| "--queue-timeout-secs", |
| "120", |
| "--rate-limit-tokens-per-second", |
| "100", |
| ] |
|
|
| router_args = parse_router_args(args) |
|
|
| assert router_args.max_concurrent_requests == 512 |
| assert router_args.queue_size == 200 |
| assert router_args.queue_timeout_secs == 120 |
| assert router_args.rate_limit_tokens_per_second == 100 |
|
|
| def test_parse_health_check_args(self): |
| """Test parsing health check arguments.""" |
| args = [ |
| "--health-failure-threshold", |
| "2", |
| "--health-success-threshold", |
| "1", |
| "--health-check-timeout-secs", |
| "3", |
| "--health-check-interval-secs", |
| "30", |
| "--health-check-endpoint", |
| "/healthz", |
| ] |
|
|
| router_args = parse_router_args(args) |
|
|
| assert router_args.health_failure_threshold == 2 |
| assert router_args.health_success_threshold == 1 |
| assert router_args.health_check_timeout_secs == 3 |
| assert router_args.health_check_interval_secs == 30 |
| assert router_args.health_check_endpoint == "/healthz" |
|
|
| def test_parse_cors_args(self): |
| """Test parsing CORS arguments.""" |
| args = [ |
| "--cors-allowed-origins", |
| "http://localhost:3000", |
| "https://example.com", |
| ] |
|
|
| router_args = parse_router_args(args) |
|
|
| assert router_args.cors_allowed_origins == [ |
| "http://localhost:3000", |
| "https://example.com", |
| ] |
|
|
| def test_parse_tokenizer_args(self): |
| """Test parsing tokenizer arguments.""" |
| |
| |
| pytest.skip("Tokenizer arguments not available in current implementation") |
|
|
| def test_parse_invalid_args(self): |
| """Test parsing invalid arguments.""" |
| |
| with pytest.raises(SystemExit): |
| parse_router_args(["--policy", "invalid_policy"]) |
|
|
| |
| with pytest.raises(ValueError, match="Invalid bootstrap port"): |
| parse_router_args( |
| [ |
| "--pd-disaggregation", |
| "--prefill", |
| "http://prefill1:8000", |
| "invalid_port", |
| ] |
| ) |
|
|
| def test_help_output(self): |
| """Test that help output is generated correctly.""" |
| with pytest.raises(SystemExit) as exc_info: |
| parse_router_args(["--help"]) |
|
|
| |
| assert exc_info.value.code == 0 |
|
|