| use std::sync::Arc; |
| use std::path::PathBuf; |
| use axum::{Router, extract::DefaultBodyLimit}; |
| use tower_http::{trace::TraceLayer, services::ServeDir}; |
| use tracing::info; |
|
|
| use antigravity_server::{ |
| modules, |
| proxy::TokenManager, |
| api, |
| }; |
|
|
| #[tokio::main] |
| async fn main() { |
| |
| modules::logger::init_logger(); |
|
|
| info!("Starting Antigravity API Proxy Server..."); |
|
|
| |
| |
| let data_dir = if PathBuf::from("/data").exists() { |
| PathBuf::from("/data") |
| } else if PathBuf::from("./data").exists() { |
| PathBuf::from("./data").canonicalize().unwrap_or_else(|_| PathBuf::from("./data")) |
| } else { |
| let home_dir = dirs::home_dir() |
| .expect("Cannot get home directory") |
| .join(".antigravity_tools"); |
| |
| let _ = std::fs::create_dir_all(&home_dir); |
| let _ = std::fs::create_dir_all(home_dir.join("accounts")); |
| home_dir |
| }; |
|
|
| info!("Using data directory: {:?}", data_dir); |
|
|
| |
| let config = modules::load_app_config().unwrap_or_default(); |
| let proxy_config = config.proxy.clone(); |
|
|
| |
| let token_manager = Arc::new(TokenManager::new(data_dir.clone())); |
|
|
| |
| match token_manager.load_accounts().await { |
| Ok(count) => { |
| info!("Loaded {} accounts from file system", count); |
| } |
| Err(e) => { |
| info!("Could not load accounts from files: {} (this is ok)", e); |
| } |
| } |
|
|
| |
| |
| if let Ok(tokens_str) = std::env::var("REFRESH_TOKENS") { |
| info!("Found REFRESH_TOKENS environment variable, loading accounts..."); |
| let tokens: Vec<&str> = tokens_str.split(',') |
| .map(|s| s.trim()) |
| .filter(|s| !s.is_empty()) |
| .collect(); |
|
|
| let mut success_count = 0; |
| for (idx, refresh_token) in tokens.iter().enumerate() { |
| info!("Processing token {} of {}...", idx + 1, tokens.len()); |
| match load_account_from_refresh_token(&token_manager, refresh_token).await { |
| Ok(email) => { |
| info!("Loaded account: {}", email); |
| success_count += 1; |
| } |
| Err(e) => { |
| tracing::error!("Failed to load token {}: {}", idx + 1, e); |
| } |
| } |
| } |
| info!("Loaded {} accounts from environment variable", success_count); |
| } |
|
|
| |
| let mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.anthropic_mapping.clone())); |
| let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.openai_mapping.clone())); |
| let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.custom_mapping.clone())); |
| let proxy_state = Arc::new(tokio::sync::RwLock::new(proxy_config.upstream_proxy.clone())); |
|
|
| let app_state = antigravity_server::proxy::server::AppState { |
| token_manager: token_manager.clone(), |
| anthropic_mapping: mapping_state, |
| openai_mapping: openai_mapping_state, |
| custom_mapping: custom_mapping_state, |
| request_timeout: 300, |
| thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), |
| upstream_proxy: proxy_state, |
| upstream: Arc::new(antigravity_server::proxy::upstream::client::UpstreamClient::new( |
| Some(proxy_config.upstream_proxy.clone()) |
| )), |
| }; |
|
|
| |
| use antigravity_server::proxy::handlers; |
| use axum::routing::{get, post}; |
|
|
| let proxy_routes = Router::new() |
| |
| .route("/v1/models", get(handlers::openai::handle_list_models)) |
| .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions)) |
| .route("/v1/completions", post(handlers::openai::handle_completions)) |
| .route("/v1/responses", post(handlers::openai::handle_completions)) |
| |
| .route("/v1/messages", post(handlers::claude::handle_messages)) |
| .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens)) |
| .route("/v1/models/claude", get(handlers::claude::handle_list_models)) |
| |
| .route("/v1beta/models", get(handlers::gemini::handle_list_models)) |
| .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate)) |
| .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) |
| |
| .route("/healthz", get(health_check)) |
| .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) |
| .layer(TraceLayer::new_for_http()) |
| .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::auth_middleware)) |
| .layer(antigravity_server::proxy::middleware::cors_layer()) |
| .with_state(app_state); |
|
|
| |
| |
| |
| let app = Router::new() |
| .merge(api::api_routes::<()>()) |
| .merge(proxy_routes) |
| .fallback_service(ServeDir::new("static")) |
| .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::basic_auth_middleware)); |
|
|
| |
| let port = std::env::var("PORT") |
| .unwrap_or_else(|_| "7860".to_string()) |
| .parse::<u16>() |
| .unwrap_or(7860); |
|
|
| let addr = format!("0.0.0.0:{}", port); |
| info!("Server listening on http://{}", addr); |
|
|
| |
| let keepalive_port = port; |
| tokio::spawn(async move { |
| keepalive_task(keepalive_port).await; |
| }); |
|
|
| let listener = tokio::net::TcpListener::bind(&addr) |
| .await |
| .expect("Failed to bind address"); |
|
|
| axum::serve(listener, app) |
| .await |
| .expect("Server error"); |
| } |
|
|
| |
| |
| async fn keepalive_task(port: u16) { |
| use tokio::time::{interval, Duration}; |
|
|
| |
| tokio::time::sleep(Duration::from_secs(30)).await; |
|
|
| let client = reqwest::Client::new(); |
| let url = format!("http://127.0.0.1:{}/healthz", port); |
|
|
| |
| let enabled = std::env::var("KEEPALIVE_ENABLED") |
| .map(|v| v != "false" && v != "0") |
| .unwrap_or(true); |
|
|
| if !enabled { |
| info!("Keep-alive task disabled"); |
| return; |
| } |
|
|
| |
| let interval_secs = std::env::var("KEEPALIVE_INTERVAL") |
| .ok() |
| .and_then(|v| v.parse::<u64>().ok()) |
| .unwrap_or(300); |
|
|
| info!("Keep-alive task started (interval: {}s)", interval_secs); |
|
|
| let mut ticker = interval(Duration::from_secs(interval_secs)); |
|
|
| loop { |
| ticker.tick().await; |
|
|
| match client.get(&url).send().await { |
| Ok(resp) if resp.status().is_success() => { |
| tracing::debug!("Keep-alive ping successful"); |
| } |
| Ok(resp) => { |
| tracing::warn!("Keep-alive ping returned status: {}", resp.status()); |
| } |
| Err(e) => { |
| tracing::warn!("Keep-alive ping failed: {}", e); |
| } |
| } |
| } |
| } |
|
|
| |
| async fn health_check() -> axum::Json<serde_json::Value> { |
| axum::Json(serde_json::json!({ |
| "status": "ok", |
| "version": env!("CARGO_PKG_VERSION") |
| })) |
| } |
|
|
| |
| async fn load_account_from_refresh_token( |
| token_manager: &Arc<TokenManager>, |
| refresh_token: &str, |
| ) -> Result<String, String> { |
| use antigravity_server::modules::oauth; |
| use antigravity_server::proxy::project_resolver; |
|
|
| |
| let token_res = oauth::refresh_access_token(refresh_token) |
| .await |
| .map_err(|e| format!("Failed to refresh token: {}", e))?; |
|
|
| |
| let user_info = oauth::get_user_info(&token_res.access_token) |
| .await |
| .map_err(|e| format!("Failed to get user info: {}", e))?; |
|
|
| let email = user_info.email.clone(); |
|
|
| |
| let project_id = project_resolver::fetch_project_id(&token_res.access_token) |
| .await |
| .ok(); |
|
|
| |
| let now = chrono::Utc::now().timestamp(); |
| let expiry_timestamp = now + token_res.expires_in; |
|
|
| |
| token_manager.add_token( |
| token_res.access_token, |
| refresh_token.to_string(), |
| expiry_timestamp, |
| email.clone(), |
| project_id, |
| ).await; |
|
|
| Ok(email) |
| } |
|
|