Spaces:
Sleeping
Sleeping
| """ | |
| 账号池管理器 - 支持 OAuth Token 的增删改查和自动刷新 | |
| """ | |
| import json | |
| import os | |
| import httpx | |
| from typing import List, Optional | |
| from datetime import datetime, timedelta | |
| from models import Account, OAuthToken, AccountStats | |
| # 数据文件路径 (HF Spaces 持久化目录) | |
| DATA_DIR = os.environ.get("DATA_DIR", "./data") | |
| ACCOUNTS_FILE = os.path.join(DATA_DIR, "accounts.json") | |
| CONFIG_FILE = os.path.join(DATA_DIR, "config.json") | |
| # Google OAuth 配置 (Antigravity 使用的 Client ID) | |
| OAUTH_CLIENT_ID = os.environ.get( | |
| "OAUTH_CLIENT_ID", | |
| "595848968694-r5ng3t6qb9elhe1u1h1hqgq4j2r3hgvk.apps.googleusercontent.com" | |
| ) | |
| # 默认使用 AI Studio 的公开 Client Secret | |
| OAUTH_CLIENT_SECRET = os.environ.get( | |
| "OAUTH_CLIENT_SECRET", | |
| "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" | |
| ) | |
| OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token" | |
| class AccountManager: | |
| """账号管理器 - 支持 OAuth Token""" | |
| def __init__(self): | |
| self._accounts: dict[str, Account] = {} | |
| self._current_index = 0 | |
| self._ensure_data_dir() | |
| self._load_accounts() | |
| def _ensure_data_dir(self): | |
| """确保数据目录存在""" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| def _load_accounts(self): | |
| """从文件加载账号""" | |
| if os.path.exists(ACCOUNTS_FILE): | |
| try: | |
| with open(ACCOUNTS_FILE, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| for item in data: | |
| account = Account(**item) | |
| self._accounts[account.id] = account | |
| except Exception as e: | |
| print(f"加载账号失败: {e}") | |
| def _save_accounts(self): | |
| """保存账号到文件""" | |
| try: | |
| data = [acc.model_dump(mode="json") for acc in self._accounts.values()] | |
| with open(ACCOUNTS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2, default=str) | |
| except Exception as e: | |
| print(f"保存账号失败: {e}") | |
| def add_account( | |
| self, | |
| email: str, | |
| access_token: str, | |
| refresh_token: str, | |
| expires_in: int = 3600, | |
| project_id: Optional[str] = None | |
| ) -> Account: | |
| """添加新账号""" | |
| now = int(datetime.now().timestamp()) | |
| token = OAuthToken( | |
| access_token=access_token, | |
| refresh_token=refresh_token, | |
| expires_in=expires_in, | |
| expiry_timestamp=now + expires_in, | |
| project_id=project_id | |
| ) | |
| account = Account(email=email, token=token) | |
| self._accounts[account.id] = account | |
| self._save_accounts() | |
| return account | |
| def remove_account(self, account_id: str) -> bool: | |
| """删除账号""" | |
| if account_id in self._accounts: | |
| del self._accounts[account_id] | |
| self._save_accounts() | |
| return True | |
| return False | |
| def get_account(self, account_id: str) -> Optional[Account]: | |
| """获取单个账号""" | |
| return self._accounts.get(account_id) | |
| def get_all_accounts(self) -> List[Account]: | |
| """获取所有账号""" | |
| return list(self._accounts.values()) | |
| def get_available_accounts(self) -> List[Account]: | |
| """获取所有可用账号""" | |
| return [acc for acc in self._accounts.values() if acc.is_available()] | |
| async def get_next_token(self) -> Optional[Account]: | |
| """ | |
| 获取下一个可用的 Token(轮询机制) | |
| 自动刷新过期的 Token | |
| """ | |
| available = self.get_available_accounts() | |
| if not available: | |
| return None | |
| # Round Robin | |
| self._current_index = self._current_index % len(available) | |
| account = available[self._current_index] | |
| self._current_index += 1 | |
| # 检查并刷新过期 Token | |
| if account.is_token_expired(): | |
| print(f"账号 {account.email} 的 token 即将过期,正在刷新...") | |
| try: | |
| await self._refresh_token(account) | |
| except Exception as e: | |
| print(f"刷新 token 失败: {e}") | |
| # 继续使用可能过期的 token,让 API 返回错误 | |
| return account | |
| async def _refresh_token(self, account: Account): | |
| """刷新 OAuth Token""" | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| OAUTH_TOKEN_URL, | |
| data={ | |
| "client_id": OAUTH_CLIENT_ID, | |
| "client_secret": OAUTH_CLIENT_SECRET, | |
| "refresh_token": account.token.refresh_token, | |
| "grant_type": "refresh_token" | |
| } | |
| ) | |
| if response.status_code != 200: | |
| raise Exception(f"刷新失败: {response.text}") | |
| data = response.json() | |
| now = int(datetime.now().timestamp()) | |
| account.token.access_token = data["access_token"] | |
| account.token.expires_in = data.get("expires_in", 3600) | |
| account.token.expiry_timestamp = now + account.token.expires_in | |
| self._save_accounts() | |
| print(f"Token 刷新成功!有效期: {account.token.expires_in} 秒") | |
| def update_account_stats(self, account_id: str, success: bool, error: str = None): | |
| """更新账号统计信息""" | |
| account = self._accounts.get(account_id) | |
| if account: | |
| account.total_requests += 1 | |
| account.last_used = datetime.now() | |
| if success: | |
| account.successful_requests += 1 | |
| account.last_error = None | |
| else: | |
| account.failed_requests += 1 | |
| account.last_error = error | |
| self._save_accounts() | |
| def set_account_cooldown(self, account_id: str, duration_seconds: int): | |
| """设置账号冷却时间""" | |
| account = self._accounts.get(account_id) | |
| if account: | |
| account.cooldown_until = datetime.now() + timedelta(seconds=duration_seconds) | |
| self._save_accounts() | |
| def toggle_account(self, account_id: str) -> bool: | |
| """切换账号启用状态""" | |
| account = self._accounts.get(account_id) | |
| if account: | |
| account.enabled = not account.enabled | |
| self._save_accounts() | |
| return account.enabled | |
| return False | |
| def get_stats(self) -> AccountStats: | |
| """获取统计汇总""" | |
| accounts = list(self._accounts.values()) | |
| total_requests = sum(acc.total_requests for acc in accounts) | |
| successful = sum(acc.successful_requests for acc in accounts) | |
| return AccountStats( | |
| total_accounts=len(accounts), | |
| available_accounts=len([a for a in accounts if a.is_available()]), | |
| total_requests=total_requests, | |
| success_rate=successful / total_requests if total_requests > 0 else 0.0 | |
| ) | |
| class ConfigManager: | |
| """配置管理器 - 管理 API Key 等可变配置""" | |
| def __init__(self): | |
| self._config = { | |
| "api_key": "sk-antigravity" | |
| } | |
| self._ensure_data_dir() | |
| self._load_config() | |
| def _ensure_data_dir(self): | |
| """确保数据目录存在""" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| def _load_config(self): | |
| """从文件加载配置""" | |
| if os.path.exists(CONFIG_FILE): | |
| try: | |
| with open(CONFIG_FILE, "r", encoding="utf-8") as f: | |
| self._config.update(json.load(f)) | |
| except Exception as e: | |
| print(f"加载配置失败: {e}") | |
| def _save_config(self): | |
| """保存配置到文件""" | |
| try: | |
| with open(CONFIG_FILE, "w", encoding="utf-8") as f: | |
| json.dump(self._config, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| print(f"保存配置失败: {e}") | |
| def get_api_key(self) -> str: | |
| """获取 API Key""" | |
| return self._config.get("api_key", "sk-antigravity") | |
| def set_api_key(self, api_key: str) -> bool: | |
| """设置 API Key""" | |
| if not api_key or len(api_key.strip()) == 0: | |
| return False | |
| self._config["api_key"] = api_key.strip() | |
| self._save_config() | |
| return True | |
| # 全局单例 | |
| account_manager = AccountManager() | |
| config_manager = ConfigManager() | |