| """Grok Token 管理器 - 单例模式的Token负载均衡和状态管理""" |
|
|
| import orjson |
| import time |
| import asyncio |
| import aiofiles |
| import portalocker |
| from pathlib import Path |
| from curl_cffi.requests import AsyncSession |
| from typing import Dict, Any, Optional, Tuple |
|
|
| from app.models.grok_models import TokenType, Models |
| from app.core.exception import GrokApiException |
| from app.core.logger import logger |
| from app.core.config import setting |
| from app.services.grok.statsig import get_dynamic_headers |
|
|
|
|
| |
| RATE_LIMIT_API = "https://grok.com/rest/rate-limits" |
| TIMEOUT = 30 |
| BROWSER = "chrome133a" |
| MAX_FAILURES = 3 |
| TOKEN_INVALID = 401 |
| STATSIG_INVALID = 403 |
|
|
| |
| COOLDOWN_REQUESTS = 5 |
| COOLDOWN_429_WITH_QUOTA = 3600 |
| COOLDOWN_429_NO_QUOTA = 36000 |
|
|
|
|
| class GrokTokenManager: |
| """Token管理器(单例)""" |
| |
| _instance: Optional['GrokTokenManager'] = None |
| _lock = asyncio.Lock() |
|
|
| def __new__(cls) -> 'GrokTokenManager': |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| return cls._instance |
|
|
| def __init__(self): |
| if hasattr(self, '_initialized'): |
| return |
|
|
| self.token_file = Path(__file__).parents[3] / "data" / "token.json" |
| self._file_lock = asyncio.Lock() |
| self.token_file.parent.mkdir(parents=True, exist_ok=True) |
| self._storage = None |
| self.token_data = None |
| |
| |
| self._save_pending = False |
| self._save_task = None |
| self._shutdown = False |
| |
| |
| self._cooldown_counts: Dict[str, int] = {} |
| self._request_counter = 0 |
| |
| |
| self._refresh_lock = False |
| self._refresh_progress: Dict[str, Any] = {"running": False, "current": 0, "total": 0, "success": 0, "failed": 0} |
| |
| self._initialized = True |
| logger.debug(f"[Token] 初始化完成: {self.token_file}") |
|
|
| def set_storage(self, storage) -> None: |
| """设置存储实例""" |
| self._storage = storage |
|
|
| async def _load_data(self) -> None: |
| """异步加载Token数据(支持多进程)""" |
| default = {TokenType.NORMAL.value: {}, TokenType.SUPER.value: {}} |
| |
| def load_sync(): |
| with open(self.token_file, "r", encoding="utf-8") as f: |
| portalocker.lock(f, portalocker.LOCK_SH) |
| try: |
| return orjson.loads(f.read()) |
| finally: |
| portalocker.unlock(f) |
|
|
| try: |
| if self.token_file.exists(): |
| |
| async with self._file_lock: |
| self.token_data = await asyncio.to_thread(load_sync) |
| else: |
| self.token_data = default |
| logger.debug("[Token] 创建新数据文件") |
| except Exception as e: |
| logger.error(f"[Token] 加载失败: {e}") |
| self.token_data = default |
|
|
| async def _save_data(self) -> None: |
| """保存Token数据(支持多进程)""" |
| def save_sync(data): |
| with open(self.token_file, "w", encoding="utf-8") as f: |
| portalocker.lock(f, portalocker.LOCK_EX) |
| try: |
| content = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode() |
| f.write(content) |
| f.flush() |
| finally: |
| portalocker.unlock(f) |
|
|
| try: |
| if not self._storage: |
| async with self._file_lock: |
| await asyncio.to_thread(save_sync, self.token_data) |
| else: |
| await self._storage.save_tokens(self.token_data) |
| except Exception as e: |
| logger.error(f"[Token] 保存失败: {e}") |
| raise GrokApiException(f"保存失败: {e}", "TOKEN_SAVE_ERROR") |
|
|
| def _mark_dirty(self) -> None: |
| """标记有待保存的数据""" |
| self._save_pending = True |
|
|
| async def _batch_save_worker(self) -> None: |
| """批量保存后台任务""" |
| from app.core.config import setting |
| |
| interval = setting.global_config.get("batch_save_interval", 1.0) |
| logger.info(f"[Token] 存储任务已启动,间隔: {interval}s") |
| |
| while not self._shutdown: |
| await asyncio.sleep(interval) |
| |
| if self._save_pending and not self._shutdown: |
| try: |
| await self._save_data() |
| self._save_pending = False |
| logger.debug("[Token] 存储完成") |
| except Exception as e: |
| logger.error(f"[Token] 存储失败: {e}") |
|
|
| async def start_batch_save(self) -> None: |
| """启动批量保存任务""" |
| if self._save_task is None: |
| self._save_task = asyncio.create_task(self._batch_save_worker()) |
| logger.info("[Token] 存储任务已创建") |
|
|
| async def shutdown(self) -> None: |
| """关闭并刷新所有待保存数据""" |
| self._shutdown = True |
| |
| if self._save_task: |
| self._save_task.cancel() |
| try: |
| await self._save_task |
| except asyncio.CancelledError: |
| pass |
| |
| |
| if self._save_pending: |
| await self._save_data() |
| logger.info("[Token] 关闭时刷新完成") |
|
|
| @staticmethod |
| def _extract_sso(auth_token: str) -> Optional[str]: |
| """提取SSO值""" |
| if "sso=" in auth_token: |
| return auth_token.split("sso=")[1].split(";")[0] |
| logger.warning("[Token] 无法提取SSO值") |
| return None |
|
|
| def _find_token(self, sso: str) -> Tuple[Optional[str], Optional[Dict]]: |
| """查找Token""" |
| for token_type in [TokenType.NORMAL.value, TokenType.SUPER.value]: |
| if sso in self.token_data[token_type]: |
| return token_type, self.token_data[token_type][sso] |
| return None, None |
|
|
| async def add_token(self, tokens: list[str], token_type: TokenType) -> None: |
| """添加Token""" |
| if not tokens: |
| return |
|
|
| count = 0 |
| for token in tokens: |
| if not token or not token.strip(): |
| continue |
|
|
| self.token_data[token_type.value][token] = { |
| "createdTime": int(time.time() * 1000), |
| "remainingQueries": -1, |
| "heavyremainingQueries": -1, |
| "status": "active", |
| "failedCount": 0, |
| "lastFailureTime": None, |
| "lastFailureReason": None, |
| "tags": [], |
| "note": "" |
| } |
| count += 1 |
|
|
| self._mark_dirty() |
| logger.info(f"[Token] 添加 {count} 个 {token_type.value} Token") |
|
|
| async def delete_token(self, tokens: list[str], token_type: TokenType) -> None: |
| """删除Token""" |
| if not tokens: |
| return |
|
|
| count = 0 |
| for token in tokens: |
| if token in self.token_data[token_type.value]: |
| del self.token_data[token_type.value][token] |
| count += 1 |
|
|
| self._mark_dirty() |
| logger.info(f"[Token] 删除 {count} 个 {token_type.value} Token") |
|
|
| async def update_token_tags(self, token: str, token_type: TokenType, tags: list[str]) -> None: |
| """更新Token标签""" |
| if token not in self.token_data[token_type.value]: |
| raise GrokApiException("Token不存在", "TOKEN_NOT_FOUND", {"token": token[:10]}) |
| |
| cleaned = [t.strip() for t in tags if t and t.strip()] |
| self.token_data[token_type.value][token]["tags"] = cleaned |
| self._mark_dirty() |
| logger.info(f"[Token] 更新标签: {token[:10]}... -> {cleaned}") |
|
|
| async def update_token_note(self, token: str, token_type: TokenType, note: str) -> None: |
| """更新Token备注""" |
| if token not in self.token_data[token_type.value]: |
| raise GrokApiException("Token不存在", "TOKEN_NOT_FOUND", {"token": token[:10]}) |
| |
| self.token_data[token_type.value][token]["note"] = note.strip() |
| self._mark_dirty() |
| logger.info(f"[Token] 更新备注: {token[:10]}...") |
| |
| def get_tokens(self) -> Dict[str, Any]: |
| """获取所有Token""" |
| return self.token_data.copy() |
|
|
| async def _reload_if_needed(self) -> None: |
| """在多进程模式下重新加载数据""" |
| |
| if self._storage: |
| return |
| |
| def reload_sync(): |
| with open(self.token_file, "r", encoding="utf-8") as f: |
| portalocker.lock(f, portalocker.LOCK_SH) |
| try: |
| return orjson.loads(f.read()) |
| finally: |
| portalocker.unlock(f) |
|
|
| try: |
| if self.token_file.exists(): |
| self.token_data = await asyncio.to_thread(reload_sync) |
| except Exception as e: |
| logger.warning(f"[Token] 重新加载失败: {e}") |
|
|
| async def get_token(self, model: str) -> str: |
| """获取Token""" |
| jwt = await self.select_token(model) |
| return f"sso-rw={jwt};sso={jwt}" |
| |
| async def select_token(self, model: str) -> str: |
| """选择最优Token(多进程安全,支持冷却)""" |
| |
| await self._reload_if_needed() |
| |
| |
| self._request_counter += 1 |
| for token in list(self._cooldown_counts.keys()): |
| self._cooldown_counts[token] -= 1 |
| if self._cooldown_counts[token] <= 0: |
| del self._cooldown_counts[token] |
| logger.debug(f"[Token] 冷却结束: {token[:10]}...") |
| |
| current_time = time.time() * 1000 |
| |
| def select_best(tokens: Dict[str, Any], field: str) -> Tuple[Optional[str], Optional[int]]: |
| """选择最佳Token""" |
| unused, used = [], [] |
|
|
| for key, data in tokens.items(): |
| |
| if data.get("status") == "expired": |
| continue |
| |
| |
| if data.get("failedCount", 0) >= MAX_FAILURES: |
| continue |
| |
| |
| if key in self._cooldown_counts: |
| continue |
| |
| |
| cooldown_until = data.get("cooldownUntil", 0) |
| if cooldown_until and cooldown_until > current_time: |
| continue |
|
|
| remaining = int(data.get(field, -1)) |
| if remaining == 0: |
| continue |
|
|
| if remaining == -1: |
| unused.append(key) |
| elif remaining > 0: |
| used.append((key, remaining)) |
|
|
| if unused: |
| return unused[0], -1 |
| if used: |
| used.sort(key=lambda x: x[1], reverse=True) |
| return used[0][0], used[0][1] |
| return None, None |
|
|
| |
| snapshot = { |
| TokenType.NORMAL.value: self.token_data[TokenType.NORMAL.value].copy(), |
| TokenType.SUPER.value: self.token_data[TokenType.SUPER.value].copy() |
| } |
|
|
| |
| if model == "grok-4-heavy": |
| field = "heavyremainingQueries" |
| token_key, remaining = select_best(snapshot[TokenType.SUPER.value], field) |
| else: |
| field = "remainingQueries" |
| token_key, remaining = select_best(snapshot[TokenType.NORMAL.value], field) |
| if token_key is None: |
| token_key, remaining = select_best(snapshot[TokenType.SUPER.value], field) |
|
|
| if token_key is None: |
| raise GrokApiException( |
| f"没有可用Token: {model}", |
| "NO_AVAILABLE_TOKEN", |
| { |
| "model": model, |
| "normal": len(snapshot[TokenType.NORMAL.value]), |
| "super": len(snapshot[TokenType.SUPER.value]), |
| "cooldown_count": len(self._cooldown_counts) |
| } |
| ) |
|
|
| status = "未使用" if remaining == -1 else f"剩余{remaining}次" |
| logger.debug(f"[Token] 分配Token: {model} ({status})") |
| return token_key |
| |
| async def check_limits(self, auth_token: str, model: str) -> Optional[Dict[str, Any]]: |
| """检查速率限制""" |
| try: |
| rate_model = Models.to_rate_limit(model) |
| payload = {"requestKind": "DEFAULT", "modelName": rate_model} |
| |
| cf = setting.grok_config.get("cf_clearance", "") |
| headers = get_dynamic_headers("/rest/rate-limits") |
| headers["Cookie"] = f"{auth_token};{cf}" if cf else auth_token |
|
|
| |
| retry_codes = setting.grok_config.get("retry_status_codes", [401, 429]) |
| MAX_OUTER_RETRY = 3 |
| |
| for outer_retry in range(MAX_OUTER_RETRY + 1): |
| |
| max_403_retries = 5 |
| retry_403_count = 0 |
| |
| while retry_403_count <= max_403_retries: |
| |
| from app.core.proxy_pool import proxy_pool |
| |
| |
| if retry_403_count > 0 and proxy_pool._enabled: |
| logger.info(f"[Token] 403重试 {retry_403_count}/{max_403_retries},刷新代理...") |
| proxy = await proxy_pool.force_refresh() |
| else: |
| proxy = await setting.get_proxy_async("service") |
| |
| proxies = {"http": proxy, "https": proxy} if proxy else None |
| |
| async with AsyncSession() as session: |
| response = await session.post( |
| RATE_LIMIT_API, |
| headers=headers, |
| json=payload, |
| impersonate=BROWSER, |
| timeout=TIMEOUT, |
| proxies=proxies |
| ) |
|
|
| |
| if response.status_code == 403 and proxy_pool._enabled: |
| retry_403_count += 1 |
| |
| if retry_403_count <= max_403_retries: |
| logger.warning(f"[Token] 遇到403错误,正在重试 ({retry_403_count}/{max_403_retries})...") |
| await asyncio.sleep(0.5) |
| continue |
| |
| |
| logger.error(f"[Token] 403错误,已重试{retry_403_count-1}次,放弃") |
| sso = self._extract_sso(auth_token) |
| if sso: |
| await self.record_failure(auth_token, 403, "服务器被Block") |
| |
| |
| if response.status_code in retry_codes: |
| if outer_retry < MAX_OUTER_RETRY: |
| delay = (outer_retry + 1) * 0.1 |
| logger.warning(f"[Token] 遇到{response.status_code}错误,外层重试 ({outer_retry+1}/{MAX_OUTER_RETRY}),等待{delay}s...") |
| await asyncio.sleep(delay) |
| break |
| else: |
| logger.error(f"[Token] {response.status_code}错误,已重试{outer_retry}次,放弃") |
| sso = self._extract_sso(auth_token) |
| if sso: |
| if response.status_code == 401: |
| await self.record_failure(auth_token, 401, "Token失效") |
| else: |
| await self.record_failure(auth_token, response.status_code, f"错误: {response.status_code}") |
| return None |
|
|
| if response.status_code == 200: |
| data = response.json() |
| sso = self._extract_sso(auth_token) |
| |
| if outer_retry > 0 or retry_403_count > 0: |
| logger.info(f"[Token] 重试成功!") |
| |
| if sso: |
| if model == "grok-4-heavy": |
| await self.update_limits(sso, normal=None, heavy=data.get("remainingQueries", -1)) |
| logger.info(f"[Token] 更新限制: {sso[:10]}..., heavy={data.get('remainingQueries', -1)}") |
| else: |
| await self.update_limits(sso, normal=data.get("remainingTokens", -1), heavy=None) |
| logger.info(f"[Token] 更新限制: {sso[:10]}..., basic={data.get('remainingTokens', -1)}") |
| |
| return data |
| else: |
| |
| logger.warning(f"[Token] 获取限制失败: {response.status_code}") |
| sso = self._extract_sso(auth_token) |
| if sso: |
| await self.record_failure(auth_token, response.status_code, f"错误: {response.status_code}") |
| return None |
|
|
| except Exception as e: |
| logger.error(f"[Token] 检查限制错误: {e}") |
| return None |
|
|
| async def update_limits(self, sso: str, normal: Optional[int] = None, heavy: Optional[int] = None) -> None: |
| """更新限制""" |
| try: |
| for token_type in [TokenType.NORMAL.value, TokenType.SUPER.value]: |
| if sso in self.token_data[token_type]: |
| if normal is not None: |
| self.token_data[token_type][sso]["remainingQueries"] = normal |
| if heavy is not None: |
| self.token_data[token_type][sso]["heavyremainingQueries"] = heavy |
| self._mark_dirty() |
| logger.info(f"[Token] 更新限制: {sso[:10]}...") |
| return |
| logger.warning(f"[Token] 未找到: {sso[:10]}...") |
| except Exception as e: |
| logger.error(f"[Token] 更新限制错误: {e}") |
| |
| async def record_failure(self, auth_token: str, status: int, msg: str) -> None: |
| """记录失败""" |
| try: |
| if status == STATSIG_INVALID: |
| logger.warning("[Token] IP被Block,请: 1.更换IP 2.使用代理 3.配置CF值") |
| return |
|
|
| sso = self._extract_sso(auth_token) |
| if not sso: |
| return |
|
|
| _, data = self._find_token(sso) |
| if not data: |
| logger.warning(f"[Token] 未找到: {sso[:10]}...") |
| return |
|
|
| data["failedCount"] = data.get("failedCount", 0) + 1 |
| data["lastFailureTime"] = int(time.time() * 1000) |
| data["lastFailureReason"] = f"{status}: {msg}" |
|
|
| logger.warning( |
| f"[Token] 失败: {sso[:10]}... (状态:{status}), " |
| f"次数: {data['failedCount']}/{MAX_FAILURES}, 原因: {msg}" |
| ) |
|
|
| if 400 <= status < 500 and data["failedCount"] >= MAX_FAILURES: |
| data["status"] = "expired" |
| logger.error(f"[Token] 标记失效: {sso[:10]}... (连续{status}错误{data['failedCount']}次)") |
|
|
| self._mark_dirty() |
|
|
| except Exception as e: |
| logger.error(f"[Token] 记录失败错误: {e}") |
|
|
| async def reset_failure(self, auth_token: str) -> None: |
| """重置失败计数""" |
| try: |
| sso = self._extract_sso(auth_token) |
| if not sso: |
| return |
|
|
| _, data = self._find_token(sso) |
| if not data: |
| return |
|
|
| if data.get("failedCount", 0) > 0: |
| data["failedCount"] = 0 |
| data["lastFailureTime"] = None |
| data["lastFailureReason"] = None |
| self._mark_dirty() |
| logger.info(f"[Token] 重置失败计数: {sso[:10]}...") |
|
|
| except Exception as e: |
| logger.error(f"[Token] 重置失败错误: {e}") |
|
|
| async def apply_cooldown(self, auth_token: str, status_code: int) -> None: |
| """应用冷却策略 |
| - 429 错误:使用时间冷却(有额度1小时,无额度10小时) |
| - 其他错误:使用次数冷却(5次请求) |
| """ |
| try: |
| sso = self._extract_sso(auth_token) |
| if not sso: |
| return |
| |
| _, data = self._find_token(sso) |
| if not data: |
| return |
| |
| remaining = data.get("remainingQueries", -1) |
| |
| if status_code == 429: |
| |
| if remaining > 0 or remaining == -1: |
| |
| cooldown_until = time.time() + COOLDOWN_429_WITH_QUOTA |
| logger.info(f"[Token] 429冷却(有额度): {sso[:10]}... 冷却1小时") |
| else: |
| |
| cooldown_until = time.time() + COOLDOWN_429_NO_QUOTA |
| logger.info(f"[Token] 429冷却(无额度): {sso[:10]}... 冷却10小时") |
| data["cooldownUntil"] = int(cooldown_until * 1000) |
| self._mark_dirty() |
| else: |
| |
| if remaining != 0: |
| self._cooldown_counts[sso] = COOLDOWN_REQUESTS |
| logger.info(f"[Token] 次数冷却: {sso[:10]}... 冷却{COOLDOWN_REQUESTS}次请求") |
| |
| except Exception as e: |
| logger.error(f"[Token] 应用冷却错误: {e}") |
|
|
| async def refresh_all_limits(self) -> Dict[str, Any]: |
| """刷新所有 Token 的剩余次数""" |
| |
| if self._refresh_lock: |
| return {"error": "refresh_in_progress", "message": "已有刷新任务在进行中", "progress": self._refresh_progress} |
| |
| |
| self._refresh_lock = True |
| |
| try: |
| |
| all_tokens = [] |
| for token_type in [TokenType.NORMAL.value, TokenType.SUPER.value]: |
| for sso in list(self.token_data[token_type].keys()): |
| all_tokens.append((token_type, sso)) |
| |
| total = len(all_tokens) |
| self._refresh_progress = {"running": True, "current": 0, "total": total, "success": 0, "failed": 0} |
| |
| success_count = 0 |
| fail_count = 0 |
| |
| for i, (token_type, sso) in enumerate(all_tokens): |
| auth_token = f"sso-rw={sso};sso={sso}" |
| try: |
| result = await self.check_limits(auth_token, "grok-4-fast") |
| if result: |
| success_count += 1 |
| else: |
| fail_count += 1 |
| except Exception as e: |
| logger.warning(f"[Token] 刷新失败: {sso[:10]}... - {e}") |
| fail_count += 1 |
| |
| |
| self._refresh_progress = { |
| "running": True, |
| "current": i + 1, |
| "total": total, |
| "success": success_count, |
| "failed": fail_count |
| } |
| await asyncio.sleep(0.1) |
| |
| logger.info(f"[Token] 批量刷新完成: 成功{success_count}, 失败{fail_count}") |
| self._refresh_progress = {"running": False, "current": total, "total": total, "success": success_count, "failed": fail_count} |
| return {"success": success_count, "failed": fail_count, "total": total} |
| |
| finally: |
| self._refresh_lock = False |
| |
| def get_refresh_progress(self) -> Dict[str, Any]: |
| """获取刷新进度""" |
| return self._refresh_progress.copy() |
|
|
|
|
| |
| token_manager = GrokTokenManager() |
|
|