File size: 8,618 Bytes
ec41d51
 
 
 
 
 
 
 
 
 
 
 
 
731d915
ec41d51
 
984e901
 
 
 
2fbc7af
 
 
 
 
ec41d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731d915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec41d51
 
731d915
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""
账号池管理器 - 支持 OAuth Token 的增删改查和自动刷新
"""
import json
import os
import httpx
from typing import List, Optional
from datetime import datetime, timedelta
from models import Account, OAuthToken, AccountStats

# 数据文件路径 (HF Spaces 持久化目录)
DATA_DIR = os.environ.get("DATA_DIR", "./data")
ACCOUNTS_FILE = os.path.join(DATA_DIR, "accounts.json")
CONFIG_FILE = os.path.join(DATA_DIR, "config.json")

# Google OAuth 配置 (Antigravity 使用的 Client ID)
OAUTH_CLIENT_ID = os.environ.get(
    "OAUTH_CLIENT_ID", 
    "595848968694-r5ng3t6qb9elhe1u1h1hqgq4j2r3hgvk.apps.googleusercontent.com"
)
# 默认使用 AI Studio 的公开 Client Secret
OAUTH_CLIENT_SECRET = os.environ.get(
    "OAUTH_CLIENT_SECRET", 
    "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"
)
OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"


class AccountManager:
    """账号管理器 - 支持 OAuth Token"""
    
    def __init__(self):
        self._accounts: dict[str, Account] = {}
        self._current_index = 0
        self._ensure_data_dir()
        self._load_accounts()
    
    def _ensure_data_dir(self):
        """确保数据目录存在"""
        os.makedirs(DATA_DIR, exist_ok=True)
    
    def _load_accounts(self):
        """从文件加载账号"""
        if os.path.exists(ACCOUNTS_FILE):
            try:
                with open(ACCOUNTS_FILE, "r", encoding="utf-8") as f:
                    data = json.load(f)
                    for item in data:
                        account = Account(**item)
                        self._accounts[account.id] = account
            except Exception as e:
                print(f"加载账号失败: {e}")
    
    def _save_accounts(self):
        """保存账号到文件"""
        try:
            data = [acc.model_dump(mode="json") for acc in self._accounts.values()]
            with open(ACCOUNTS_FILE, "w", encoding="utf-8") as f:
                json.dump(data, f, ensure_ascii=False, indent=2, default=str)
        except Exception as e:
            print(f"保存账号失败: {e}")
    
    def add_account(
        self, 
        email: str, 
        access_token: str, 
        refresh_token: str,
        expires_in: int = 3600,
        project_id: Optional[str] = None
    ) -> Account:
        """添加新账号"""
        now = int(datetime.now().timestamp())
        token = OAuthToken(
            access_token=access_token,
            refresh_token=refresh_token,
            expires_in=expires_in,
            expiry_timestamp=now + expires_in,
            project_id=project_id
        )
        account = Account(email=email, token=token)
        self._accounts[account.id] = account
        self._save_accounts()
        return account
    
    def remove_account(self, account_id: str) -> bool:
        """删除账号"""
        if account_id in self._accounts:
            del self._accounts[account_id]
            self._save_accounts()
            return True
        return False
    
    def get_account(self, account_id: str) -> Optional[Account]:
        """获取单个账号"""
        return self._accounts.get(account_id)
    
    def get_all_accounts(self) -> List[Account]:
        """获取所有账号"""
        return list(self._accounts.values())
    
    def get_available_accounts(self) -> List[Account]:
        """获取所有可用账号"""
        return [acc for acc in self._accounts.values() if acc.is_available()]
    
    async def get_next_token(self) -> Optional[Account]:
        """
        获取下一个可用的 Token(轮询机制)
        自动刷新过期的 Token
        """
        available = self.get_available_accounts()
        if not available:
            return None
        
        # Round Robin
        self._current_index = self._current_index % len(available)
        account = available[self._current_index]
        self._current_index += 1
        
        # 检查并刷新过期 Token
        if account.is_token_expired():
            print(f"账号 {account.email} 的 token 即将过期,正在刷新...")
            try:
                await self._refresh_token(account)
            except Exception as e:
                print(f"刷新 token 失败: {e}")
                # 继续使用可能过期的 token,让 API 返回错误
        
        return account
    
    async def _refresh_token(self, account: Account):
        """刷新 OAuth Token"""
        async with httpx.AsyncClient() as client:
            response = await client.post(
                OAUTH_TOKEN_URL,
                data={
                    "client_id": OAUTH_CLIENT_ID,
                    "client_secret": OAUTH_CLIENT_SECRET,
                    "refresh_token": account.token.refresh_token,
                    "grant_type": "refresh_token"
                }
            )
            
            if response.status_code != 200:
                raise Exception(f"刷新失败: {response.text}")
            
            data = response.json()
            now = int(datetime.now().timestamp())
            
            account.token.access_token = data["access_token"]
            account.token.expires_in = data.get("expires_in", 3600)
            account.token.expiry_timestamp = now + account.token.expires_in
            
            self._save_accounts()
            print(f"Token 刷新成功!有效期: {account.token.expires_in} 秒")
    
    def update_account_stats(self, account_id: str, success: bool, error: str = None):
        """更新账号统计信息"""
        account = self._accounts.get(account_id)
        if account:
            account.total_requests += 1
            account.last_used = datetime.now()
            if success:
                account.successful_requests += 1
                account.last_error = None
            else:
                account.failed_requests += 1
                account.last_error = error
            self._save_accounts()
    
    def set_account_cooldown(self, account_id: str, duration_seconds: int):
        """设置账号冷却时间"""
        account = self._accounts.get(account_id)
        if account:
            account.cooldown_until = datetime.now() + timedelta(seconds=duration_seconds)
            self._save_accounts()
    
    def toggle_account(self, account_id: str) -> bool:
        """切换账号启用状态"""
        account = self._accounts.get(account_id)
        if account:
            account.enabled = not account.enabled
            self._save_accounts()
            return account.enabled
        return False
    
    def get_stats(self) -> AccountStats:
        """获取统计汇总"""
        accounts = list(self._accounts.values())
        total_requests = sum(acc.total_requests for acc in accounts)
        successful = sum(acc.successful_requests for acc in accounts)
        
        return AccountStats(
            total_accounts=len(accounts),
            available_accounts=len([a for a in accounts if a.is_available()]),
            total_requests=total_requests,
            success_rate=successful / total_requests if total_requests > 0 else 0.0
        )


class ConfigManager:
    """配置管理器 - 管理 API Key 等可变配置"""
    
    def __init__(self):
        self._config = {
            "api_key": "sk-antigravity"
        }
        self._ensure_data_dir()
        self._load_config()
    
    def _ensure_data_dir(self):
        """确保数据目录存在"""
        os.makedirs(DATA_DIR, exist_ok=True)
    
    def _load_config(self):
        """从文件加载配置"""
        if os.path.exists(CONFIG_FILE):
            try:
                with open(CONFIG_FILE, "r", encoding="utf-8") as f:
                    self._config.update(json.load(f))
            except Exception as e:
                print(f"加载配置失败: {e}")
    
    def _save_config(self):
        """保存配置到文件"""
        try:
            with open(CONFIG_FILE, "w", encoding="utf-8") as f:
                json.dump(self._config, f, ensure_ascii=False, indent=2)
        except Exception as e:
            print(f"保存配置失败: {e}")
    
    def get_api_key(self) -> str:
        """获取 API Key"""
        return self._config.get("api_key", "sk-antigravity")
    
    def set_api_key(self, api_key: str) -> bool:
        """设置 API Key"""
        if not api_key or len(api_key.strip()) == 0:
            return False
        self._config["api_key"] = api_key.strip()
        self._save_config()
        return True


# 全局单例
account_manager = AccountManager()
config_manager = ConfigManager()