File size: 9,182 Bytes
bbb1195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7881083
bbb1195
 
7881083
 
bbb1195
7881083
bbb1195
7881083
 
 
 
 
bbb1195
 
 
 
 
 
 
 
 
 
 
aec2274
bbb1195
 
aec2274
bbb1195
 
aec2274
bbb1195
 
 
aec2274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbb1195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7881083
 
 
 
 
 
bbb1195
 
 
 
 
 
 
 
 
7881083
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbb1195
 
 
 
 
 
 
aec2274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
use std::sync::Arc;
use std::path::PathBuf;
use axum::{Router, extract::DefaultBodyLimit};
use tower_http::{trace::TraceLayer, services::ServeDir};
use tracing::info;

use antigravity_server::{
    modules,
    proxy::TokenManager,
    api,
};

#[tokio::main]
async fn main() {
    // Initialize logger
    modules::logger::init_logger();

    info!("Starting Antigravity API Proxy Server...");

    // Get data directory
    // Priority: /data (HF persistent) > ./data (current dir) > ~/.antigravity_tools (fallback)
    let data_dir = if PathBuf::from("/data").exists() {
        PathBuf::from("/data")
    } else if PathBuf::from("./data").exists() {
        PathBuf::from("./data").canonicalize().unwrap_or_else(|_| PathBuf::from("./data"))
    } else {
        let home_dir = dirs::home_dir()
            .expect("Cannot get home directory")
            .join(".antigravity_tools");
        // Create the directory if it doesn't exist
        let _ = std::fs::create_dir_all(&home_dir);
        let _ = std::fs::create_dir_all(home_dir.join("accounts"));
        home_dir
    };

    info!("Using data directory: {:?}", data_dir);

    // Load configuration
    let config = modules::load_app_config().unwrap_or_default();
    let proxy_config = config.proxy.clone();

    // Initialize token manager
    let token_manager = Arc::new(TokenManager::new(data_dir.clone()));

    // Load accounts from file system (if any)
    match token_manager.load_accounts().await {
        Ok(count) => {
            info!("Loaded {} accounts from file system", count);
        }
        Err(e) => {
            info!("Could not load accounts from files: {} (this is ok)", e);
        }
    }

    // Load accounts from environment variable (REFRESH_TOKENS)
    // Format: comma-separated refresh tokens
    if let Ok(tokens_str) = std::env::var("REFRESH_TOKENS") {
        info!("Found REFRESH_TOKENS environment variable, loading accounts...");
        let tokens: Vec<&str> = tokens_str.split(',')
            .map(|s| s.trim())
            .filter(|s| !s.is_empty())
            .collect();

        let mut success_count = 0;
        for (idx, refresh_token) in tokens.iter().enumerate() {
            info!("Processing token {} of {}...", idx + 1, tokens.len());
            match load_account_from_refresh_token(&token_manager, refresh_token).await {
                Ok(email) => {
                    info!("Loaded account: {}", email);
                    success_count += 1;
                }
                Err(e) => {
                    tracing::error!("Failed to load token {}: {}", idx + 1, e);
                }
            }
        }
        info!("Loaded {} accounts from environment variable", success_count);
    }

    // Build proxy state
    let mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.anthropic_mapping.clone()));
    let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.openai_mapping.clone()));
    let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(proxy_config.custom_mapping.clone()));
    let proxy_state = Arc::new(tokio::sync::RwLock::new(proxy_config.upstream_proxy.clone()));

    let app_state = antigravity_server::proxy::server::AppState {
        token_manager: token_manager.clone(),
        anthropic_mapping: mapping_state,
        openai_mapping: openai_mapping_state,
        custom_mapping: custom_mapping_state,
        request_timeout: 300,
        thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
        upstream_proxy: proxy_state,
        upstream: Arc::new(antigravity_server::proxy::upstream::client::UpstreamClient::new(
            Some(proxy_config.upstream_proxy.clone())
        )),
    };

    // Build routes
    use antigravity_server::proxy::handlers;
    use axum::routing::{get, post};

    let proxy_routes = Router::new()
        // OpenAI Protocol
        .route("/v1/models", get(handlers::openai::handle_list_models))
        .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions))
        .route("/v1/completions", post(handlers::openai::handle_completions))
        .route("/v1/responses", post(handlers::openai::handle_completions))
        // Claude Protocol
        .route("/v1/messages", post(handlers::claude::handle_messages))
        .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens))
        .route("/v1/models/claude", get(handlers::claude::handle_list_models))
        // Gemini Protocol
        .route("/v1beta/models", get(handlers::gemini::handle_list_models))
        .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate))
        .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens))
        // Health check
        .route("/healthz", get(health_check))
        .layer(DefaultBodyLimit::max(100 * 1024 * 1024))
        .layer(TraceLayer::new_for_http())
        .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::auth_middleware))
        .layer(antigravity_server::proxy::middleware::cors_layer())
        .with_state(app_state);

    // Combine API routes and proxy routes
    // Apply basic auth middleware to admin UI and management API
    // Proxy routes (/v1/*) are protected by API key auth instead
    let app = Router::new()
        .merge(api::api_routes::<()>())
        .merge(proxy_routes)
        .fallback_service(ServeDir::new("static"))
        .layer(axum::middleware::from_fn(antigravity_server::proxy::middleware::basic_auth_middleware));

    // Bind to port 7860 (HuggingFace Spaces default)
    let port = std::env::var("PORT")
        .unwrap_or_else(|_| "7860".to_string())
        .parse::<u16>()
        .unwrap_or(7860);

    let addr = format!("0.0.0.0:{}", port);
    info!("Server listening on http://{}", addr);

    // Start keep-alive background task to prevent HuggingFace Spaces from sleeping
    let keepalive_port = port;
    tokio::spawn(async move {
        keepalive_task(keepalive_port).await;
    });

    let listener = tokio::net::TcpListener::bind(&addr)
        .await
        .expect("Failed to bind address");

    axum::serve(listener, app)
        .await
        .expect("Server error");
}

/// Keep-alive background task
/// Pings the health check endpoint every 5 minutes to prevent HuggingFace Spaces from sleeping
async fn keepalive_task(port: u16) {
    use tokio::time::{interval, Duration};

    // Wait 30 seconds for server to start
    tokio::time::sleep(Duration::from_secs(30)).await;

    let client = reqwest::Client::new();
    let url = format!("http://127.0.0.1:{}/healthz", port);

    // Check if keepalive is enabled (default: enabled)
    let enabled = std::env::var("KEEPALIVE_ENABLED")
        .map(|v| v != "false" && v != "0")
        .unwrap_or(true);

    if !enabled {
        info!("Keep-alive task disabled");
        return;
    }

    // Get interval from env (default: 5 minutes)
    let interval_secs = std::env::var("KEEPALIVE_INTERVAL")
        .ok()
        .and_then(|v| v.parse::<u64>().ok())
        .unwrap_or(300); // 5 minutes

    info!("Keep-alive task started (interval: {}s)", interval_secs);

    let mut ticker = interval(Duration::from_secs(interval_secs));

    loop {
        ticker.tick().await;

        match client.get(&url).send().await {
            Ok(resp) if resp.status().is_success() => {
                tracing::debug!("Keep-alive ping successful");
            }
            Ok(resp) => {
                tracing::warn!("Keep-alive ping returned status: {}", resp.status());
            }
            Err(e) => {
                tracing::warn!("Keep-alive ping failed: {}", e);
            }
        }
    }
}

/// Health check handler
async fn health_check() -> axum::Json<serde_json::Value> {
    axum::Json(serde_json::json!({
        "status": "ok",
        "version": env!("CARGO_PKG_VERSION")
    }))
}

/// Load account from refresh token and add to token manager
async fn load_account_from_refresh_token(
    token_manager: &Arc<TokenManager>,
    refresh_token: &str,
) -> Result<String, String> {
    use antigravity_server::modules::oauth;
    use antigravity_server::proxy::project_resolver;

    // 1. Refresh to get access token
    let token_res = oauth::refresh_access_token(refresh_token)
        .await
        .map_err(|e| format!("Failed to refresh token: {}", e))?;

    // 2. Get user info
    let user_info = oauth::get_user_info(&token_res.access_token)
        .await
        .map_err(|e| format!("Failed to get user info: {}", e))?;

    let email = user_info.email.clone();

    // 3. Get project ID
    let project_id = project_resolver::fetch_project_id(&token_res.access_token)
        .await
        .ok();

    // 4. Calculate expiry timestamp
    let now = chrono::Utc::now().timestamp();
    let expiry_timestamp = now + token_res.expires_in;

    // 5. Add to token manager (in-memory only)
    token_manager.add_token(
        token_res.access_token,
        refresh_token.to_string(),
        expiry_timestamp,
        email.clone(),
        project_id,
    ).await;

    Ok(email)
}