gemini / server /src /main.rs
yinming
Add REFRESH_TOKENS env var support for in-memory account loading
aec2274
use std::sync::Arc;
use std::path::PathBuf;
use axum::{Router, extract::DefaultBodyLimit};
use tower_http::{trace::TraceLayer, services::ServeDir};
use tracing::info;
use antigravity_server::{
modules,
proxy::TokenManager,
api,
};
#[tokio::main]
async fn main() {
// Initialize logger
modules::logger::init_logger();
info!("Starting Antigravity API Proxy Server...");
// Get data directory
// Priority: /data (HF persistent) > ./data (current dir) > ~/.antigravity_tools (fallback)
let data_dir = if PathBuf::from("/data").exists() {
PathBuf::from("/data")
} else if PathBuf::from("./data").exists() {
PathBuf::from("./data").canonicalize().unwrap_or_else(|_| PathBuf::from("./data"))
} else {
let home_dir = dirs::home_dir()
.expect("Cannot get home directory")
.join(".antigravity_tools");
// Create the directory if it doesn't exist
let _ = std::fs::create_dir_all(&home_dir);
let _ = std::fs::create_dir_all(home_dir.join("accounts"));
home_dir
};
info!("Using data directory: {:?}", data_dir);
// Load configuration
let config = modules::load_app_config().unwrap_or_default();
let proxy_config = config.proxy.clone();
// Initialize token manager
let token_manager = Arc::new(TokenManager::new(data_dir.clone()));
// Load accounts from file system (if any)
match token_manager.load_accounts().await {
Ok(count) => {
info!("Loaded {} accounts from file system", count);
}
Err(e) => {
info!("Could not load accounts from files: {} (this is ok)", e);
}
}
// Load accounts from environment variable (REFRESH_TOKENS)
// Format: comma-separated refresh tokens
if let Ok(tokens_str) = std::env::var("REFRESH_TOKENS") {
info!("Found REFRESH_TOKENS environment variable, loading accounts...");
let tokens: Vec<&str> = tokens_str.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
let mut success_count = 0;
for (idx, refresh_token) in tokens.iter().enumerate() {
info!("Processing token {} of {}...", idx + 1, tokens.len());
match load_account_from_refresh_token(&token_manager, refresh_token).await {
Ok(email) => {
info!("Loaded account: {}", email);
success_count += 1;
}
Err(e) => {
tracing::error!("Failed to load token {}: {}", idx + 1, e);
}
}
}
info!("Loaded {} accounts from environment variable", success_count);
}
// Build proxy state
let mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.anthropic_mapping.clone()));
let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.openai_mapping.clone()));
let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.custom_mapping.clone()));
let proxy_state = Arc::new(tokio::sync::RwLock::new(proxy_config.upstream_proxy.clone()));
let app_state = antigravity_server::proxy::server::AppState {
token_manager: token_manager.clone(),
anthropic_mapping: mapping_state,
openai_mapping: openai_mapping_state,
custom_mapping: custom_mapping_state,
request_timeout: 300,
thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
upstream_proxy: proxy_state,
upstream: Arc::new(antigravity_server::proxy::upstream::client::UpstreamClient::new(
Some(proxy_config.upstream_proxy.clone())
)),
};
// Build routes
use antigravity_server::proxy::handlers;
use axum::routing::{get, post};
let proxy_routes = Router::new()
// OpenAI Protocol
.route("/v1/models", get(handlers::openai::handle_list_models))
.route("/v1/chat/completions", post(handlers::openai::handle_chat_completions))
.route("/v1/completions", post(handlers::openai::handle_completions))
.route("/v1/responses", post(handlers::openai::handle_completions))
// Claude Protocol
.route("/v1/messages", post(handlers::claude::handle_messages))
.route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens))
.route("/v1/models/claude", get(handlers::claude::handle_list_models))
// Gemini Protocol
.route("/v1beta/models", get(handlers::gemini::handle_list_models))
.route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate))
.route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens))
// Health check
.route("/healthz", get(health_check))
.layer(DefaultBodyLimit::max(100 * 1024 * 1024))
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::auth_middleware))
.layer(antigravity_server::proxy::middleware::cors_layer())
.with_state(app_state);
// Combine API routes and proxy routes
// Apply basic auth middleware to admin UI and management API
// Proxy routes (/v1/*) are protected by API key auth instead
let app = Router::new()
.merge(api::api_routes::<()>())
.merge(proxy_routes)
.fallback_service(ServeDir::new("static"))
.layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::basic_auth_middleware));
// Bind to port 7860 (HuggingFace Spaces default)
let port = std::env::var("PORT")
.unwrap_or_else(|_| "7860".to_string())
.parse::<u16>()
.unwrap_or(7860);
let addr = format!("0.0.0.0:{}", port);
info!("Server listening on http://{}", addr);
// Start keep-alive background task to prevent HuggingFace Spaces from sleeping
let keepalive_port = port;
tokio::spawn(async move {
keepalive_task(keepalive_port).await;
});
let listener = tokio::net::TcpListener::bind(&addr)
.await
.expect("Failed to bind address");
axum::serve(listener, app)
.await
.expect("Server error");
}
/// Keep-alive background task
/// Pings the health check endpoint every 5 minutes to prevent HuggingFace Spaces from sleeping
async fn keepalive_task(port: u16) {
use tokio::time::{interval, Duration};
// Wait 30 seconds for server to start
tokio::time::sleep(Duration::from_secs(30)).await;
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{}/healthz", port);
// Check if keepalive is enabled (default: enabled)
let enabled = std::env::var("KEEPALIVE_ENABLED")
.map(|v| v != "false" && v != "0")
.unwrap_or(true);
if !enabled {
info!("Keep-alive task disabled");
return;
}
// Get interval from env (default: 5 minutes)
let interval_secs = std::env::var("KEEPALIVE_INTERVAL")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(300); // 5 minutes
info!("Keep-alive task started (interval: {}s)", interval_secs);
let mut ticker = interval(Duration::from_secs(interval_secs));
loop {
ticker.tick().await;
match client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
tracing::debug!("Keep-alive ping successful");
}
Ok(resp) => {
tracing::warn!("Keep-alive ping returned status: {}", resp.status());
}
Err(e) => {
tracing::warn!("Keep-alive ping failed: {}", e);
}
}
}
}
/// Health check handler
async fn health_check() -> axum::Json<serde_json::Value> {
axum::Json(serde_json::json!({
"status": "ok",
"version": env!("CARGO_PKG_VERSION")
}))
}
/// Load account from refresh token and add to token manager
async fn load_account_from_refresh_token(
token_manager: &Arc<TokenManager>,
refresh_token: &str,
) -> Result<String, String> {
use antigravity_server::modules::oauth;
use antigravity_server::proxy::project_resolver;
// 1. Refresh to get access token
let token_res = oauth::refresh_access_token(refresh_token)
.await
.map_err(|e| format!("Failed to refresh token: {}", e))?;
// 2. Get user info
let user_info = oauth::get_user_info(&token_res.access_token)
.await
.map_err(|e| format!("Failed to get user info: {}", e))?;
let email = user_info.email.clone();
// 3. Get project ID
let project_id = project_resolver::fetch_project_id(&token_res.access_token)
.await
.ok();
// 4. Calculate expiry timestamp
let now = chrono::Utc::now().timestamp();
let expiry_timestamp = now + token_res.expires_in;
// 5. Add to token manager (in-memory only)
token_manager.add_token(
token_res.access_token,
refresh_token.to_string(),
expiry_timestamp,
email.clone(),
project_id,
).await;
Ok(email)
}