""" 账号池管理器 - 支持 OAuth Token 的增删改查和自动刷新 """ import json import os import httpx from typing import List, Optional from datetime import datetime, timedelta from models import Account, OAuthToken, AccountStats # 数据文件路径 (HF Spaces 持久化目录) DATA_DIR = os.environ.get("DATA_DIR", "./data") ACCOUNTS_FILE = os.path.join(DATA_DIR, "accounts.json") CONFIG_FILE = os.path.join(DATA_DIR, "config.json") # Google OAuth 配置 (Antigravity 使用的 Client ID) OAUTH_CLIENT_ID = os.environ.get( "OAUTH_CLIENT_ID", "595848968694-r5ng3t6qb9elhe1u1h1hqgq4j2r3hgvk.apps.googleusercontent.com" ) # 默认使用 AI Studio 的公开 Client Secret OAUTH_CLIENT_SECRET = os.environ.get( "OAUTH_CLIENT_SECRET", "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" ) OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token" class AccountManager: """账号管理器 - 支持 OAuth Token""" def __init__(self): self._accounts: dict[str, Account] = {} self._current_index = 0 self._ensure_data_dir() self._load_accounts() def _ensure_data_dir(self): """确保数据目录存在""" os.makedirs(DATA_DIR, exist_ok=True) def _load_accounts(self): """从文件加载账号""" if os.path.exists(ACCOUNTS_FILE): try: with open(ACCOUNTS_FILE, "r", encoding="utf-8") as f: data = json.load(f) for item in data: account = Account(**item) self._accounts[account.id] = account except Exception as e: print(f"加载账号失败: {e}") def _save_accounts(self): """保存账号到文件""" try: data = [acc.model_dump(mode="json") for acc in self._accounts.values()] with open(ACCOUNTS_FILE, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2, default=str) except Exception as e: print(f"保存账号失败: {e}") def add_account( self, email: str, access_token: str, refresh_token: str, expires_in: int = 3600, project_id: Optional[str] = None ) -> Account: """添加新账号""" now = int(datetime.now().timestamp()) token = OAuthToken( access_token=access_token, refresh_token=refresh_token, expires_in=expires_in, expiry_timestamp=now + expires_in, project_id=project_id ) account = Account(email=email, token=token) self._accounts[account.id] = account self._save_accounts() return account def remove_account(self, account_id: str) -> bool: """删除账号""" if account_id in self._accounts: del self._accounts[account_id] self._save_accounts() return True return False def get_account(self, account_id: str) -> Optional[Account]: """获取单个账号""" return self._accounts.get(account_id) def get_all_accounts(self) -> List[Account]: """获取所有账号""" return list(self._accounts.values()) def get_available_accounts(self) -> List[Account]: """获取所有可用账号""" return [acc for acc in self._accounts.values() if acc.is_available()] async def get_next_token(self) -> Optional[Account]: """ 获取下一个可用的 Token(轮询机制) 自动刷新过期的 Token """ available = self.get_available_accounts() if not available: return None # Round Robin self._current_index = self._current_index % len(available) account = available[self._current_index] self._current_index += 1 # 检查并刷新过期 Token if account.is_token_expired(): print(f"账号 {account.email} 的 token 即将过期,正在刷新...") try: await self._refresh_token(account) except Exception as e: print(f"刷新 token 失败: {e}") # 继续使用可能过期的 token,让 API 返回错误 return account async def _refresh_token(self, account: Account): """刷新 OAuth Token""" async with httpx.AsyncClient() as client: response = await client.post( OAUTH_TOKEN_URL, data={ "client_id": OAUTH_CLIENT_ID, "client_secret": OAUTH_CLIENT_SECRET, "refresh_token": account.token.refresh_token, "grant_type": "refresh_token" } ) if response.status_code != 200: raise Exception(f"刷新失败: {response.text}") data = response.json() now = int(datetime.now().timestamp()) account.token.access_token = data["access_token"] account.token.expires_in = data.get("expires_in", 3600) account.token.expiry_timestamp = now + account.token.expires_in self._save_accounts() print(f"Token 刷新成功!有效期: {account.token.expires_in} 秒") def update_account_stats(self, account_id: str, success: bool, error: str = None): """更新账号统计信息""" account = self._accounts.get(account_id) if account: account.total_requests += 1 account.last_used = datetime.now() if success: account.successful_requests += 1 account.last_error = None else: account.failed_requests += 1 account.last_error = error self._save_accounts() def set_account_cooldown(self, account_id: str, duration_seconds: int): """设置账号冷却时间""" account = self._accounts.get(account_id) if account: account.cooldown_until = datetime.now() + timedelta(seconds=duration_seconds) self._save_accounts() def toggle_account(self, account_id: str) -> bool: """切换账号启用状态""" account = self._accounts.get(account_id) if account: account.enabled = not account.enabled self._save_accounts() return account.enabled return False def get_stats(self) -> AccountStats: """获取统计汇总""" accounts = list(self._accounts.values()) total_requests = sum(acc.total_requests for acc in accounts) successful = sum(acc.successful_requests for acc in accounts) return AccountStats( total_accounts=len(accounts), available_accounts=len([a for a in accounts if a.is_available()]), total_requests=total_requests, success_rate=successful / total_requests if total_requests > 0 else 0.0 ) class ConfigManager: """配置管理器 - 管理 API Key 等可变配置""" def __init__(self): self._config = { "api_key": "sk-antigravity" } self._ensure_data_dir() self._load_config() def _ensure_data_dir(self): """确保数据目录存在""" os.makedirs(DATA_DIR, exist_ok=True) def _load_config(self): """从文件加载配置""" if os.path.exists(CONFIG_FILE): try: with open(CONFIG_FILE, "r", encoding="utf-8") as f: self._config.update(json.load(f)) except Exception as e: print(f"加载配置失败: {e}") def _save_config(self): """保存配置到文件""" try: with open(CONFIG_FILE, "w", encoding="utf-8") as f: json.dump(self._config, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存配置失败: {e}") def get_api_key(self) -> str: """获取 API Key""" return self._config.get("api_key", "sk-antigravity") def set_api_key(self, api_key: str) -> bool: """设置 API Key""" if not api_key or len(api_key.strip()) == 0: return False self._config["api_key"] = api_key.strip() self._save_config() return True # 全局单例 account_manager = AccountManager() config_manager = ConfigManager()