Spaces:
Sleeping
Sleeping
| use std::sync::Arc; | |
| use std::path::PathBuf; | |
| use axum::{Router, extract::DefaultBodyLimit}; | |
| use tower_http::{trace::TraceLayer, services::ServeDir}; | |
| use tracing::info; | |
| use antigravity_server::{ | |
| modules, | |
| proxy::TokenManager, | |
| api, | |
| }; | |
| async fn main() { | |
| // Initialize logger | |
| modules::logger::init_logger(); | |
| info!("Starting Antigravity API Proxy Server..."); | |
| // Get data directory | |
| // Priority: /data (HF persistent) > ./data (current dir) > ~/.antigravity_tools (fallback) | |
| let data_dir = if PathBuf::from("/data").exists() { | |
| PathBuf::from("/data") | |
| } else if PathBuf::from("./data").exists() { | |
| PathBuf::from("./data").canonicalize().unwrap_or_else(|_| PathBuf::from("./data")) | |
| } else { | |
| let home_dir = dirs::home_dir() | |
| .expect("Cannot get home directory") | |
| .join(".antigravity_tools"); | |
| // Create the directory if it doesn't exist | |
| let _ = std::fs::create_dir_all(&home_dir); | |
| let _ = std::fs::create_dir_all(home_dir.join("accounts")); | |
| home_dir | |
| }; | |
| info!("Using data directory: {:?}", data_dir); | |
| // Load configuration | |
| let config = modules::load_app_config().unwrap_or_default(); | |
| let proxy_config = config.proxy.clone(); | |
| // Initialize token manager | |
| let token_manager = Arc::new(TokenManager::new(data_dir.clone())); | |
| // Load accounts from file system (if any) | |
| match token_manager.load_accounts().await { | |
| Ok(count) => { | |
| info!("Loaded {} accounts from file system", count); | |
| } | |
| Err(e) => { | |
| info!("Could not load accounts from files: {} (this is ok)", e); | |
| } | |
| } | |
| // Load accounts from environment variable (REFRESH_TOKENS) | |
| // Format: comma-separated refresh tokens | |
| if let Ok(tokens_str) = std::env::var("REFRESH_TOKENS") { | |
| info!("Found REFRESH_TOKENS environment variable, loading accounts..."); | |
| let tokens: Vec<&str> = tokens_str.split(',') | |
| .map(|s| s.trim()) | |
| .filter(|s| !s.is_empty()) | |
| .collect(); | |
| let mut success_count = 0; | |
| for (idx, refresh_token) in tokens.iter().enumerate() { | |
| info!("Processing token {} of {}...", idx + 1, tokens.len()); | |
| match load_account_from_refresh_token(&token_manager, refresh_token).await { | |
| Ok(email) => { | |
| info!("Loaded account: {}", email); | |
| success_count += 1; | |
| } | |
| Err(e) => { | |
| tracing::error!("Failed to load token {}: {}", idx + 1, e); | |
| } | |
| } | |
| } | |
| info!("Loaded {} accounts from environment variable", success_count); | |
| } | |
| // Build proxy state | |
| let mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.anthropic_mapping.clone())); | |
| let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.openai_mapping.clone())); | |
| let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.custom_mapping.clone())); | |
| let proxy_state = Arc::new(tokio::sync::RwLock::new(proxy_config.upstream_proxy.clone())); | |
| let app_state = antigravity_server::proxy::server::AppState { | |
| token_manager: token_manager.clone(), | |
| anthropic_mapping: mapping_state, | |
| openai_mapping: openai_mapping_state, | |
| custom_mapping: custom_mapping_state, | |
| request_timeout: 300, | |
| thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), | |
| upstream_proxy: proxy_state, | |
| upstream: Arc::new(antigravity_server::proxy::upstream::client::UpstreamClient::new( | |
| Some(proxy_config.upstream_proxy.clone()) | |
| )), | |
| }; | |
| // Build routes | |
| use antigravity_server::proxy::handlers; | |
| use axum::routing::{get, post}; | |
| let proxy_routes = Router::new() | |
| // OpenAI Protocol | |
| .route("/v1/models", get(handlers::openai::handle_list_models)) | |
| .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions)) | |
| .route("/v1/completions", post(handlers::openai::handle_completions)) | |
| .route("/v1/responses", post(handlers::openai::handle_completions)) | |
| // Claude Protocol | |
| .route("/v1/messages", post(handlers::claude::handle_messages)) | |
| .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens)) | |
| .route("/v1/models/claude", get(handlers::claude::handle_list_models)) | |
| // Gemini Protocol | |
| .route("/v1beta/models", get(handlers::gemini::handle_list_models)) | |
| .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate)) | |
| .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) | |
| // Health check | |
| .route("/healthz", get(health_check)) | |
| .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) | |
| .layer(TraceLayer::new_for_http()) | |
| .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::auth_middleware)) | |
| .layer(antigravity_server::proxy::middleware::cors_layer()) | |
| .with_state(app_state); | |
| // Combine API routes and proxy routes | |
| // Apply basic auth middleware to admin UI and management API | |
| // Proxy routes (/v1/*) are protected by API key auth instead | |
| let app = Router::new() | |
| .merge(api::api_routes::<()>()) | |
| .merge(proxy_routes) | |
| .fallback_service(ServeDir::new("static")) | |
| .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::basic_auth_middleware)); | |
| // Bind to port 7860 (HuggingFace Spaces default) | |
| let port = std::env::var("PORT") | |
| .unwrap_or_else(|_| "7860".to_string()) | |
| .parse::<u16>() | |
| .unwrap_or(7860); | |
| let addr = format!("0.0.0.0:{}", port); | |
| info!("Server listening on http://{}", addr); | |
| // Start keep-alive background task to prevent HuggingFace Spaces from sleeping | |
| let keepalive_port = port; | |
| tokio::spawn(async move { | |
| keepalive_task(keepalive_port).await; | |
| }); | |
| let listener = tokio::net::TcpListener::bind(&addr) | |
| .await | |
| .expect("Failed to bind address"); | |
| axum::serve(listener, app) | |
| .await | |
| .expect("Server error"); | |
| } | |
| /// Keep-alive background task | |
| /// Pings the health check endpoint every 5 minutes to prevent HuggingFace Spaces from sleeping | |
| async fn keepalive_task(port: u16) { | |
| use tokio::time::{interval, Duration}; | |
| // Wait 30 seconds for server to start | |
| tokio::time::sleep(Duration::from_secs(30)).await; | |
| let client = reqwest::Client::new(); | |
| let url = format!("http://127.0.0.1:{}/healthz", port); | |
| // Check if keepalive is enabled (default: enabled) | |
| let enabled = std::env::var("KEEPALIVE_ENABLED") | |
| .map(|v| v != "false" && v != "0") | |
| .unwrap_or(true); | |
| if !enabled { | |
| info!("Keep-alive task disabled"); | |
| return; | |
| } | |
| // Get interval from env (default: 5 minutes) | |
| let interval_secs = std::env::var("KEEPALIVE_INTERVAL") | |
| .ok() | |
| .and_then(|v| v.parse::<u64>().ok()) | |
| .unwrap_or(300); // 5 minutes | |
| info!("Keep-alive task started (interval: {}s)", interval_secs); | |
| let mut ticker = interval(Duration::from_secs(interval_secs)); | |
| loop { | |
| ticker.tick().await; | |
| match client.get(&url).send().await { | |
| Ok(resp) if resp.status().is_success() => { | |
| tracing::debug!("Keep-alive ping successful"); | |
| } | |
| Ok(resp) => { | |
| tracing::warn!("Keep-alive ping returned status: {}", resp.status()); | |
| } | |
| Err(e) => { | |
| tracing::warn!("Keep-alive ping failed: {}", e); | |
| } | |
| } | |
| } | |
| } | |
| /// Health check handler | |
| async fn health_check() -> axum::Json<serde_json::Value> { | |
| axum::Json(serde_json::json!({ | |
| "status": "ok", | |
| "version": env!("CARGO_PKG_VERSION") | |
| })) | |
| } | |
| /// Load account from refresh token and add to token manager | |
| async fn load_account_from_refresh_token( | |
| token_manager: &Arc<TokenManager>, | |
| refresh_token: &str, | |
| ) -> Result<String, String> { | |
| use antigravity_server::modules::oauth; | |
| use antigravity_server::proxy::project_resolver; | |
| // 1. Refresh to get access token | |
| let token_res = oauth::refresh_access_token(refresh_token) | |
| .await | |
| .map_err(|e| format!("Failed to refresh token: {}", e))?; | |
| // 2. Get user info | |
| let user_info = oauth::get_user_info(&token_res.access_token) | |
| .await | |
| .map_err(|e| format!("Failed to get user info: {}", e))?; | |
| let email = user_info.email.clone(); | |
| // 3. Get project ID | |
| let project_id = project_resolver::fetch_project_id(&token_res.access_token) | |
| .await | |
| .ok(); | |
| // 4. Calculate expiry timestamp | |
| let now = chrono::Utc::now().timestamp(); | |
| let expiry_timestamp = now + token_res.expires_in; | |
| // 5. Add to token manager (in-memory only) | |
| token_manager.add_token( | |
| token_res.access_token, | |
| refresh_token.to_string(), | |
| expiry_timestamp, | |
| email.clone(), | |
| project_id, | |
| ).await; | |
| Ok(email) | |
| } | |