File size: 8,297 Bytes
bbb1195 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | use axum::{
Router,
routing::{get, post},
extract::DefaultBodyLimit,
response::{IntoResponse, Response, Json},
};
use tracing::{debug, error};
use tower_http::trace::TraceLayer;
use std::sync::Arc;
use tokio::sync::oneshot;
use crate::proxy::TokenManager;
/// Axum 应用状态
#[derive(Clone)]
pub struct AppState {
pub token_manager: Arc<TokenManager>,
pub anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
pub openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
pub custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
#[allow(dead_code)]
pub request_timeout: u64, // API 请求超时(秒)
#[allow(dead_code)]
pub thought_signature_map: Arc<tokio::sync::Mutex<std::collections::HashMap<String, String>>>, // 思维链签名映射 (ID -> Signature)
#[allow(dead_code)]
pub upstream_proxy: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>,
pub upstream: Arc<crate::proxy::upstream::client::UpstreamClient>,
}
/// Axum 服务器实例
pub struct AxumServer {
shutdown_tx: Option<oneshot::Sender<()>>,
anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
proxy_state: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>,
}
impl AxumServer {
pub async fn update_mapping(&self, config: &crate::proxy::config::ProxyConfig) {
{
let mut m = self.anthropic_mapping.write().await;
*m = config.anthropic_mapping.clone();
}
{
let mut m = self.openai_mapping.write().await;
*m = config.openai_mapping.clone();
}
{
let mut m = self.custom_mapping.write().await;
*m = config.custom_mapping.clone();
}
tracing::info!("模型映射 (Anthropic/OpenAI/Custom) 已全量热更新");
}
/// 更新代理配置
pub async fn update_proxy(&self, new_config: crate::proxy::config::UpstreamProxyConfig) {
let mut proxy = self.proxy_state.write().await;
*proxy = new_config;
tracing::info!("上游代理配置已热更新");
}
/// 启动 Axum 服务器
pub async fn start(
host: String,
port: u16,
token_manager: Arc<TokenManager>,
anthropic_mapping: std::collections::HashMap<String, String>,
openai_mapping: std::collections::HashMap<String, String>,
custom_mapping: std::collections::HashMap<String, String>,
_request_timeout: u64,
upstream_proxy: crate::proxy::config::UpstreamProxyConfig,
) -> Result<(Self, tokio::task::JoinHandle<()>), String> {
let mapping_state = Arc::new(tokio::sync::RwLock::new(anthropic_mapping));
let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(openai_mapping));
let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(custom_mapping));
let proxy_state = Arc::new(tokio::sync::RwLock::new(upstream_proxy.clone()));
let state = AppState {
token_manager: token_manager.clone(),
anthropic_mapping: mapping_state.clone(),
openai_mapping: openai_mapping_state.clone(),
custom_mapping: custom_mapping_state.clone(),
request_timeout: 300, // 5分钟超时
thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
upstream_proxy: proxy_state.clone(),
upstream: Arc::new(crate::proxy::upstream::client::UpstreamClient::new(Some(upstream_proxy.clone()))),
};
// 构建路由 - 使用新架构的 handlers!
use crate::proxy::handlers;
// 构建路由
let app = Router::new()
// OpenAI Protocol
.route("/v1/models", get(handlers::openai::handle_list_models))
.route("/v1/chat/completions", post(handlers::openai::handle_chat_completions))
.route("/v1/completions", post(handlers::openai::handle_completions))
.route("/v1/responses", post(handlers::openai::handle_completions)) // 兼容 Codex CLI
// Claude Protocol
.route("/v1/messages", post(handlers::claude::handle_messages))
.route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens))
.route("/v1/models/claude", get(handlers::claude::handle_list_models))
// Gemini Protocol (Native)
.route("/v1beta/models", get(handlers::gemini::handle_list_models))
// Handle both GET (get info) and POST (generateContent with colon) at the same route
.route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate))
.route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) // Specific route priority
.route("/healthz", get(health_check_handler))
.layer(DefaultBodyLimit::max(100 * 1024 * 1024))
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(crate::proxy::middleware::auth_middleware))
.layer(crate::proxy::middleware::cors_layer())
.with_state(state);
// 绑定地址
let addr = format!("{}:{}", host, port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| format!("地址 {} 绑定失败: {}", addr, e))?;
tracing::info!("反代服务器启动在 http://{}", addr);
// 创建关闭通道
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
let server_instance = Self {
shutdown_tx: Some(shutdown_tx),
anthropic_mapping: mapping_state.clone(),
openai_mapping: openai_mapping_state.clone(),
custom_mapping: custom_mapping_state.clone(),
proxy_state,
};
// 在新任务中启动服务器
let handle = tokio::spawn(async move {
use hyper_util::rt::TokioIo;
use hyper::server::conn::http1;
use hyper_util::service::TowerToHyperService;
loop {
tokio::select! {
res = listener.accept() => {
match res {
Ok((stream, _)) => {
let io = TokioIo::new(stream);
let service = TowerToHyperService::new(app.clone());
tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades() // 支持 WebSocket (如果以后需要)
.await
{
debug!("连接处理结束或出错: {:?}", err);
}
});
}
Err(e) => {
error!("接收连接失败: {:?}", e);
}
}
}
_ = &mut shutdown_rx => {
tracing::info!("反代服务器停止监听");
break;
}
}
}
});
Ok((
server_instance,
handle,
))
}
/// 停止服务器
pub fn stop(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
// ===== API 处理器 (旧代码已移除,由 src/proxy/handlers/* 接管) =====
/// 健康检查处理器
async fn health_check_handler() -> Response {
Json(serde_json::json!({
"status": "ok"
})).into_response()
}
|