"""Grok Token 管理器 - 单例模式的Token负载均衡和状态管理""" import orjson import time import asyncio import aiofiles import portalocker from pathlib import Path from curl_cffi.requests import AsyncSession from typing import Dict, Any, Optional, Tuple from app.models.grok_models import TokenType, Models from app.core.exception import GrokApiException from app.core.logger import logger from app.core.config import setting from app.services.grok.statsig import get_dynamic_headers # 常量 RATE_LIMIT_API = "https://grok.com/rest/rate-limits" TIMEOUT = 30 BROWSER = "chrome133a" MAX_FAILURES = 3 TOKEN_INVALID = 401 STATSIG_INVALID = 403 # 冷却常量 COOLDOWN_REQUESTS = 5 # 普通失败冷却请求数 COOLDOWN_429_WITH_QUOTA = 3600 # 429+有额度冷却1小时(秒) COOLDOWN_429_NO_QUOTA = 36000 # 429+无额度冷却10小时(秒) class GrokTokenManager: """Token管理器(单例)""" _instance: Optional['GrokTokenManager'] = None _lock = asyncio.Lock() def __new__(cls) -> 'GrokTokenManager': if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if hasattr(self, '_initialized'): return self.token_file = Path(__file__).parents[3] / "data" / "token.json" self._file_lock = asyncio.Lock() self.token_file.parent.mkdir(parents=True, exist_ok=True) self._storage = None self.token_data = None # 延迟加载 # 批量保存队列 self._save_pending = False # 标记是否有待保存的数据 self._save_task = None # 后台保存任务 self._shutdown = False # 关闭标志 # 冷却状态 self._cooldown_counts: Dict[str, int] = {} # Token -> 剩余冷却次数 self._request_counter = 0 # 全局请求计数器 # 刷新状态 self._refresh_lock = False # 刷新锁 self._refresh_progress: Dict[str, Any] = {"running": False, "current": 0, "total": 0, "success": 0, "failed": 0} self._initialized = True logger.debug(f"[Token] 初始化完成: {self.token_file}") def set_storage(self, storage) -> None: """设置存储实例""" self._storage = storage async def _load_data(self) -> None: """异步加载Token数据(支持多进程)""" default = {TokenType.NORMAL.value: {}, TokenType.SUPER.value: {}} def load_sync(): with open(self.token_file, "r", encoding="utf-8") as f: portalocker.lock(f, portalocker.LOCK_SH) try: return orjson.loads(f.read()) finally: portalocker.unlock(f) try: if self.token_file.exists(): # 使用进程锁读取文件 async with self._file_lock: self.token_data = await asyncio.to_thread(load_sync) else: self.token_data = default logger.debug("[Token] 创建新数据文件") except Exception as e: logger.error(f"[Token] 加载失败: {e}") self.token_data = default async def _save_data(self) -> None: """保存Token数据(支持多进程)""" def save_sync(data): with open(self.token_file, "w", encoding="utf-8") as f: portalocker.lock(f, portalocker.LOCK_EX) try: content = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode() f.write(content) f.flush() finally: portalocker.unlock(f) try: if not self._storage: async with self._file_lock: await asyncio.to_thread(save_sync, self.token_data) else: await self._storage.save_tokens(self.token_data) except Exception as e: logger.error(f"[Token] 保存失败: {e}") raise GrokApiException(f"保存失败: {e}", "TOKEN_SAVE_ERROR") def _mark_dirty(self) -> None: """标记有待保存的数据""" self._save_pending = True async def _batch_save_worker(self) -> None: """批量保存后台任务""" from app.core.config import setting interval = setting.global_config.get("batch_save_interval", 1.0) logger.info(f"[Token] 存储任务已启动,间隔: {interval}s") while not self._shutdown: await asyncio.sleep(interval) if self._save_pending and not self._shutdown: try: await self._save_data() self._save_pending = False logger.debug("[Token] 存储完成") except Exception as e: logger.error(f"[Token] 存储失败: {e}") async def start_batch_save(self) -> None: """启动批量保存任务""" if self._save_task is None: self._save_task = asyncio.create_task(self._batch_save_worker()) logger.info("[Token] 存储任务已创建") async def shutdown(self) -> None: """关闭并刷新所有待保存数据""" self._shutdown = True if self._save_task: self._save_task.cancel() try: await self._save_task except asyncio.CancelledError: pass # 最终刷新 if self._save_pending: await self._save_data() logger.info("[Token] 关闭时刷新完成") @staticmethod def _extract_sso(auth_token: str) -> Optional[str]: """提取SSO值""" if "sso=" in auth_token: return auth_token.split("sso=")[1].split(";")[0] logger.warning("[Token] 无法提取SSO值") return None def _find_token(self, sso: str) -> Tuple[Optional[str], Optional[Dict]]: """查找Token""" for token_type in [TokenType.NORMAL.value, TokenType.SUPER.value]: if sso in self.token_data[token_type]: return token_type, self.token_data[token_type][sso] return None, None async def add_token(self, tokens: list[str], token_type: TokenType) -> None: """添加Token""" if not tokens: return count = 0 for token in tokens: if not token or not token.strip(): continue self.token_data[token_type.value][token] = { "createdTime": int(time.time() * 1000), "remainingQueries": -1, "heavyremainingQueries": -1, "status": "active", "failedCount": 0, "lastFailureTime": None, "lastFailureReason": None, "tags": [], "note": "" } count += 1 self._mark_dirty() # 批量保存 logger.info(f"[Token] 添加 {count} 个 {token_type.value} Token") async def delete_token(self, tokens: list[str], token_type: TokenType) -> None: """删除Token""" if not tokens: return count = 0 for token in tokens: if token in self.token_data[token_type.value]: del self.token_data[token_type.value][token] count += 1 self._mark_dirty() # 批量保存 logger.info(f"[Token] 删除 {count} 个 {token_type.value} Token") async def update_token_tags(self, token: str, token_type: TokenType, tags: list[str]) -> None: """更新Token标签""" if token not in self.token_data[token_type.value]: raise GrokApiException("Token不存在", "TOKEN_NOT_FOUND", {"token": token[:10]}) cleaned = [t.strip() for t in tags if t and t.strip()] self.token_data[token_type.value][token]["tags"] = cleaned self._mark_dirty() # 批量保存 logger.info(f"[Token] 更新标签: {token[:10]}... -> {cleaned}") async def update_token_note(self, token: str, token_type: TokenType, note: str) -> None: """更新Token备注""" if token not in self.token_data[token_type.value]: raise GrokApiException("Token不存在", "TOKEN_NOT_FOUND", {"token": token[:10]}) self.token_data[token_type.value][token]["note"] = note.strip() self._mark_dirty() # 批量保存 logger.info(f"[Token] 更新备注: {token[:10]}...") def get_tokens(self) -> Dict[str, Any]: """获取所有Token""" return self.token_data.copy() async def _reload_if_needed(self) -> None: """在多进程模式下重新加载数据""" # 只在文件模式且多进程环境下才重新加载 if self._storage: return def reload_sync(): with open(self.token_file, "r", encoding="utf-8") as f: portalocker.lock(f, portalocker.LOCK_SH) try: return orjson.loads(f.read()) finally: portalocker.unlock(f) try: if self.token_file.exists(): self.token_data = await asyncio.to_thread(reload_sync) except Exception as e: logger.warning(f"[Token] 重新加载失败: {e}") async def get_token(self, model: str) -> str: """获取Token""" jwt = await self.select_token(model) return f"sso-rw={jwt};sso={jwt}" async def select_token(self, model: str) -> str: """选择最优Token(多进程安全,支持冷却)""" # 重新加载最新数据(多进程模式) await self._reload_if_needed() # 递减所有次数冷却计数 self._request_counter += 1 for token in list(self._cooldown_counts.keys()): self._cooldown_counts[token] -= 1 if self._cooldown_counts[token] <= 0: del self._cooldown_counts[token] logger.debug(f"[Token] 冷却结束: {token[:10]}...") current_time = time.time() * 1000 # 毫秒 def select_best(tokens: Dict[str, Any], field: str) -> Tuple[Optional[str], Optional[int]]: """选择最佳Token""" unused, used = [], [] for key, data in tokens.items(): # 跳过已失效的token if data.get("status") == "expired": continue # 跳过失败次数过多的token(任何错误状态码) if data.get("failedCount", 0) >= MAX_FAILURES: continue # 跳过次数冷却中的token if key in self._cooldown_counts: continue # 跳过时间冷却中的token(429) cooldown_until = data.get("cooldownUntil", 0) if cooldown_until and cooldown_until > current_time: continue remaining = int(data.get(field, -1)) if remaining == 0: continue if remaining == -1: unused.append(key) elif remaining > 0: used.append((key, remaining)) if unused: return unused[0], -1 if used: used.sort(key=lambda x: x[1], reverse=True) return used[0][0], used[0][1] return None, None # 快照 snapshot = { TokenType.NORMAL.value: self.token_data[TokenType.NORMAL.value].copy(), TokenType.SUPER.value: self.token_data[TokenType.SUPER.value].copy() } # 选择策略 if model == "grok-4-heavy": field = "heavyremainingQueries" token_key, remaining = select_best(snapshot[TokenType.SUPER.value], field) else: field = "remainingQueries" token_key, remaining = select_best(snapshot[TokenType.NORMAL.value], field) if token_key is None: token_key, remaining = select_best(snapshot[TokenType.SUPER.value], field) if token_key is None: raise GrokApiException( f"没有可用Token: {model}", "NO_AVAILABLE_TOKEN", { "model": model, "normal": len(snapshot[TokenType.NORMAL.value]), "super": len(snapshot[TokenType.SUPER.value]), "cooldown_count": len(self._cooldown_counts) } ) status = "未使用" if remaining == -1 else f"剩余{remaining}次" logger.debug(f"[Token] 分配Token: {model} ({status})") return token_key async def check_limits(self, auth_token: str, model: str) -> Optional[Dict[str, Any]]: """检查速率限制""" try: rate_model = Models.to_rate_limit(model) payload = {"requestKind": "DEFAULT", "modelName": rate_model} cf = setting.grok_config.get("cf_clearance", "") headers = get_dynamic_headers("/rest/rate-limits") headers["Cookie"] = f"{auth_token};{cf}" if cf else auth_token # 外层重试:可配置状态码(401/429等) retry_codes = setting.grok_config.get("retry_status_codes", [401, 429]) MAX_OUTER_RETRY = 3 for outer_retry in range(MAX_OUTER_RETRY + 1): # +1 确保实际重试3次 # 内层重试:403代理池重试 max_403_retries = 5 retry_403_count = 0 while retry_403_count <= max_403_retries: # 异步获取代理(支持代理池) from app.core.proxy_pool import proxy_pool # 如果是403重试且使用代理池,强制刷新代理 if retry_403_count > 0 and proxy_pool._enabled: logger.info(f"[Token] 403重试 {retry_403_count}/{max_403_retries},刷新代理...") proxy = await proxy_pool.force_refresh() else: proxy = await setting.get_proxy_async("service") proxies = {"http": proxy, "https": proxy} if proxy else None async with AsyncSession() as session: response = await session.post( RATE_LIMIT_API, headers=headers, json=payload, impersonate=BROWSER, timeout=TIMEOUT, proxies=proxies ) # 内层403重试:仅当有代理池时触发 if response.status_code == 403 and proxy_pool._enabled: retry_403_count += 1 if retry_403_count <= max_403_retries: logger.warning(f"[Token] 遇到403错误,正在重试 ({retry_403_count}/{max_403_retries})...") await asyncio.sleep(0.5) continue # 内层重试全部失败 logger.error(f"[Token] 403错误,已重试{retry_403_count-1}次,放弃") sso = self._extract_sso(auth_token) if sso: await self.record_failure(auth_token, 403, "服务器被Block") # 检查可配置状态码错误 - 外层重试 if response.status_code in retry_codes: if outer_retry < MAX_OUTER_RETRY: delay = (outer_retry + 1) * 0.1 # 渐进延迟:0.1s, 0.2s, 0.3s logger.warning(f"[Token] 遇到{response.status_code}错误,外层重试 ({outer_retry+1}/{MAX_OUTER_RETRY}),等待{delay}s...") await asyncio.sleep(delay) break # 跳出内层循环,进入外层重试 else: logger.error(f"[Token] {response.status_code}错误,已重试{outer_retry}次,放弃") sso = self._extract_sso(auth_token) if sso: if response.status_code == 401: await self.record_failure(auth_token, 401, "Token失效") else: await self.record_failure(auth_token, response.status_code, f"错误: {response.status_code}") return None if response.status_code == 200: data = response.json() sso = self._extract_sso(auth_token) if outer_retry > 0 or retry_403_count > 0: logger.info(f"[Token] 重试成功!") if sso: if model == "grok-4-heavy": await self.update_limits(sso, normal=None, heavy=data.get("remainingQueries", -1)) logger.info(f"[Token] 更新限制: {sso[:10]}..., heavy={data.get('remainingQueries', -1)}") else: await self.update_limits(sso, normal=data.get("remainingTokens", -1), heavy=None) logger.info(f"[Token] 更新限制: {sso[:10]}..., basic={data.get('remainingTokens', -1)}") return data else: # 其他错误 logger.warning(f"[Token] 获取限制失败: {response.status_code}") sso = self._extract_sso(auth_token) if sso: await self.record_failure(auth_token, response.status_code, f"错误: {response.status_code}") return None except Exception as e: logger.error(f"[Token] 检查限制错误: {e}") return None async def update_limits(self, sso: str, normal: Optional[int] = None, heavy: Optional[int] = None) -> None: """更新限制""" try: for token_type in [TokenType.NORMAL.value, TokenType.SUPER.value]: if sso in self.token_data[token_type]: if normal is not None: self.token_data[token_type][sso]["remainingQueries"] = normal if heavy is not None: self.token_data[token_type][sso]["heavyremainingQueries"] = heavy self._mark_dirty() # 批量保存 logger.info(f"[Token] 更新限制: {sso[:10]}...") return logger.warning(f"[Token] 未找到: {sso[:10]}...") except Exception as e: logger.error(f"[Token] 更新限制错误: {e}") async def record_failure(self, auth_token: str, status: int, msg: str) -> None: """记录失败""" try: if status == STATSIG_INVALID: logger.warning("[Token] IP被Block,请: 1.更换IP 2.使用代理 3.配置CF值") return sso = self._extract_sso(auth_token) if not sso: return _, data = self._find_token(sso) if not data: logger.warning(f"[Token] 未找到: {sso[:10]}...") return data["failedCount"] = data.get("failedCount", 0) + 1 data["lastFailureTime"] = int(time.time() * 1000) data["lastFailureReason"] = f"{status}: {msg}" logger.warning( f"[Token] 失败: {sso[:10]}... (状态:{status}), " f"次数: {data['failedCount']}/{MAX_FAILURES}, 原因: {msg}" ) if 400 <= status < 500 and data["failedCount"] >= MAX_FAILURES: data["status"] = "expired" logger.error(f"[Token] 标记失效: {sso[:10]}... (连续{status}错误{data['failedCount']}次)") self._mark_dirty() # 批量保存 except Exception as e: logger.error(f"[Token] 记录失败错误: {e}") async def reset_failure(self, auth_token: str) -> None: """重置失败计数""" try: sso = self._extract_sso(auth_token) if not sso: return _, data = self._find_token(sso) if not data: return if data.get("failedCount", 0) > 0: data["failedCount"] = 0 data["lastFailureTime"] = None data["lastFailureReason"] = None self._mark_dirty() # 批量保存 logger.info(f"[Token] 重置失败计数: {sso[:10]}...") except Exception as e: logger.error(f"[Token] 重置失败错误: {e}") async def apply_cooldown(self, auth_token: str, status_code: int) -> None: """应用冷却策略 - 429 错误:使用时间冷却(有额度1小时,无额度10小时) - 其他错误:使用次数冷却(5次请求) """ try: sso = self._extract_sso(auth_token) if not sso: return _, data = self._find_token(sso) if not data: return remaining = data.get("remainingQueries", -1) if status_code == 429: # 429 使用时间冷却 if remaining > 0 or remaining == -1: # 有额度:冷却1小时 cooldown_until = time.time() + COOLDOWN_429_WITH_QUOTA logger.info(f"[Token] 429冷却(有额度): {sso[:10]}... 冷却1小时") else: # 无额度:冷却10小时 cooldown_until = time.time() + COOLDOWN_429_NO_QUOTA logger.info(f"[Token] 429冷却(无额度): {sso[:10]}... 冷却10小时") data["cooldownUntil"] = int(cooldown_until * 1000) self._mark_dirty() else: # 其他错误使用次数冷却(有额度时才冷却) if remaining != 0: self._cooldown_counts[sso] = COOLDOWN_REQUESTS logger.info(f"[Token] 次数冷却: {sso[:10]}... 冷却{COOLDOWN_REQUESTS}次请求") except Exception as e: logger.error(f"[Token] 应用冷却错误: {e}") async def refresh_all_limits(self) -> Dict[str, Any]: """刷新所有 Token 的剩余次数""" # 检查是否已在刷新 if self._refresh_lock: return {"error": "refresh_in_progress", "message": "已有刷新任务在进行中", "progress": self._refresh_progress} # 获取锁 self._refresh_lock = True try: # 计算总数 all_tokens = [] for token_type in [TokenType.NORMAL.value, TokenType.SUPER.value]: for sso in list(self.token_data[token_type].keys()): all_tokens.append((token_type, sso)) total = len(all_tokens) self._refresh_progress = {"running": True, "current": 0, "total": total, "success": 0, "failed": 0} success_count = 0 fail_count = 0 for i, (token_type, sso) in enumerate(all_tokens): auth_token = f"sso-rw={sso};sso={sso}" try: result = await self.check_limits(auth_token, "grok-4-fast") if result: success_count += 1 else: fail_count += 1 except Exception as e: logger.warning(f"[Token] 刷新失败: {sso[:10]}... - {e}") fail_count += 1 # 更新进度 self._refresh_progress = { "running": True, "current": i + 1, "total": total, "success": success_count, "failed": fail_count } await asyncio.sleep(0.1) # 避免请求过快 logger.info(f"[Token] 批量刷新完成: 成功{success_count}, 失败{fail_count}") self._refresh_progress = {"running": False, "current": total, "total": total, "success": success_count, "failed": fail_count} return {"success": success_count, "failed": fail_count, "total": total} finally: self._refresh_lock = False def get_refresh_progress(self) -> Dict[str, Any]: """获取刷新进度""" return self._refresh_progress.copy() # 全局实例 token_manager = GrokTokenManager()