| use axum::{ |
| Router, |
| routing::{get, post}, |
| extract::DefaultBodyLimit, |
| response::{IntoResponse, Response, Json}, |
| }; |
| use tracing::{debug, error}; |
| use tower_http::trace::TraceLayer; |
| use std::sync::Arc; |
| use tokio::sync::oneshot; |
| use crate::proxy::TokenManager; |
|
|
|
|
| |
| #[derive(Clone)] |
| pub struct AppState { |
| pub token_manager: Arc<TokenManager>, |
| pub anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, |
| pub openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, |
| pub custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, |
| #[allow(dead_code)] |
| pub request_timeout: u64, |
| #[allow(dead_code)] |
| pub thought_signature_map: Arc<tokio::sync::Mutex<std::collections::HashMap<String, String>>>, |
| #[allow(dead_code)] |
| pub upstream_proxy: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>, |
| pub upstream: Arc<crate::proxy::upstream::client::UpstreamClient>, |
| } |
|
|
| |
| pub struct AxumServer { |
| shutdown_tx: Option<oneshot::Sender<()>>, |
| anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, |
| openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, |
| custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, |
| proxy_state: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>, |
| } |
|
|
| impl AxumServer { |
| pub async fn update_mapping(&self, config: &crate::proxy::config::ProxyConfig) { |
| { |
| let mut m = self.anthropic_mapping.write().await; |
| *m = config.anthropic_mapping.clone(); |
| } |
| { |
| let mut m = self.openai_mapping.write().await; |
| *m = config.openai_mapping.clone(); |
| } |
| { |
| let mut m = self.custom_mapping.write().await; |
| *m = config.custom_mapping.clone(); |
| } |
| tracing::info!("模型映射 (Anthropic/OpenAI/Custom) 已全量热更新"); |
| } |
|
|
| |
| pub async fn update_proxy(&self, new_config: crate::proxy::config::UpstreamProxyConfig) { |
| let mut proxy = self.proxy_state.write().await; |
| *proxy = new_config; |
| tracing::info!("上游代理配置已热更新"); |
| } |
| |
| pub async fn start( |
| host: String, |
| port: u16, |
| token_manager: Arc<TokenManager>, |
| anthropic_mapping: std::collections::HashMap<String, String>, |
| openai_mapping: std::collections::HashMap<String, String>, |
| custom_mapping: std::collections::HashMap<String, String>, |
| _request_timeout: u64, |
| upstream_proxy: crate::proxy::config::UpstreamProxyConfig, |
| ) -> Result<(Self, tokio::task::JoinHandle<()>), String> { |
| let mapping_state = Arc::new(tokio::sync::RwLock::new(anthropic_mapping)); |
| let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(openai_mapping)); |
| let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(custom_mapping)); |
| let proxy_state = Arc::new(tokio::sync::RwLock::new(upstream_proxy.clone())); |
|
|
| let state = AppState { |
| token_manager: token_manager.clone(), |
| anthropic_mapping: mapping_state.clone(), |
| openai_mapping: openai_mapping_state.clone(), |
| custom_mapping: custom_mapping_state.clone(), |
| request_timeout: 300, |
| thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), |
| upstream_proxy: proxy_state.clone(), |
| upstream: Arc::new(crate::proxy::upstream::client::UpstreamClient::new(Some(upstream_proxy.clone()))), |
| }; |
| |
| |
| use crate::proxy::handlers; |
| |
| let app = Router::new() |
| |
| .route("/v1/models", get(handlers::openai::handle_list_models)) |
| .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions)) |
| .route("/v1/completions", post(handlers::openai::handle_completions)) |
| .route("/v1/responses", post(handlers::openai::handle_completions)) |
| |
| |
| .route("/v1/messages", post(handlers::claude::handle_messages)) |
| .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens)) |
| .route("/v1/models/claude", get(handlers::claude::handle_list_models)) |
| |
| |
| .route("/v1beta/models", get(handlers::gemini::handle_list_models)) |
| |
| .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate)) |
| .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) |
| .route("/healthz", get(health_check_handler)) |
| .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) |
| .layer(TraceLayer::new_for_http()) |
| .layer(axum::middleware::from_fn(crate::proxy::middleware::auth_middleware)) |
| .layer(crate::proxy::middleware::cors_layer()) |
| .with_state(state); |
| |
| |
| let addr = format!("{}:{}", host, port); |
| let listener = tokio::net::TcpListener::bind(&addr) |
| .await |
| .map_err(|e| format!("地址 {} 绑定失败: {}", addr, e))?; |
| |
| tracing::info!("反代服务器启动在 http://{}", addr); |
| |
| |
| let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); |
| |
| let server_instance = Self { |
| shutdown_tx: Some(shutdown_tx), |
| anthropic_mapping: mapping_state.clone(), |
| openai_mapping: openai_mapping_state.clone(), |
| custom_mapping: custom_mapping_state.clone(), |
| proxy_state, |
| }; |
| |
| |
| let handle = tokio::spawn(async move { |
| use hyper_util::rt::TokioIo; |
| use hyper::server::conn::http1; |
| use hyper_util::service::TowerToHyperService; |
|
|
| loop { |
| tokio::select! { |
| res = listener.accept() => { |
| match res { |
| Ok((stream, _)) => { |
| let io = TokioIo::new(stream); |
| let service = TowerToHyperService::new(app.clone()); |
| |
| tokio::task::spawn(async move { |
| if let Err(err) = http1::Builder::new() |
| .serve_connection(io, service) |
| .with_upgrades() |
| .await |
| { |
| debug!("连接处理结束或出错: {:?}", err); |
| } |
| }); |
| } |
| Err(e) => { |
| error!("接收连接失败: {:?}", e); |
| } |
| } |
| } |
| _ = &mut shutdown_rx => { |
| tracing::info!("反代服务器停止监听"); |
| break; |
| } |
| } |
| } |
| }); |
| |
| Ok(( |
| server_instance, |
| handle, |
| )) |
| } |
| |
| |
| pub fn stop(mut self) { |
| if let Some(tx) = self.shutdown_tx.take() { |
| let _ = tx.send(()); |
| } |
| } |
| } |
|
|
| |
|
|
| |
| async fn health_check_handler() -> Response { |
| Json(serde_json::json!({ |
| "status": "ok" |
| })).into_response() |
| } |
|
|