use std::sync::Arc; use std::path::PathBuf; use axum::{Router, extract::DefaultBodyLimit}; use tower_http::{trace::TraceLayer, services::ServeDir}; use tracing::info; use antigravity_server::{ modules, proxy::TokenManager, api, }; #[tokio::main] async fn main() { // Initialize logger modules::logger::init_logger(); info!("Starting Antigravity API Proxy Server..."); // Get data directory // Priority: /data (HF persistent) > ./data (current dir) > ~/.antigravity_tools (fallback) let data_dir = if PathBuf::from("/data").exists() { PathBuf::from("/data") } else if PathBuf::from("./data").exists() { PathBuf::from("./data").canonicalize().unwrap_or_else(|_| PathBuf::from("./data")) } else { let home_dir = dirs::home_dir() .expect("Cannot get home directory") .join(".antigravity_tools"); // Create the directory if it doesn't exist let _ = std::fs::create_dir_all(&home_dir); let _ = std::fs::create_dir_all(home_dir.join("accounts")); home_dir }; info!("Using data directory: {:?}", data_dir); // Load configuration let config = modules::load_app_config().unwrap_or_default(); let proxy_config = config.proxy.clone(); // Initialize token manager let token_manager = Arc::new(TokenManager::new(data_dir.clone())); // Load accounts from file system (if any) match token_manager.load_accounts().await { Ok(count) => { info!("Loaded {} accounts from file system", count); } Err(e) => { info!("Could not load accounts from files: {} (this is ok)", e); } } // Load accounts from environment variable (REFRESH_TOKENS) // Format: comma-separated refresh tokens if let Ok(tokens_str) = std::env::var("REFRESH_TOKENS") { info!("Found REFRESH_TOKENS environment variable, loading accounts..."); let tokens: Vec<&str> = tokens_str.split(',') .map(|s| s.trim()) .filter(|s| !s.is_empty()) .collect(); let mut success_count = 0; for (idx, refresh_token) in tokens.iter().enumerate() { info!("Processing token {} of {}...", idx + 1, tokens.len()); match load_account_from_refresh_token(&token_manager, refresh_token).await { Ok(email) => { info!("Loaded account: {}", email); success_count += 1; } Err(e) => { tracing::error!("Failed to load token {}: {}", idx + 1, e); } } } info!("Loaded {} accounts from environment variable", success_count); } // Build proxy state let mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.anthropic_mapping.clone())); let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.openai_mapping.clone())); let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.custom_mapping.clone())); let proxy_state = Arc::new(tokio::sync::RwLock::new(proxy_config.upstream_proxy.clone())); let app_state = antigravity_server::proxy::server::AppState { token_manager: token_manager.clone(), anthropic_mapping: mapping_state, openai_mapping: openai_mapping_state, custom_mapping: custom_mapping_state, request_timeout: 300, thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), upstream_proxy: proxy_state, upstream: Arc::new(antigravity_server::proxy::upstream::client::UpstreamClient::new( Some(proxy_config.upstream_proxy.clone()) )), }; // Build routes use antigravity_server::proxy::handlers; use axum::routing::{get, post}; let proxy_routes = Router::new() // OpenAI Protocol .route("/v1/models", get(handlers::openai::handle_list_models)) .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions)) .route("/v1/completions", post(handlers::openai::handle_completions)) .route("/v1/responses", post(handlers::openai::handle_completions)) // Claude Protocol .route("/v1/messages", post(handlers::claude::handle_messages)) .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens)) .route("/v1/models/claude", get(handlers::claude::handle_list_models)) // Gemini Protocol .route("/v1beta/models", get(handlers::gemini::handle_list_models)) .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate)) .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) // Health check .route("/healthz", get(health_check)) .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) .layer(TraceLayer::new_for_http()) .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::auth_middleware)) .layer(antigravity_server::proxy::middleware::cors_layer()) .with_state(app_state); // Combine API routes and proxy routes // Apply basic auth middleware to admin UI and management API // Proxy routes (/v1/*) are protected by API key auth instead let app = Router::new() .merge(api::api_routes::<()>()) .merge(proxy_routes) .fallback_service(ServeDir::new("static")) .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::basic_auth_middleware)); // Bind to port 7860 (HuggingFace Spaces default) let port = std::env::var("PORT") .unwrap_or_else(|_| "7860".to_string()) .parse::() .unwrap_or(7860); let addr = format!("0.0.0.0:{}", port); info!("Server listening on http://{}", addr); // Start keep-alive background task to prevent HuggingFace Spaces from sleeping let keepalive_port = port; tokio::spawn(async move { keepalive_task(keepalive_port).await; }); let listener = tokio::net::TcpListener::bind(&addr) .await .expect("Failed to bind address"); axum::serve(listener, app) .await .expect("Server error"); } /// Keep-alive background task /// Pings the health check endpoint every 5 minutes to prevent HuggingFace Spaces from sleeping async fn keepalive_task(port: u16) { use tokio::time::{interval, Duration}; // Wait 30 seconds for server to start tokio::time::sleep(Duration::from_secs(30)).await; let client = reqwest::Client::new(); let url = format!("http://127.0.0.1:{}/healthz", port); // Check if keepalive is enabled (default: enabled) let enabled = std::env::var("KEEPALIVE_ENABLED") .map(|v| v != "false" && v != "0") .unwrap_or(true); if !enabled { info!("Keep-alive task disabled"); return; } // Get interval from env (default: 5 minutes) let interval_secs = std::env::var("KEEPALIVE_INTERVAL") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(300); // 5 minutes info!("Keep-alive task started (interval: {}s)", interval_secs); let mut ticker = interval(Duration::from_secs(interval_secs)); loop { ticker.tick().await; match client.get(&url).send().await { Ok(resp) if resp.status().is_success() => { tracing::debug!("Keep-alive ping successful"); } Ok(resp) => { tracing::warn!("Keep-alive ping returned status: {}", resp.status()); } Err(e) => { tracing::warn!("Keep-alive ping failed: {}", e); } } } } /// Health check handler async fn health_check() -> axum::Json { axum::Json(serde_json::json!({ "status": "ok", "version": env!("CARGO_PKG_VERSION") })) } /// Load account from refresh token and add to token manager async fn load_account_from_refresh_token( token_manager: &Arc, refresh_token: &str, ) -> Result { use antigravity_server::modules::oauth; use antigravity_server::proxy::project_resolver; // 1. Refresh to get access token let token_res = oauth::refresh_access_token(refresh_token) .await .map_err(|e| format!("Failed to refresh token: {}", e))?; // 2. Get user info let user_info = oauth::get_user_info(&token_res.access_token) .await .map_err(|e| format!("Failed to get user info: {}", e))?; let email = user_info.email.clone(); // 3. Get project ID let project_id = project_resolver::fetch_project_id(&token_res.access_token) .await .ok(); // 4. Calculate expiry timestamp let now = chrono::Utc::now().timestamp(); let expiry_timestamp = now + token_res.expires_in; // 5. Add to token manager (in-memory only) token_manager.add_token( token_res.access_token, refresh_token.to_string(), expiry_timestamp, email.clone(), project_id, ).await; Ok(email) }