antigravity-proxy / account_manager.py
asemxin
fix: add default OAuth client secret for Antigravity
2fbc7af
"""
账号池管理器 - 支持 OAuth Token 的增删改查和自动刷新
"""
import json
import os
import httpx
from typing import List, Optional
from datetime import datetime, timedelta
from models import Account, OAuthToken, AccountStats
# 数据文件路径 (HF Spaces 持久化目录)
DATA_DIR = os.environ.get("DATA_DIR", "./data")
ACCOUNTS_FILE = os.path.join(DATA_DIR, "accounts.json")
CONFIG_FILE = os.path.join(DATA_DIR, "config.json")
# Google OAuth 配置 (Antigravity 使用的 Client ID)
OAUTH_CLIENT_ID = os.environ.get(
"OAUTH_CLIENT_ID",
"595848968694-r5ng3t6qb9elhe1u1h1hqgq4j2r3hgvk.apps.googleusercontent.com"
)
# 默认使用 AI Studio 的公开 Client Secret
OAUTH_CLIENT_SECRET = os.environ.get(
"OAUTH_CLIENT_SECRET",
"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"
)
OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
class AccountManager:
"""账号管理器 - 支持 OAuth Token"""
def __init__(self):
self._accounts: dict[str, Account] = {}
self._current_index = 0
self._ensure_data_dir()
self._load_accounts()
def _ensure_data_dir(self):
"""确保数据目录存在"""
os.makedirs(DATA_DIR, exist_ok=True)
def _load_accounts(self):
"""从文件加载账号"""
if os.path.exists(ACCOUNTS_FILE):
try:
with open(ACCOUNTS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data:
account = Account(**item)
self._accounts[account.id] = account
except Exception as e:
print(f"加载账号失败: {e}")
def _save_accounts(self):
"""保存账号到文件"""
try:
data = [acc.model_dump(mode="json") for acc in self._accounts.values()]
with open(ACCOUNTS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, default=str)
except Exception as e:
print(f"保存账号失败: {e}")
def add_account(
self,
email: str,
access_token: str,
refresh_token: str,
expires_in: int = 3600,
project_id: Optional[str] = None
) -> Account:
"""添加新账号"""
now = int(datetime.now().timestamp())
token = OAuthToken(
access_token=access_token,
refresh_token=refresh_token,
expires_in=expires_in,
expiry_timestamp=now + expires_in,
project_id=project_id
)
account = Account(email=email, token=token)
self._accounts[account.id] = account
self._save_accounts()
return account
def remove_account(self, account_id: str) -> bool:
"""删除账号"""
if account_id in self._accounts:
del self._accounts[account_id]
self._save_accounts()
return True
return False
def get_account(self, account_id: str) -> Optional[Account]:
"""获取单个账号"""
return self._accounts.get(account_id)
def get_all_accounts(self) -> List[Account]:
"""获取所有账号"""
return list(self._accounts.values())
def get_available_accounts(self) -> List[Account]:
"""获取所有可用账号"""
return [acc for acc in self._accounts.values() if acc.is_available()]
async def get_next_token(self) -> Optional[Account]:
"""
获取下一个可用的 Token(轮询机制)
自动刷新过期的 Token
"""
available = self.get_available_accounts()
if not available:
return None
# Round Robin
self._current_index = self._current_index % len(available)
account = available[self._current_index]
self._current_index += 1
# 检查并刷新过期 Token
if account.is_token_expired():
print(f"账号 {account.email} 的 token 即将过期,正在刷新...")
try:
await self._refresh_token(account)
except Exception as e:
print(f"刷新 token 失败: {e}")
# 继续使用可能过期的 token,让 API 返回错误
return account
async def _refresh_token(self, account: Account):
"""刷新 OAuth Token"""
async with httpx.AsyncClient() as client:
response = await client.post(
OAUTH_TOKEN_URL,
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": account.token.refresh_token,
"grant_type": "refresh_token"
}
)
if response.status_code != 200:
raise Exception(f"刷新失败: {response.text}")
data = response.json()
now = int(datetime.now().timestamp())
account.token.access_token = data["access_token"]
account.token.expires_in = data.get("expires_in", 3600)
account.token.expiry_timestamp = now + account.token.expires_in
self._save_accounts()
print(f"Token 刷新成功!有效期: {account.token.expires_in} 秒")
def update_account_stats(self, account_id: str, success: bool, error: str = None):
"""更新账号统计信息"""
account = self._accounts.get(account_id)
if account:
account.total_requests += 1
account.last_used = datetime.now()
if success:
account.successful_requests += 1
account.last_error = None
else:
account.failed_requests += 1
account.last_error = error
self._save_accounts()
def set_account_cooldown(self, account_id: str, duration_seconds: int):
"""设置账号冷却时间"""
account = self._accounts.get(account_id)
if account:
account.cooldown_until = datetime.now() + timedelta(seconds=duration_seconds)
self._save_accounts()
def toggle_account(self, account_id: str) -> bool:
"""切换账号启用状态"""
account = self._accounts.get(account_id)
if account:
account.enabled = not account.enabled
self._save_accounts()
return account.enabled
return False
def get_stats(self) -> AccountStats:
"""获取统计汇总"""
accounts = list(self._accounts.values())
total_requests = sum(acc.total_requests for acc in accounts)
successful = sum(acc.successful_requests for acc in accounts)
return AccountStats(
total_accounts=len(accounts),
available_accounts=len([a for a in accounts if a.is_available()]),
total_requests=total_requests,
success_rate=successful / total_requests if total_requests > 0 else 0.0
)
class ConfigManager:
"""配置管理器 - 管理 API Key 等可变配置"""
def __init__(self):
self._config = {
"api_key": "sk-antigravity"
}
self._ensure_data_dir()
self._load_config()
def _ensure_data_dir(self):
"""确保数据目录存在"""
os.makedirs(DATA_DIR, exist_ok=True)
def _load_config(self):
"""从文件加载配置"""
if os.path.exists(CONFIG_FILE):
try:
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
self._config.update(json.load(f))
except Exception as e:
print(f"加载配置失败: {e}")
def _save_config(self):
"""保存配置到文件"""
try:
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
json.dump(self._config, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存配置失败: {e}")
def get_api_key(self) -> str:
"""获取 API Key"""
return self._config.get("api_key", "sk-antigravity")
def set_api_key(self, api_key: str) -> bool:
"""设置 API Key"""
if not api_key or len(api_key.strip()) == 0:
return False
self._config["api_key"] = api_key.strip()
self._save_config()
return True
# 全局单例
account_manager = AccountManager()
config_manager = ConfigManager()