use axum::{ Router, routing::{get, post}, extract::DefaultBodyLimit, response::{IntoResponse, Response, Json}, }; use tracing::{debug, error}; use tower_http::trace::TraceLayer; use std::sync::Arc; use tokio::sync::oneshot; use crate::proxy::TokenManager; /// Axum 应用状态 #[derive(Clone)] pub struct AppState { pub token_manager: Arc, pub anthropic_mapping: Arc>>, pub openai_mapping: Arc>>, pub custom_mapping: Arc>>, #[allow(dead_code)] pub request_timeout: u64, // API 请求超时(秒) #[allow(dead_code)] pub thought_signature_map: Arc>>, // 思维链签名映射 (ID -> Signature) #[allow(dead_code)] pub upstream_proxy: Arc>, pub upstream: Arc, } /// Axum 服务器实例 pub struct AxumServer { shutdown_tx: Option>, anthropic_mapping: Arc>>, openai_mapping: Arc>>, custom_mapping: Arc>>, proxy_state: Arc>, } impl AxumServer { pub async fn update_mapping(&self, config: &crate::proxy::config::ProxyConfig) { { let mut m = self.anthropic_mapping.write().await; *m = config.anthropic_mapping.clone(); } { let mut m = self.openai_mapping.write().await; *m = config.openai_mapping.clone(); } { let mut m = self.custom_mapping.write().await; *m = config.custom_mapping.clone(); } tracing::info!("模型映射 (Anthropic/OpenAI/Custom) 已全量热更新"); } /// 更新代理配置 pub async fn update_proxy(&self, new_config: crate::proxy::config::UpstreamProxyConfig) { let mut proxy = self.proxy_state.write().await; *proxy = new_config; tracing::info!("上游代理配置已热更新"); } /// 启动 Axum 服务器 pub async fn start( host: String, port: u16, token_manager: Arc, anthropic_mapping: std::collections::HashMap, openai_mapping: std::collections::HashMap, custom_mapping: std::collections::HashMap, _request_timeout: u64, upstream_proxy: crate::proxy::config::UpstreamProxyConfig, ) -> Result<(Self, tokio::task::JoinHandle<()>), String> { let mapping_state = Arc::new(tokio::sync::RwLock::new(anthropic_mapping)); let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(openai_mapping)); let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(custom_mapping)); let proxy_state = Arc::new(tokio::sync::RwLock::new(upstream_proxy.clone())); let state = AppState { token_manager: token_manager.clone(), anthropic_mapping: mapping_state.clone(), openai_mapping: openai_mapping_state.clone(), custom_mapping: custom_mapping_state.clone(), request_timeout: 300, // 5分钟超时 thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), upstream_proxy: proxy_state.clone(), upstream: Arc::new(crate::proxy::upstream::client::UpstreamClient::new(Some(upstream_proxy.clone()))), }; // 构建路由 - 使用新架构的 handlers! use crate::proxy::handlers; // 构建路由 let app = Router::new() // OpenAI Protocol .route("/v1/models", get(handlers::openai::handle_list_models)) .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions)) .route("/v1/completions", post(handlers::openai::handle_completions)) .route("/v1/responses", post(handlers::openai::handle_completions)) // 兼容 Codex CLI // Claude Protocol .route("/v1/messages", post(handlers::claude::handle_messages)) .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens)) .route("/v1/models/claude", get(handlers::claude::handle_list_models)) // Gemini Protocol (Native) .route("/v1beta/models", get(handlers::gemini::handle_list_models)) // Handle both GET (get info) and POST (generateContent with colon) at the same route .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate)) .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) // Specific route priority .route("/healthz", get(health_check_handler)) .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) .layer(TraceLayer::new_for_http()) .layer(axum::middleware::from_fn(crate::proxy::middleware::auth_middleware)) .layer(crate::proxy::middleware::cors_layer()) .with_state(state); // 绑定地址 let addr = format!("{}:{}", host, port); let listener = tokio::net::TcpListener::bind(&addr) .await .map_err(|e| format!("地址 {} 绑定失败: {}", addr, e))?; tracing::info!("反代服务器启动在 http://{}", addr); // 创建关闭通道 let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); let server_instance = Self { shutdown_tx: Some(shutdown_tx), anthropic_mapping: mapping_state.clone(), openai_mapping: openai_mapping_state.clone(), custom_mapping: custom_mapping_state.clone(), proxy_state, }; // 在新任务中启动服务器 let handle = tokio::spawn(async move { use hyper_util::rt::TokioIo; use hyper::server::conn::http1; use hyper_util::service::TowerToHyperService; loop { tokio::select! { res = listener.accept() => { match res { Ok((stream, _)) => { let io = TokioIo::new(stream); let service = TowerToHyperService::new(app.clone()); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .serve_connection(io, service) .with_upgrades() // 支持 WebSocket (如果以后需要) .await { debug!("连接处理结束或出错: {:?}", err); } }); } Err(e) => { error!("接收连接失败: {:?}", e); } } } _ = &mut shutdown_rx => { tracing::info!("反代服务器停止监听"); break; } } } }); Ok(( server_instance, handle, )) } /// 停止服务器 pub fn stop(mut self) { if let Some(tx) = self.shutdown_tx.take() { let _ = tx.send(()); } } } // ===== API 处理器 (旧代码已移除,由 src/proxy/handlers/* 接管) ===== /// 健康检查处理器 async fn health_check_handler() -> Response { Json(serde_json::json!({ "status": "ok" })).into_response() }