gemini / server /src /proxy /server.rs
yinming
feat: Antigravity API Proxy for HuggingFace Spaces
bbb1195
use axum::{
Router,
routing::{get, post},
extract::DefaultBodyLimit,
response::{IntoResponse, Response, Json},
};
use tracing::{debug, error};
use tower_http::trace::TraceLayer;
use std::sync::Arc;
use tokio::sync::oneshot;
use crate::proxy::TokenManager;
/// Axum 应用状态
#[derive(Clone)]
pub struct AppState {
pub token_manager: Arc<TokenManager>,
pub anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
pub openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
pub custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
#[allow(dead_code)]
pub request_timeout: u64, // API 请求超时(秒)
#[allow(dead_code)]
pub thought_signature_map: Arc<tokio::sync::Mutex<std::collections::HashMap<String, String>>>, // 思维链签名映射 (ID -> Signature)
#[allow(dead_code)]
pub upstream_proxy: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>,
pub upstream: Arc<crate::proxy::upstream::client::UpstreamClient>,
}
/// Axum 服务器实例
pub struct AxumServer {
shutdown_tx: Option<oneshot::Sender<()>>,
anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
proxy_state: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>,
}
impl AxumServer {
pub async fn update_mapping(&self, config: &crate::proxy::config::ProxyConfig) {
{
let mut m = self.anthropic_mapping.write().await;
*m = config.anthropic_mapping.clone();
}
{
let mut m = self.openai_mapping.write().await;
*m = config.openai_mapping.clone();
}
{
let mut m = self.custom_mapping.write().await;
*m = config.custom_mapping.clone();
}
tracing::info!("模型映射 (Anthropic/OpenAI/Custom) 已全量热更新");
}
/// 更新代理配置
pub async fn update_proxy(&self, new_config: crate::proxy::config::UpstreamProxyConfig) {
let mut proxy = self.proxy_state.write().await;
*proxy = new_config;
tracing::info!("上游代理配置已热更新");
}
/// 启动 Axum 服务器
pub async fn start(
host: String,
port: u16,
token_manager: Arc<TokenManager>,
anthropic_mapping: std::collections::HashMap<String, String>,
openai_mapping: std::collections::HashMap<String, String>,
custom_mapping: std::collections::HashMap<String, String>,
_request_timeout: u64,
upstream_proxy: crate::proxy::config::UpstreamProxyConfig,
) -> Result<(Self, tokio::task::JoinHandle<()>), String> {
let mapping_state = Arc::new(tokio::sync::RwLock::new(anthropic_mapping));
let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(openai_mapping));
let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(custom_mapping));
let proxy_state = Arc::new(tokio::sync::RwLock::new(upstream_proxy.clone()));
let state = AppState {
token_manager: token_manager.clone(),
anthropic_mapping: mapping_state.clone(),
openai_mapping: openai_mapping_state.clone(),
custom_mapping: custom_mapping_state.clone(),
request_timeout: 300, // 5分钟超时
thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
upstream_proxy: proxy_state.clone(),
upstream: Arc::new(crate::proxy::upstream::client::UpstreamClient::new(Some(upstream_proxy.clone()))),
};
// 构建路由 - 使用新架构的 handlers!
use crate::proxy::handlers;
// 构建路由
let app = Router::new()
// OpenAI Protocol
.route("/v1/models", get(handlers::openai::handle_list_models))
.route("/v1/chat/completions", post(handlers::openai::handle_chat_completions))
.route("/v1/completions", post(handlers::openai::handle_completions))
.route("/v1/responses", post(handlers::openai::handle_completions)) // 兼容 Codex CLI
// Claude Protocol
.route("/v1/messages", post(handlers::claude::handle_messages))
.route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens))
.route("/v1/models/claude", get(handlers::claude::handle_list_models))
// Gemini Protocol (Native)
.route("/v1beta/models", get(handlers::gemini::handle_list_models))
// Handle both GET (get info) and POST (generateContent with colon) at the same route
.route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate))
.route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) // Specific route priority
.route("/healthz", get(health_check_handler))
.layer(DefaultBodyLimit::max(100 * 1024 * 1024))
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(crate::proxy::middleware::auth_middleware))
.layer(crate::proxy::middleware::cors_layer())
.with_state(state);
// 绑定地址
let addr = format!("{}:{}", host, port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| format!("地址 {} 绑定失败: {}", addr, e))?;
tracing::info!("反代服务器启动在 http://{}", addr);
// 创建关闭通道
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
let server_instance = Self {
shutdown_tx: Some(shutdown_tx),
anthropic_mapping: mapping_state.clone(),
openai_mapping: openai_mapping_state.clone(),
custom_mapping: custom_mapping_state.clone(),
proxy_state,
};
// 在新任务中启动服务器
let handle = tokio::spawn(async move {
use hyper_util::rt::TokioIo;
use hyper::server::conn::http1;
use hyper_util::service::TowerToHyperService;
loop {
tokio::select! {
res = listener.accept() => {
match res {
Ok((stream, _)) => {
let io = TokioIo::new(stream);
let service = TowerToHyperService::new(app.clone());
tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(io, service)
.with_upgrades() // 支持 WebSocket (如果以后需要)
.await
{
debug!("连接处理结束或出错: {:?}", err);
}
});
}
Err(e) => {
error!("接收连接失败: {:?}", e);
}
}
}
_ = &mut shutdown_rx => {
tracing::info!("反代服务器停止监听");
break;
}
}
}
});
Ok((
server_instance,
handle,
))
}
/// 停止服务器
pub fn stop(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
// ===== API 处理器 (旧代码已移除,由 src/proxy/handlers/* 接管) =====
/// 健康检查处理器
async fn health_check_handler() -> Response {
Json(serde_json::json!({
"status": "ok"
})).into_response()
}