File size: 15,089 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | from typing import Optional
from sglang_router.router_args import RouterArgs
from sglang_router.sglang_router_rs import (
BackendType,
HistoryBackendType,
PolicyType,
PyApiKeyEntry,
PyControlPlaneAuthConfig,
PyJwtConfig,
PyOracleConfig,
PyPostgresConfig,
PyRedisConfig,
PyRole,
)
from sglang_router.sglang_router_rs import Router as _Router
def policy_from_str(policy_str: Optional[str]) -> PolicyType:
"""Convert policy string to PolicyType enum."""
if policy_str is None:
return None
policy_map = {
"random": PolicyType.Random,
"round_robin": PolicyType.RoundRobin,
"cache_aware": PolicyType.CacheAware,
"power_of_two": PolicyType.PowerOfTwo,
"bucket": PolicyType.Bucket,
"manual": PolicyType.Manual,
"consistent_hashing": PolicyType.ConsistentHashing,
"prefix_hash": PolicyType.PrefixHash,
}
return policy_map[policy_str]
def backend_from_str(backend_str: Optional[str]) -> BackendType:
"""Convert backend string to BackendType enum."""
if isinstance(backend_str, BackendType):
return backend_str
if backend_str is None:
return BackendType.Sglang
backend_map = {"sglang": BackendType.Sglang, "openai": BackendType.Openai}
backend_lower = backend_str.lower()
if backend_lower not in backend_map:
raise ValueError(
f"Unknown backend: {backend_str}. Valid options: {', '.join(backend_map.keys())}"
)
return backend_map[backend_lower]
def history_backend_from_str(backend_str: Optional[str]) -> HistoryBackendType:
"""Convert history backend string to HistoryBackendType enum."""
if isinstance(backend_str, HistoryBackendType):
return backend_str
if backend_str is None:
return HistoryBackendType.Memory
backend_lower = backend_str.lower()
if backend_lower == "memory":
return HistoryBackendType.Memory
elif backend_lower == "none":
# Use getattr to access 'None' which is a Python keyword
return getattr(HistoryBackendType, "None")
elif backend_lower == "oracle":
return HistoryBackendType.Oracle
elif backend_lower == "postgres":
return HistoryBackendType.Postgres
elif backend_lower == "redis":
return HistoryBackendType.Redis
else:
raise ValueError(f"Unknown history backend: {backend_str}")
def role_from_str(role_str: str) -> PyRole:
"""Convert role string to PyRole enum."""
if role_str.lower() == "admin":
return PyRole.Admin
return PyRole.User
def build_control_plane_auth_config(
args_dict: dict,
) -> Optional[PyControlPlaneAuthConfig]:
"""Build control plane auth config from args dict."""
api_keys = args_dict.get("control_plane_api_keys", [])
jwt_issuer = args_dict.get("jwt_issuer")
jwt_audience = args_dict.get("jwt_audience")
audit_enabled = args_dict.get("control_plane_audit_enabled", False)
# Check if any auth is configured
has_api_keys = bool(api_keys)
has_jwt = jwt_issuer is not None and jwt_audience is not None
if not has_api_keys and not has_jwt:
return None
# Build API key entries
py_api_keys = []
for key_tuple in api_keys:
# Tuple format: (id, name, key, role)
key_id, name, key, role = key_tuple
py_api_keys.append(
PyApiKeyEntry(
id=key_id,
name=name,
key=key,
role=role_from_str(role),
)
)
# Build JWT config if present
jwt_config = None
if has_jwt:
jwt_config = PyJwtConfig(
issuer=jwt_issuer,
audience=jwt_audience,
jwks_uri=args_dict.get("jwt_jwks_uri"),
role_mapping=args_dict.get("jwt_role_mapping", {}),
)
return PyControlPlaneAuthConfig(
jwt=jwt_config,
api_keys=py_api_keys,
audit_enabled=audit_enabled,
)
class Router:
"""
A high-performance router for distributing requests across worker nodes.
Args:
worker_urls: List of URLs for worker nodes that will handle requests. Each URL should include
the protocol, host, and port (e.g., ['http://worker1:8000', 'http://worker2:8000'])
policy: Load balancing policy to use. Options:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Supports IPv4, IPv6 (e.g., ::, ::1), or 0.0.0.0 for all interfaces. Default: '0.0.0.0'
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup and registration. Large models can take significant time to load into GPU memory. Default: 1800 (30 minutes)
worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
balance_abs_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32
balance_rel_threshold: Load balancing is triggered when (max_load - min_load) > abs_threshold
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
the router can manage multiple models simultaneously with per-model load balancing
policies. Default: False
api_key: The api key used for the authorization with the worker.
Useful when the dp aware scheduling strategy is enabled.
Default: None
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
log_level: Logging level. Options: 'debug', 'info', 'warn', 'error'.
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
automatically discover worker pods based on the selector. Default: False
selector: Dictionary mapping of label keys to values for Kubernetes pod selection.
Example: {"app": "sglang-worker"}. Default: {}
service_discovery_port: Port to use for service discovery. The router will generate
worker URLs using this port. Default: 80
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for prefill servers (PD mode only). Default: {}
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
for decode servers (PD mode only). Default: {}
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
Default: 'sglang.ai/bootstrap-port'
request_timeout_secs: Request timeout in seconds. Default: 600
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
health_check_endpoint: Health check endpoint path. Default: '/health'
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
server_cert_path: Path to server TLS certificate (PEM format). Default: None
server_key_path: Path to server TLS private key (PEM format). Default: None
"""
def __init__(self, router: _Router):
self._router = router
@staticmethod
def from_args(args: RouterArgs) -> "Router":
"""Create a router from a RouterArgs instance."""
args_dict = vars(args)
# Convert RouterArgs to _Router parameters
args_dict["worker_urls"] = (
[]
if args_dict["service_discovery"] or args_dict["pd_disaggregation"]
else args_dict["worker_urls"]
)
args_dict["policy"] = policy_from_str(args_dict["policy"])
args_dict["prefill_urls"] = (
args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["decode_urls"] = (
args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None
)
args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"])
args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"])
# Convert backend
args_dict["backend"] = backend_from_str(args_dict.get("backend"))
# Convert history_backend to enum first
history_backend_raw = args_dict.get("history_backend", "memory")
history_backend = history_backend_from_str(history_backend_raw)
# Convert Oracle config if needed
oracle_config = None
if history_backend == HistoryBackendType.Oracle:
# Prioritize TNS alias over connect descriptor
tns_alias = args_dict.get("oracle_tns_alias")
connect_descriptor = args_dict.get("oracle_connect_descriptor")
# Use TNS alias if provided, otherwise use connect descriptor
final_descriptor = tns_alias if tns_alias else connect_descriptor
oracle_config = PyOracleConfig(
password=args_dict.get("oracle_password"),
username=args_dict.get("oracle_username"),
connect_descriptor=final_descriptor,
wallet_path=args_dict.get("oracle_wallet_path"),
pool_min=args_dict.get("oracle_pool_min", 1),
pool_max=args_dict.get("oracle_pool_max", 16),
pool_timeout_secs=args_dict.get("oracle_pool_timeout_secs", 30),
)
args_dict["oracle_config"] = oracle_config
args_dict["history_backend"] = history_backend
# Convert Postgres config if needed
postgres_config = None
if history_backend == HistoryBackendType.Postgres:
postgres_config = PyPostgresConfig(
db_url=args_dict.get("postgres_db_url"),
pool_max=args_dict.get("postgres_pool_max", 16),
)
args_dict["postgres_config"] = postgres_config
# Convert Redis config if needed
redis_config = None
if history_backend == HistoryBackendType.Redis:
retention_days = args_dict.get("redis_retention_days", 30)
# If retention_days is negative, it means persistent storage (None in Rust)
retention_arg = None if retention_days < 0 else retention_days
redis_config = PyRedisConfig(
url=args_dict.get("redis_url"),
pool_max=args_dict.get("redis_pool_max", 16),
retention_days=retention_arg,
)
args_dict["redis_config"] = redis_config
# Build control plane auth config
args_dict["control_plane_auth"] = build_control_plane_auth_config(args_dict)
# Remove fields that shouldn't be passed to Rust Router constructor
fields_to_remove = [
"mini_lb",
"test_external_dp_routing",
"oracle_wallet_path",
"oracle_tns_alias",
"oracle_connect_descriptor",
"oracle_username",
"oracle_password",
"oracle_pool_min",
"oracle_pool_max",
"oracle_pool_timeout_secs",
"postgres_db_url",
"postgres_pool_max",
"redis_url",
"redis_pool_max",
"redis_retention_days",
# Control plane auth fields (converted to control_plane_auth)
"control_plane_api_keys",
"control_plane_audit_enabled",
"jwt_issuer",
"jwt_audience",
"jwt_jwks_uri",
"jwt_role_mapping",
]
for field in fields_to_remove:
args_dict.pop(field, None)
return Router(_Router(**args_dict))
def start(self) -> None:
"""Start the router server.
This method blocks until the server is shut down.
"""
self._router.start()
|