use pyo3::prelude::*; use smg::*; use once_cell::sync::OnceCell; use std::collections::HashMap; // Define the enums with PyO3 bindings #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] pub enum PolicyType { Random, RoundRobin, CacheAware, PowerOfTwo, Bucket, Manual, ConsistentHashing, PrefixHash, } #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] pub enum BackendType { Sglang, Openai, } #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] pub enum HistoryBackendType { Memory, None, Oracle, Postgres, Redis, } #[pyclass(eq)] #[derive(Clone, PartialEq, Debug, Default)] pub enum PyRole { Admin, #[default] User, } impl PyRole { pub fn to_auth_role(&self) -> auth::Role { match self { PyRole::Admin => auth::Role::Admin, PyRole::User => auth::Role::User, } } } #[pyclass] #[derive(Clone, Debug, PartialEq)] pub struct PyApiKeyEntry { #[pyo3(get, set)] pub id: String, #[pyo3(get, set)] pub name: String, #[pyo3(get, set)] pub key: String, #[pyo3(get, set)] pub role: PyRole, } #[pymethods] impl PyApiKeyEntry { #[new] #[pyo3(signature = (id, name, key, role = PyRole::User))] fn new(id: String, name: String, key: String, role: PyRole) -> Self { PyApiKeyEntry { id, name, key, role } } } impl PyApiKeyEntry { pub fn to_auth_api_key_entry(&self) -> auth::ApiKeyEntry { auth::ApiKeyEntry::new(&self.id, &self.name, &self.key, self.role.to_auth_role()) } } #[pyclass] #[derive(Clone, Debug, PartialEq)] pub struct PyJwtConfig { #[pyo3(get, set)] pub issuer: String, #[pyo3(get, set)] pub audience: String, #[pyo3(get, set)] pub jwks_uri: Option, #[pyo3(get, set)] pub role_mapping: HashMap, } #[pymethods] impl PyJwtConfig { #[new] #[pyo3(signature = ( issuer, audience, jwks_uri = None, role_mapping = HashMap::new(), ))] fn new( issuer: String, audience: String, jwks_uri: Option, role_mapping: HashMap, ) -> Self { PyJwtConfig { issuer, audience, jwks_uri, role_mapping, } } } impl PyJwtConfig { pub fn to_auth_jwt_config(&self) -> auth::JwtConfig { let mut config = auth::JwtConfig::new(&self.issuer, &self.audience); // Conditionally set JWKS URI if let Some(ref uri) = self.jwks_uri { config = config.with_jwks_uri(uri); } // Add role mappings for (idp_role, gateway_role) in &self.role_mapping { let role = match gateway_role.to_lowercase().as_str() { "admin" => auth::Role::Admin, _ => auth::Role::User, }; config = config.with_role_mapping(idp_role, role); } config } } #[pyclass] #[derive(Clone, Debug, Default, PartialEq)] pub struct PyControlPlaneAuthConfig { #[pyo3(get, set)] pub jwt: Option, #[pyo3(get, set)] pub api_keys: Vec, #[pyo3(get, set)] pub audit_enabled: bool, } #[pymethods] impl PyControlPlaneAuthConfig { #[new] #[pyo3(signature = ( jwt = None, api_keys = vec![], audit_enabled = true, ))] fn new( jwt: Option, api_keys: Vec, audit_enabled: bool, ) -> Self { PyControlPlaneAuthConfig { jwt, api_keys, audit_enabled, } } } impl PyControlPlaneAuthConfig { pub fn to_auth_control_plane_config(&self) -> auth::ControlPlaneAuthConfig { auth::ControlPlaneAuthConfig { jwt: self.jwt.as_ref().map(|j| j.to_auth_jwt_config()), api_keys: self.api_keys.iter().map(|k| k.to_auth_api_key_entry()).collect(), audit_enabled: self.audit_enabled, } } } #[pyclass] #[derive(Clone, PartialEq)] pub struct PyOracleConfig { #[pyo3(get, set)] pub wallet_path: Option, #[pyo3(get, set)] pub connect_descriptor: Option, #[pyo3(get, set)] pub username: Option, #[pyo3(get, set)] pub password: Option, #[pyo3(get, set)] pub pool_min: usize, #[pyo3(get, set)] pub pool_max: usize, #[pyo3(get, set)] pub pool_timeout_secs: u64, } impl std::fmt::Debug for PyOracleConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PyOracleConfig") .field("wallet_path", &self.wallet_path) .field("connect_descriptor", &"") .field("username", &self.username) .field("password", &"") .field("pool_min", &self.pool_min) .field("pool_max", &self.pool_max) .field("pool_timeout_secs", &self.pool_timeout_secs) .finish() } } #[pymethods] impl PyOracleConfig { #[new] #[pyo3(signature = ( password = None, username = None, connect_descriptor = None, wallet_path = None, pool_min = 1, pool_max = 16, pool_timeout_secs = 30, ))] fn new( password: Option, username: Option, connect_descriptor: Option, wallet_path: Option, pool_min: usize, pool_max: usize, pool_timeout_secs: u64, ) -> PyResult { if pool_min == 0 { return Err(pyo3::exceptions::PyValueError::new_err( "pool_min must be at least 1", )); } if pool_max < pool_min { return Err(pyo3::exceptions::PyValueError::new_err( "pool_max must be >= pool_min", )); } Ok(PyOracleConfig { wallet_path, connect_descriptor, username, password, pool_min, pool_max, pool_timeout_secs, }) } } impl PyOracleConfig { pub fn to_config_oracle(&self) -> config::OracleConfig { config::OracleConfig { wallet_path: self.wallet_path.clone(), connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(), username: self.username.clone().unwrap_or_default(), password: self.password.clone().unwrap_or_default(), pool_min: self.pool_min, pool_max: self.pool_max, pool_timeout_secs: self.pool_timeout_secs, } } } #[pyclass] #[derive(Debug, Clone, PartialEq)] pub struct PyRedisConfig { #[pyo3(get, set)] pub url: String, #[pyo3(get, set)] pub pool_max: usize, #[pyo3(get, set)] pub retention_days: Option, } #[pymethods] impl PyRedisConfig { #[new] #[pyo3(signature = (url, pool_max = 16, retention_days = Some(30)))] fn new(url: String, pool_max: usize, retention_days: Option) -> PyResult { Ok(PyRedisConfig { url, pool_max, retention_days, }) } } impl PyRedisConfig { pub fn to_config_redis(&self) -> config::RedisConfig { config::RedisConfig { url: self.url.clone(), pool_max: self.pool_max, retention_days: self.retention_days, } } } #[pyclass] #[derive(Debug, Clone, PartialEq)] pub struct PyPostgresConfig { #[pyo3(get, set)] pub db_url: Option, #[pyo3(get, set)] pub pool_max: usize, } #[pymethods] impl PyPostgresConfig { #[new] #[pyo3(signature = (db_url = None,pool_max = 16,))] fn new(db_url: Option, pool_max: usize) -> PyResult { Ok(PyPostgresConfig { db_url, pool_max }) } } impl PyPostgresConfig { pub fn to_config_postgres(&self) -> config::PostgresConfig { config::PostgresConfig { db_url: self.db_url.clone().unwrap_or_default(), pool_max: self.pool_max, } } } #[pyclass] #[derive(Debug, Clone, PartialEq)] struct Router { host: String, port: u16, worker_urls: Vec, policy: PolicyType, worker_startup_timeout_secs: u64, worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, max_idle_secs: u64, assignment_mode: String, max_payload_size: usize, dp_aware: bool, api_key: Option, log_dir: Option, log_level: Option, json_log: bool, service_discovery: bool, selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, prefill_selector: HashMap, decode_selector: HashMap, bootstrap_port_annotation: String, prometheus_port: Option, prometheus_host: Option, prometheus_duration_buckets: Option>, request_timeout_secs: u64, shutdown_grace_period_secs: u64, request_id_headers: Option>, pd_disaggregation: bool, bucket_adjust_interval_secs: usize, prefill_urls: Option)>>, decode_urls: Option>, prefill_policy: Option, decode_policy: Option, max_concurrent_requests: i32, cors_allowed_origins: Vec, retry_max_retries: u32, retry_initial_backoff_ms: u64, retry_max_backoff_ms: u64, retry_backoff_multiplier: f32, retry_jitter_factor: f32, disable_retries: bool, cb_failure_threshold: u32, cb_success_threshold: u32, cb_timeout_duration_secs: u64, cb_window_duration_secs: u64, disable_circuit_breaker: bool, health_failure_threshold: u32, health_success_threshold: u32, health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, disable_health_check: bool, enable_igw: bool, queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, connection_mode: core::ConnectionMode, model_path: Option, tokenizer_path: Option, chat_template: Option, tokenizer_cache_enable_l0: bool, tokenizer_cache_l0_max_entries: usize, tokenizer_cache_enable_l1: bool, tokenizer_cache_l1_max_memory: usize, reasoning_parser: Option, tool_call_parser: Option, mcp_config_path: Option, backend: BackendType, history_backend: HistoryBackendType, oracle_config: Option, postgres_config: Option, redis_config: Option, client_cert_path: Option, client_key_path: Option, ca_cert_paths: Vec, server_cert_path: Option, server_key_path: Option, enable_trace: bool, otlp_traces_endpoint: String, control_plane_auth: Option, } impl Router { fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode { for url in worker_urls { if url.starts_with("grpc://") || url.starts_with("grpcs://") { return core::ConnectionMode::Grpc { port: None }; } } core::ConnectionMode::Http } pub fn to_router_config(&self) -> config::ConfigResult { use config::{ DiscoveryConfig, MetricsConfig, PolicyConfig as ConfigPolicyConfig, RoutingMode, }; let convert_policy = |policy: &PolicyType| -> ConfigPolicyConfig { match policy { PolicyType::Random => ConfigPolicyConfig::Random, PolicyType::RoundRobin => ConfigPolicyConfig::RoundRobin, PolicyType::CacheAware => ConfigPolicyConfig::CacheAware { cache_threshold: self.cache_threshold, balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, eviction_interval_secs: self.eviction_interval_secs, max_tree_size: self.max_tree_size, }, PolicyType::PowerOfTwo => ConfigPolicyConfig::PowerOfTwo { load_check_interval_secs: 5, }, PolicyType::Bucket => ConfigPolicyConfig::Bucket { balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, bucket_adjust_interval_secs: self.bucket_adjust_interval_secs, }, PolicyType::Manual => ConfigPolicyConfig::Manual { eviction_interval_secs: self.eviction_interval_secs, max_idle_secs: self.max_idle_secs, assignment_mode: match self.assignment_mode.as_str() { "random" => config::ManualAssignmentMode::Random, "min_load" => config::ManualAssignmentMode::MinLoad, "min_group" => config::ManualAssignmentMode::MinGroup, other => panic!("Unknown assignment mode: {}", other), }, }, PolicyType::ConsistentHashing => ConfigPolicyConfig::ConsistentHashing, PolicyType::PrefixHash => ConfigPolicyConfig::PrefixHash { prefix_token_count: 256, load_factor: 1.25, }, } }; let mode = if self.enable_igw { RoutingMode::Regular { worker_urls: vec![], } } else if matches!(self.backend, BackendType::Openai) { RoutingMode::OpenAI { worker_urls: self.worker_urls.clone(), } } else if self.pd_disaggregation { RoutingMode::PrefillDecode { prefill_urls: self.prefill_urls.clone().unwrap_or_default(), decode_urls: self.decode_urls.clone().unwrap_or_default(), prefill_policy: self.prefill_policy.as_ref().map(convert_policy), decode_policy: self.decode_policy.as_ref().map(convert_policy), } } else { RoutingMode::Regular { worker_urls: self.worker_urls.clone(), } }; let policy = convert_policy(&self.policy); let discovery = if self.service_discovery { Some(DiscoveryConfig { enabled: true, namespace: self.service_discovery_namespace.clone(), port: self.service_discovery_port, check_interval_secs: 60, selector: self.selector.clone(), prefill_selector: self.prefill_selector.clone(), decode_selector: self.decode_selector.clone(), bootstrap_port_annotation: self.bootstrap_port_annotation.clone(), router_selector: HashMap::new(), router_mesh_port_annotation: "sglang.ai/mesh-port".to_string(), }) } else { None }; let metrics = match (self.prometheus_port, self.prometheus_host.as_ref()) { (Some(port), Some(host)) => Some(MetricsConfig { port, host: host.clone(), }), _ => None, }; let trace_config = Some(config::TraceConfig { enable_trace: self.enable_trace, otlp_traces_endpoint: self.otlp_traces_endpoint.clone(), }); let history_backend = match self.history_backend { HistoryBackendType::Memory => config::HistoryBackend::Memory, HistoryBackendType::None => config::HistoryBackend::None, HistoryBackendType::Oracle => config::HistoryBackend::Oracle, HistoryBackendType::Postgres => config::HistoryBackend::Postgres, HistoryBackendType::Redis => config::HistoryBackend::Redis, }; let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) { self.oracle_config .as_ref() .map(|cfg| cfg.to_config_oracle()) } else { None }; let postgres_config = if matches!(self.history_backend, HistoryBackendType::Postgres) { self.postgres_config .as_ref() .map(|cfg| cfg.to_config_postgres()) } else { None }; let redis_config = if matches!(self.history_backend, HistoryBackendType::Redis) { self.redis_config .as_ref() .map(|cfg| cfg.to_config_redis()) } else { None }; config::RouterConfig::builder() .mode(mode) .policy(policy) .host(&self.host) .port(self.port) .connection_mode(self.connection_mode.clone()) .max_payload_size(self.max_payload_size) .request_timeout_secs(self.request_timeout_secs) .worker_startup_timeout_secs(self.worker_startup_timeout_secs) .worker_startup_check_interval_secs(self.worker_startup_check_interval) .max_concurrent_requests(self.max_concurrent_requests) .queue_size(self.queue_size) .queue_timeout_secs(self.queue_timeout_secs) .cors_allowed_origins(self.cors_allowed_origins.clone()) .retry_config(config::RetryConfig { max_retries: self.retry_max_retries, initial_backoff_ms: self.retry_initial_backoff_ms, max_backoff_ms: self.retry_max_backoff_ms, backoff_multiplier: self.retry_backoff_multiplier, jitter_factor: self.retry_jitter_factor, }) .circuit_breaker_config(config::CircuitBreakerConfig { failure_threshold: self.cb_failure_threshold, success_threshold: self.cb_success_threshold, timeout_duration_secs: self.cb_timeout_duration_secs, window_duration_secs: self.cb_window_duration_secs, }) .health_check_config(config::HealthCheckConfig { failure_threshold: self.health_failure_threshold, success_threshold: self.health_success_threshold, timeout_secs: self.health_check_timeout_secs, check_interval_secs: self.health_check_interval_secs, endpoint: self.health_check_endpoint.clone(), disable_health_check: self.disable_health_check, }) .tokenizer_cache(config::TokenizerCacheConfig { enable_l0: self.tokenizer_cache_enable_l0, l0_max_entries: self.tokenizer_cache_l0_max_entries, enable_l1: self.tokenizer_cache_enable_l1, l1_max_memory: self.tokenizer_cache_l1_max_memory, }) .history_backend(history_backend) .maybe_api_key(self.api_key.as_ref()) .maybe_discovery(discovery) .maybe_metrics(metrics) .maybe_trace(trace_config) .maybe_log_dir(self.log_dir.as_ref()) .maybe_log_level(self.log_level.as_ref()) .maybe_request_id_headers(self.request_id_headers.clone()) .maybe_rate_limit_tokens_per_second(self.rate_limit_tokens_per_second) .maybe_model_path(self.model_path.as_ref()) .maybe_tokenizer_path(self.tokenizer_path.as_ref()) .maybe_chat_template(self.chat_template.as_ref()) .maybe_oracle(oracle) .maybe_postgres(postgres_config) .maybe_redis(redis_config) .maybe_reasoning_parser(self.reasoning_parser.as_ref()) .maybe_tool_call_parser(self.tool_call_parser.as_ref()) .maybe_mcp_config_path(self.mcp_config_path.as_ref()) .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) .igw(self.enable_igw) .maybe_client_cert_and_key( self.client_cert_path.as_ref(), self.client_key_path.as_ref(), ) .add_ca_certificates(self.ca_cert_paths.clone()) .maybe_server_cert_and_key( self.server_cert_path.as_ref(), self.server_key_path.as_ref(), ) .build() } } #[pymethods] impl Router { #[new] #[pyo3(signature = ( worker_urls, policy = PolicyType::RoundRobin, host = String::from("0.0.0.0"), port = 3001, worker_startup_timeout_secs = 600, worker_startup_check_interval = 30, cache_threshold = 0.3, balance_abs_threshold = 64, balance_rel_threshold = 1.5, eviction_interval_secs = 120, max_tree_size = 2usize.pow(26), max_idle_secs = 14400, assignment_mode = String::from("random"), max_payload_size = 512 * 1024 * 1024, dp_aware = false, api_key = None, log_dir = None, log_level = None, json_log = false, service_discovery = false, selector = HashMap::new(), service_discovery_port = 80, service_discovery_namespace = None, prefill_selector = HashMap::new(), decode_selector = HashMap::new(), bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"), prometheus_port = None, prometheus_host = None, prometheus_duration_buckets = None, request_timeout_secs = 1800, shutdown_grace_period_secs = 180, request_id_headers = None, pd_disaggregation = false, bucket_adjust_interval_secs = 5, prefill_urls = None, decode_urls = None, prefill_policy = None, decode_policy = None, max_concurrent_requests = -1, cors_allowed_origins = vec![], retry_max_retries = 5, retry_initial_backoff_ms = 50, retry_max_backoff_ms = 30_000, retry_backoff_multiplier = 1.5, retry_jitter_factor = 0.2, disable_retries = false, cb_failure_threshold = 10, cb_success_threshold = 3, cb_timeout_duration_secs = 60, cb_window_duration_secs = 120, disable_circuit_breaker = false, health_failure_threshold = 3, health_success_threshold = 2, health_check_timeout_secs = 5, health_check_interval_secs = 60, health_check_endpoint = String::from("/health"), disable_health_check = false, enable_igw = false, queue_size = 100, queue_timeout_secs = 60, rate_limit_tokens_per_second = None, model_path = None, tokenizer_path = None, chat_template = None, tokenizer_cache_enable_l0 = false, tokenizer_cache_l0_max_entries = 10000, tokenizer_cache_enable_l1 = false, tokenizer_cache_l1_max_memory = 52428800, reasoning_parser = None, tool_call_parser = None, mcp_config_path = None, backend = BackendType::Sglang, history_backend = HistoryBackendType::Memory, oracle_config = None, postgres_config = None, redis_config = None, client_cert_path = None, client_key_path = None, ca_cert_paths = vec![], server_cert_path = None, server_key_path = None, enable_trace = false, otlp_traces_endpoint = String::from("localhost:4317"), control_plane_auth = None, ))] #[allow(clippy::too_many_arguments)] fn new( worker_urls: Vec, policy: PolicyType, host: String, port: u16, worker_startup_timeout_secs: u64, worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, max_idle_secs: u64, assignment_mode: String, max_payload_size: usize, dp_aware: bool, api_key: Option, log_dir: Option, log_level: Option, json_log: bool, service_discovery: bool, selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, prefill_selector: HashMap, decode_selector: HashMap, bootstrap_port_annotation: String, prometheus_port: Option, prometheus_host: Option, prometheus_duration_buckets: Option>, request_timeout_secs: u64, shutdown_grace_period_secs: u64, request_id_headers: Option>, pd_disaggregation: bool, bucket_adjust_interval_secs: usize, prefill_urls: Option)>>, decode_urls: Option>, prefill_policy: Option, decode_policy: Option, max_concurrent_requests: i32, cors_allowed_origins: Vec, retry_max_retries: u32, retry_initial_backoff_ms: u64, retry_max_backoff_ms: u64, retry_backoff_multiplier: f32, retry_jitter_factor: f32, disable_retries: bool, cb_failure_threshold: u32, cb_success_threshold: u32, cb_timeout_duration_secs: u64, cb_window_duration_secs: u64, disable_circuit_breaker: bool, health_failure_threshold: u32, health_success_threshold: u32, health_check_timeout_secs: u64, health_check_interval_secs: u64, health_check_endpoint: String, disable_health_check: bool, enable_igw: bool, queue_size: usize, queue_timeout_secs: u64, rate_limit_tokens_per_second: Option, model_path: Option, tokenizer_path: Option, chat_template: Option, tokenizer_cache_enable_l0: bool, tokenizer_cache_l0_max_entries: usize, tokenizer_cache_enable_l1: bool, tokenizer_cache_l1_max_memory: usize, reasoning_parser: Option, tool_call_parser: Option, mcp_config_path: Option, backend: BackendType, history_backend: HistoryBackendType, oracle_config: Option, postgres_config: Option, redis_config: Option, client_cert_path: Option, client_key_path: Option, ca_cert_paths: Vec, server_cert_path: Option, server_key_path: Option, enable_trace: bool, otlp_traces_endpoint: String, control_plane_auth: Option, ) -> PyResult { let mut all_urls = worker_urls.clone(); if let Some(ref prefill_urls) = prefill_urls { for (url, _) in prefill_urls { all_urls.push(url.clone()); } } if let Some(ref decode_urls) = decode_urls { all_urls.extend(decode_urls.clone()); } let connection_mode = Self::determine_connection_mode(&all_urls); Ok(Router { host, port, worker_urls, policy, worker_startup_timeout_secs, worker_startup_check_interval, cache_threshold, balance_abs_threshold, balance_rel_threshold, eviction_interval_secs, max_tree_size, max_idle_secs, assignment_mode, max_payload_size, dp_aware, api_key, log_dir, log_level, json_log, service_discovery, selector, service_discovery_port, service_discovery_namespace, prefill_selector, decode_selector, bootstrap_port_annotation, prometheus_port, prometheus_host, prometheus_duration_buckets, request_timeout_secs, shutdown_grace_period_secs, request_id_headers, pd_disaggregation, bucket_adjust_interval_secs, prefill_urls, decode_urls, prefill_policy, decode_policy, max_concurrent_requests, cors_allowed_origins, retry_max_retries, retry_initial_backoff_ms, retry_max_backoff_ms, retry_backoff_multiplier, retry_jitter_factor, disable_retries, cb_failure_threshold, cb_success_threshold, cb_timeout_duration_secs, cb_window_duration_secs, disable_circuit_breaker, health_failure_threshold, health_success_threshold, health_check_timeout_secs, health_check_interval_secs, health_check_endpoint, disable_health_check, enable_igw, queue_size, queue_timeout_secs, rate_limit_tokens_per_second, connection_mode, model_path, tokenizer_path, chat_template, tokenizer_cache_enable_l0, tokenizer_cache_l0_max_entries, tokenizer_cache_enable_l1, tokenizer_cache_l1_max_memory, reasoning_parser, tool_call_parser, mcp_config_path, backend, history_backend, oracle_config, postgres_config, redis_config, client_cert_path, client_key_path, ca_cert_paths, server_cert_path, server_key_path, enable_trace, otlp_traces_endpoint, control_plane_auth, }) } fn start(&self) -> PyResult<()> { use observability::metrics::PrometheusConfig; let router_config = self.to_router_config().map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!("Configuration error: {}", e)) })?; router_config.validate().map_err(|e| { pyo3::exceptions::PyValueError::new_err(format!( "Configuration validation failed: {}", e )) })?; let service_discovery_config = if self.service_discovery { Some(service_discovery::ServiceDiscoveryConfig { enabled: true, selector: self.selector.clone(), check_interval: std::time::Duration::from_secs(60), port: self.service_discovery_port, namespace: self.service_discovery_namespace.clone(), pd_mode: self.pd_disaggregation, prefill_selector: self.prefill_selector.clone(), decode_selector: self.decode_selector.clone(), bootstrap_port_annotation: self.bootstrap_port_annotation.clone(), router_selector: HashMap::new(), router_mesh_port_annotation: "sglang.ai/mesh-port".to_string(), }) } else { None }; let prometheus_config = Some(PrometheusConfig { port: self.prometheus_port.unwrap_or(29000), host: self .prometheus_host .clone() .unwrap_or_else(|| "127.0.0.1".to_string()), duration_buckets: self.prometheus_duration_buckets.clone(), }); let runtime = tokio::runtime::Runtime::new() .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; runtime.block_on(async move { server::startup(server::ServerConfig { host: self.host.clone(), port: self.port, router_config, max_payload_size: self.max_payload_size, log_dir: self.log_dir.clone(), log_level: self.log_level.clone(), json_log: self.json_log, service_discovery_config, prometheus_config, request_timeout_secs: self.request_timeout_secs, request_id_headers: self.request_id_headers.clone(), shutdown_grace_period_secs: self.shutdown_grace_period_secs, control_plane_auth: self .control_plane_auth .as_ref() .map(|c| c.to_auth_control_plane_config()), mesh_server_config: None, }) .await .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string())) }) } } /// Get simple version string (default for --version) #[pyfunction] fn get_version_string() -> String { version::get_version_string() } /// Get verbose version information string with full build details (for --version-verbose) #[pyfunction] fn get_verbose_version_string() -> String { version::get_verbose_version_string() } /// Get the list of available tool call parsers from the Rust factory. #[pyfunction] fn get_available_tool_call_parsers() -> Vec { static PARSERS: OnceCell> = OnceCell::new(); PARSERS .get_or_init(|| { let factory = tool_parser::ParserFactory::new(); factory.list_parsers() }) .clone() } #[pymodule] fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(get_version_string, m)?)?; m.add_function(wrap_pyfunction!(get_verbose_version_string, m)?)?; m.add_function(wrap_pyfunction!(get_available_tool_call_parsers, m)?)?; Ok(()) }