"""Token 管理服务""" import asyncio import time from datetime import datetime from typing import Dict, List, Optional from app.core.logger import logger from app.services.token.models import ( TokenInfo, EffortType, FAIL_THRESHOLD, TokenStatus, BASIC__DEFAULT_QUOTA, SUPER_DEFAULT_QUOTA, ) from app.core.storage import get_storage, LocalStorage from app.core.config import get_config from app.core.exceptions import UpstreamException from app.services.token.pool import TokenPool from app.services.grok.batch_services.usage import UsageService DEFAULT_REFRESH_BATCH_SIZE = 10 DEFAULT_REFRESH_CONCURRENCY = 5 DEFAULT_SUPER_REFRESH_INTERVAL_HOURS = 2 DEFAULT_REFRESH_INTERVAL_HOURS = 8 DEFAULT_RELOAD_INTERVAL_SEC = 30 DEFAULT_SAVE_DELAY_MS = 500 DEFAULT_USAGE_FLUSH_INTERVAL_SEC = 5 SUPER_POOL_NAME = "ssoSuper" BASIC_POOL_NAME = "ssoBasic" def _default_quota_for_pool(pool_name: str) -> int: if pool_name == SUPER_POOL_NAME: return SUPER_DEFAULT_QUOTA return BASIC__DEFAULT_QUOTA class TokenManager: """管理 Token 的增删改查和配额同步""" _instance: Optional["TokenManager"] = None _lock = asyncio.Lock() def __init__(self): self.pools: Dict[str, TokenPool] = {} self.initialized = False self._save_lock = asyncio.Lock() self._dirty = False self._save_task: Optional[asyncio.Task] = None self._save_delay = DEFAULT_SAVE_DELAY_MS / 1000.0 self._last_reload_at = 0.0 self._has_state_changes = False self._has_usage_changes = False self._state_change_seq = 0 self._usage_change_seq = 0 self._last_usage_flush_at = 0.0 self._dirty_tokens = {} self._dirty_deletes = set() @classmethod async def get_instance(cls) -> "TokenManager": """获取单例实例""" if cls._instance is None: async with cls._lock: if cls._instance is None: cls._instance = cls() await cls._instance._load() return cls._instance async def _load(self): """初始化加载""" if not self.initialized: try: storage = get_storage() data = await storage.load_tokens() # 如果后端返回 None 或空数据,尝试从本地 data/token.json 初始化后端 if not data: local_storage = LocalStorage() local_data = await local_storage.load_tokens() if local_data: data = local_data await storage.save_tokens(local_data) logger.info( f"Initialized remote token storage ({storage.__class__.__name__}) with local tokens." ) else: data = {} self.pools = {} for pool_name, tokens in data.items(): pool = TokenPool(pool_name) for token_data in tokens: quota_missing = not ( isinstance(token_data, dict) and "quota" in token_data ) try: # 统一存储裸 token if isinstance(token_data, dict): raw_token = token_data.get("token") if isinstance(raw_token, str) and raw_token.startswith( "sso=" ): token_data["token"] = raw_token[4:] token_info = TokenInfo(**token_data) if quota_missing and pool_name == SUPER_POOL_NAME: token_info.quota = SUPER_DEFAULT_QUOTA pool.add(token_info) except Exception as e: logger.warning( f"Failed to load token in pool '{pool_name}': {e}" ) continue pool._rebuild_index() self.pools[pool_name] = pool self.initialized = True self._last_reload_at = time.monotonic() total = sum(p.count() for p in self.pools.values()) logger.info( f"TokenManager initialized: {len(self.pools)} pools with {total} tokens" ) except Exception as e: logger.error(f"Failed to initialize TokenManager: {e}") self.pools = {} self.initialized = True async def reload(self): """重新加载 Token 池数据""" async with self.__class__._lock: self.initialized = False await self._load() async def reload_if_stale(self): """在多 worker 场景下保持短周期一致性""" interval = get_config("token.reload_interval_sec", DEFAULT_RELOAD_INTERVAL_SEC) try: interval = float(interval) except Exception: interval = float(DEFAULT_RELOAD_INTERVAL_SEC) if interval <= 0: return if time.monotonic() - self._last_reload_at < interval: return await self.reload() def _mark_state_change(self): self._has_state_changes = True self._state_change_seq += 1 def _mark_usage_change(self): self._has_usage_changes = True self._usage_change_seq += 1 def _track_token_change( self, token: TokenInfo, pool_name: str, change_kind: str ): token_key = token.token if token_key.startswith("sso="): token_key = token_key[4:] if token_key in self._dirty_deletes: self._dirty_deletes.remove(token_key) existing = self._dirty_tokens.get(token_key) if existing and existing[1] == "state": change_kind = "state" self._dirty_tokens[token_key] = (pool_name, change_kind) if change_kind == "state": self._mark_state_change() else: self._mark_usage_change() def _track_token_delete(self, token_str: str): token_key = token_str if token_key.startswith("sso="): token_key = token_key[4:] self._dirty_deletes.add(token_key) if token_key in self._dirty_tokens: del self._dirty_tokens[token_key] self._mark_state_change() async def _save(self, force: bool = False): """保存变更""" async with self._save_lock: try: if not self._dirty_tokens and not self._dirty_deletes: return if not force and not self._has_state_changes: interval_sec = get_config( "token.usage_flush_interval_sec", DEFAULT_USAGE_FLUSH_INTERVAL_SEC, ) try: interval_sec = float(interval_sec) except Exception: interval_sec = float(DEFAULT_USAGE_FLUSH_INTERVAL_SEC) if interval_sec > 0: now = time.monotonic() if now - self._last_usage_flush_at < interval_sec: self._dirty = True return state_seq = self._state_change_seq usage_seq = self._usage_change_seq dirty_tokens = self._dirty_tokens dirty_deletes = self._dirty_deletes self._dirty_tokens = {} self._dirty_deletes = set() updates = [] deleted = list(dirty_deletes) for token_key, meta in dirty_tokens.items(): if token_key in dirty_deletes: continue pool_name, change_kind = meta pool = self.pools.get(pool_name) if not pool: continue info = pool.get(token_key) if not info: continue payload = info.model_dump() payload["pool_name"] = pool_name payload["_update_kind"] = change_kind updates.append(payload) storage = get_storage() async with storage.acquire_lock("tokens_save", timeout=10): await storage.save_tokens_delta(updates, deleted) if state_seq == self._state_change_seq: self._has_state_changes = False if usage_seq == self._usage_change_seq: self._has_usage_changes = False self._last_usage_flush_at = time.monotonic() except Exception as e: logger.error(f"Failed to save tokens: {e}") self._dirty = True if 'dirty_tokens' in locals(): for token_key, meta in dirty_tokens.items(): existing = self._dirty_tokens.get(token_key) if existing and existing[1] == "state": continue if meta[1] == "state" and existing: self._dirty_tokens[token_key] = (meta[0], "state") else: self._dirty_tokens[token_key] = meta self._dirty_deletes.update(dirty_deletes) for token_key in dirty_deletes: if token_key in self._dirty_tokens: del self._dirty_tokens[token_key] def _schedule_save(self): """合并高频保存请求,减少写入开销""" delay_ms = get_config("token.save_delay_ms", DEFAULT_SAVE_DELAY_MS) try: delay_ms = float(delay_ms) except Exception: delay_ms = float(DEFAULT_SAVE_DELAY_MS) self._save_delay = max(0.0, delay_ms / 1000.0) self._dirty = True if self._save_delay == 0: if self._save_task and not self._save_task.done(): return self._save_task = asyncio.create_task(self._save()) return if self._save_task and not self._save_task.done(): return self._save_task = asyncio.create_task(self._flush_loop()) async def _flush_loop(self): try: while True: await asyncio.sleep(self._save_delay) if not self._dirty: break self._dirty = False await self._save() finally: self._save_task = None if self._dirty: self._schedule_save() def get_token(self, pool_name: str = "ssoBasic", exclude: set = None) -> Optional[str]: """ 获取可用 Token Args: pool_name: Token 池名称 exclude: 需要排除的 token 字符串集合 Returns: Token 字符串或 None """ pool = self.pools.get(pool_name) if not pool: logger.warning(f"Pool '{pool_name}' not found") return None token_info = pool.select(exclude=exclude) if not token_info: logger.warning(f"No available token in pool '{pool_name}'") return None token = token_info.token if token.startswith("sso="): return token[4:] return token def get_token_info(self, pool_name: str = "ssoBasic") -> Optional["TokenInfo"]: """ 获取可用 Token 的完整信息 Args: pool_name: Token 池名称 Returns: TokenInfo 对象或 None """ pool = self.pools.get(pool_name) if not pool: logger.warning(f"Pool '{pool_name}' not found") return None token_info = pool.select() if not token_info: logger.warning(f"No available token in pool '{pool_name}'") return None return token_info def get_token_for_video( self, resolution: str = "480p", video_length: int = 6, pool_candidates: Optional[List[str]] = None, ) -> Optional["TokenInfo"]: """ 根据视频需求智能选择 Token 池 路由策略: - 如果 resolution 是 "720p" 或 video_length > 6: 优先使用 "ssoSuper" 池 - 否则优先使用 "ssoBasic" 池 - 当提供 pool_candidates 时,按候选池顺序回退 Args: resolution: 视频分辨率 ("480p" 或 "720p") video_length: 视频时长(秒) pool_candidates: 候选 Token 池(按优先级) Returns: TokenInfo 对象或 None(无可用 token) """ # 确定首选池 requires_super = resolution == "720p" or video_length > 6 primary_pool = SUPER_POOL_NAME if requires_super else BASIC_POOL_NAME if pool_candidates: ordered_pools = list(pool_candidates) if primary_pool in ordered_pools: ordered_pools.remove(primary_pool) ordered_pools.insert(0, primary_pool) else: fallback_pool = BASIC_POOL_NAME if requires_super else SUPER_POOL_NAME ordered_pools = [primary_pool, fallback_pool] for idx, pool_name in enumerate(ordered_pools): token_info = self.get_token_info(pool_name) if token_info: if idx == 0: logger.info( f"Video token routing: resolution={resolution}, length={video_length}s -> " f"pool={pool_name} (token={token_info.token[:10]}...)" ) else: logger.info( f"Video token routing: fallback from {ordered_pools[0]} -> {pool_name} " f"(token={token_info.token[:10]}...)" ) return token_info if idx == 0 and requires_super and pool_name == primary_pool: next_pool = ordered_pools[1] if len(ordered_pools) > 1 else None if next_pool: logger.warning( f"Video token routing: {primary_pool} pool has no available token for " f"resolution={resolution}, length={video_length}s. " f"Falling back to {next_pool} pool." ) # 两个池都没有可用 token logger.warning( f"Video token routing: no available token in any pool " f"(resolution={resolution}, length={video_length}s)" ) return None def get_pool_name_for_token(self, token_str: str) -> Optional[str]: """Return pool name for the given token string.""" raw_token = token_str.replace("sso=", "") for pool_name, pool in self.pools.items(): if pool.get(raw_token): return pool_name return None async def consume( self, token_str: str, effort: EffortType = EffortType.LOW ) -> bool: """ 消耗配额(本地预估) Args: token_str: Token 字符串 effort: 消耗力度 Returns: 是否成功 """ raw_token = token_str.replace("sso=", "") for pool in self.pools.values(): token = pool.get(raw_token) if token: old_status = token.status consumed = token.consume(effort) logger.debug( f"Token {raw_token[:10]}...: consumed {consumed} quota, use_count={token.use_count}" ) change_kind = "state" if token.status != old_status else "usage" self._track_token_change(token, pool.name, change_kind) self._schedule_save() return True logger.warning(f"Token {raw_token[:10]}...: not found for consumption") return False async def sync_usage( self, token_str: str, fallback_effort: EffortType = EffortType.LOW, consume_on_fail: bool = True, is_usage: bool = True, ) -> bool: """ 同步 Token 用量 优先从 API 获取最新配额,失败则降级到本地预估 Args: token_str: Token 字符串(可带 sso= 前缀) fallback_effort: 降级时的消耗力度 consume_on_fail: 失败时是否降级扣费 is_usage: 是否记录为一次使用(影响 use_count) Returns: 是否成功 """ raw_token = token_str.replace("sso=", "") # 查找 Token 对象 target_token: Optional[TokenInfo] = None target_pool_name: Optional[str] = None for pool in self.pools.values(): target_token = pool.get(raw_token) if target_token: target_pool_name = pool.name break if not target_token: logger.warning(f"Token {raw_token[:10]}...: not found for sync") return False # 尝试 API 同步 try: usage_service = UsageService() result = await usage_service.get(token_str) if result and "remainingTokens" in result: new_quota = result.get("remainingTokens") if new_quota is None: new_quota = result.get("remainingQueries") if new_quota is None: return False old_quota = target_token.quota old_status = target_token.status target_token.update_quota(new_quota) target_token.record_success(is_usage=is_usage) consumed = max(0, old_quota - new_quota) logger.info( f"Token {raw_token[:10]}...: synced quota " f"{old_quota} -> {new_quota} (consumed: {consumed}, use_count: {target_token.use_count})" ) if target_pool_name: change_kind = "state" if target_token.status != old_status else "usage" self._track_token_change( target_token, target_pool_name, change_kind ) self._schedule_save() return True except Exception as e: if isinstance(e, UpstreamException): status = None if e.details and "status" in e.details: status = e.details["status"] else: status = getattr(e, "status_code", None) if status == 401: await self.record_fail(token_str, status, "rate_limits_auth_failed") logger.warning( f"Token {raw_token[:10]}...: API sync failed, fallback to local ({e})" ) # 降级:本地预估扣费 if consume_on_fail: logger.debug(f"Token {raw_token[:10]}...: using local consumption") return await self.consume(token_str, fallback_effort) else: logger.debug( f"Token {raw_token[:10]}...: sync failed, skipping local consumption" ) return False async def record_fail( self, token_str: str, status_code: int = 401, reason: str = "" ) -> bool: """ 记录 Token 失败 Args: token_str: Token 字符串 status_code: HTTP 状态码 reason: 失败原因 Returns: 是否成功 """ raw_token = token_str.replace("sso=", "") for pool in self.pools.values(): token = pool.get(raw_token) if token: if status_code == 401: threshold = get_config("token.fail_threshold", FAIL_THRESHOLD) try: threshold = int(threshold) except (TypeError, ValueError): threshold = FAIL_THRESHOLD if threshold < 1: threshold = 1 token.record_fail(status_code, reason, threshold=threshold) logger.warning( f"Token {raw_token[:10]}...: recorded {status_code} failure " f"({token.fail_count}/{threshold}) - {reason}" ) self._track_token_change(token, pool.name, "state") self._schedule_save() else: logger.info( f"Token {raw_token[:10]}...: non-auth error ({status_code}) - {reason} (not counted)" ) return True logger.warning(f"Token {raw_token[:10]}...: not found for failure record") return False async def mark_rate_limited(self, token_str: str) -> bool: """ 将 Token 标记为配额耗尽(COOLING) 当 Grok API 返回 429 时调用,将 quota 设为 0 并标记 COOLING, 使该 Token 不再被选中,等待下次 Scheduler 刷新恢复。 Args: token_str: Token 字符串 Returns: 是否成功 """ raw_token = token_str.removeprefix("sso=") for pool in self.pools.values(): token = pool.get(raw_token) if token: old_quota = token.quota token.quota = 0 token.status = TokenStatus.COOLING logger.warning( f"Token {raw_token[:10]}...: marked as rate limited " f"(quota {old_quota} -> 0, status -> cooling)" ) self._track_token_change(token, pool.name, "state") self._schedule_save() return True logger.warning(f"Token {raw_token[:10]}...: not found for rate limit marking") return False # ========== 管理功能 ========== async def add(self, token: str, pool_name: str = "ssoBasic") -> bool: """ 添加 Token Args: token: Token 字符串(不含 sso= 前缀) pool_name: 池名称 Returns: 是否成功 """ if pool_name not in self.pools: self.pools[pool_name] = TokenPool(pool_name) logger.info(f"Pool '{pool_name}': created") pool = self.pools[pool_name] token = token[4:] if token.startswith("sso=") else token if pool.get(token): logger.warning(f"Pool '{pool_name}': token already exists") return False token_info = TokenInfo(token=token, quota=_default_quota_for_pool(pool_name)) pool.add(token_info) self._track_token_change(token_info, pool_name, "state") await self._save(force=True) logger.info(f"Pool '{pool_name}': token added") return True async def mark_asset_clear(self, token: str) -> bool: """记录在线资产清理时间""" raw_token = token[4:] if token.startswith("sso=") else token for pool in self.pools.values(): info = pool.get(raw_token) if info: info.last_asset_clear_at = int(datetime.now().timestamp() * 1000) self._track_token_change(info, pool.name, "state") self._schedule_save() return True return False async def add_tag(self, token: str, tag: str) -> bool: """ 给 Token 添加标签 Args: token: Token 字符串 tag: 标签名称 Returns: 是否成功 """ raw_token = token[4:] if token.startswith("sso=") else token for pool in self.pools.values(): info = pool.get(raw_token) if info: if tag not in info.tags: info.tags.append(tag) self._track_token_change(info, pool.name, "state") self._schedule_save() logger.debug(f"Token {raw_token[:10]}...: added tag '{tag}'") return True return False async def remove_tag(self, token: str, tag: str) -> bool: """ 移除 Token 标签 Args: token: Token 字符串 tag: 标签名称 Returns: 是否成功 """ raw_token = token[4:] if token.startswith("sso=") else token for pool in self.pools.values(): info = pool.get(raw_token) if info: if tag in info.tags: info.tags.remove(tag) self._track_token_change(info, pool.name, "state") self._schedule_save() logger.debug(f"Token {raw_token[:10]}...: removed tag '{tag}'") return True return False async def remove(self, token: str) -> bool: """ 删除 Token Args: token: Token 字符串 Returns: 是否成功 """ for pool_name, pool in self.pools.items(): if pool.remove(token): self._track_token_delete(token) await self._save(force=True) logger.info(f"Pool '{pool_name}': token removed") return True logger.warning("Token not found for removal") return False async def reset_all(self): """重置所有 Token 配额""" count = 0 for pool_name, pool in self.pools.items(): default_quota = _default_quota_for_pool(pool_name) for token in pool: token.reset(default_quota) self._track_token_change(token, pool_name, "state") count += 1 await self._save(force=True) logger.info(f"Reset all: {count} tokens updated") async def reset_token(self, token_str: str) -> bool: """ 重置单个 Token Args: token_str: Token 字符串 Returns: 是否成功 """ raw_token = token_str.replace("sso=", "") for pool in self.pools.values(): token = pool.get(raw_token) if token: default_quota = _default_quota_for_pool(pool.name) token.reset(default_quota) self._track_token_change(token, pool.name, "state") await self._save(force=True) logger.info(f"Token {raw_token[:10]}...: reset completed") return True logger.warning(f"Token {raw_token[:10]}...: not found for reset") return False def get_stats(self) -> Dict[str, dict]: """获取统计信息""" stats = {} for name, pool in self.pools.items(): pool_stats = pool.get_stats() stats[name] = pool_stats.model_dump() return stats def get_pool_tokens(self, pool_name: str = "ssoBasic") -> List[TokenInfo]: """ 获取指定池的所有 Token Args: pool_name: 池名称 Returns: Token 列表 """ pool = self.pools.get(pool_name) if not pool: return [] return pool.list() async def refresh_cooling_tokens(self) -> Dict[str, int]: """ 批量刷新 cooling 状态的 Token 配额 Returns: {"checked": int, "refreshed": int, "recovered": int, "expired": int} """ # 收集需要刷新的 token to_refresh: List[tuple[str, TokenInfo]] = [] for pool in self.pools.values(): if pool.name == SUPER_POOL_NAME: interval_hours = get_config( "token.super_refresh_interval_hours", DEFAULT_SUPER_REFRESH_INTERVAL_HOURS, ) else: interval_hours = get_config( "token.refresh_interval_hours", DEFAULT_REFRESH_INTERVAL_HOURS, ) for token in pool: if token.need_refresh(interval_hours): to_refresh.append((pool.name, token)) if not to_refresh: logger.debug("Refresh check: no tokens need refresh") return {"checked": 0, "refreshed": 0, "recovered": 0, "expired": 0} logger.info(f"Refresh check: found {len(to_refresh)} cooling tokens to refresh") # 批量并发刷新 semaphore = asyncio.Semaphore(DEFAULT_REFRESH_CONCURRENCY) usage_service = UsageService() refreshed = 0 recovered = 0 expired = 0 async def _refresh_one(item: tuple[str, TokenInfo]) -> dict: """刷新单个 token""" _, token_info = item async with semaphore: token_str = token_info.token if token_str.startswith("sso="): token_str = token_str[4:] # 重试逻辑:最多 2 次重试 for retry in range(3): # 0, 1, 2 try: result = await usage_service.get(token_str) if result and "remainingTokens" in result: new_quota = result.get("remainingTokens") if new_quota is None: new_quota = result.get("remainingQueries") if new_quota is None: return {"recovered": False, "expired": False} old_quota = token_info.quota old_status = token_info.status token_info.update_quota(new_quota) token_info.mark_synced() logger.info( f"Token {token_info.token[:10]}...: refreshed " f"{old_quota} -> {new_quota}, status: {old_status} -> {token_info.status}" ) return { "recovered": new_quota > 0 and old_quota == 0, "expired": False, } return {"recovered": False, "expired": False} except Exception as e: error_str = str(e) # 检查是否为 401 错误 if "401" in error_str or "Unauthorized" in error_str: if retry < 2: logger.warning( f"Token {token_info.token[:10]}...: 401 error, " f"retry {retry + 1}/2..." ) await asyncio.sleep(0.5) continue else: # 重试 2 次后仍然 401,标记为 expired logger.error( f"Token {token_info.token[:10]}...: 401 after 2 retries, " f"marking as expired" ) token_info.status = TokenStatus.EXPIRED return {"recovered": False, "expired": True} else: logger.warning( f"Token {token_info.token[:10]}...: refresh failed ({e})" ) return {"recovered": False, "expired": False} return {"recovered": False, "expired": False} # 批量处理 for i in range(0, len(to_refresh), DEFAULT_REFRESH_BATCH_SIZE): batch = to_refresh[i : i + DEFAULT_REFRESH_BATCH_SIZE] results = await asyncio.gather(*[_refresh_one(t) for t in batch]) refreshed += len(batch) recovered += sum(r["recovered"] for r in results) expired += sum(r["expired"] for r in results) # 批次间延迟 if i + DEFAULT_REFRESH_BATCH_SIZE < len(to_refresh): await asyncio.sleep(1) for pool_name, token_info in to_refresh: self._track_token_change(token_info, pool_name, "state") await self._save(force=True) logger.info( f"Refresh completed: " f"checked={len(to_refresh)}, refreshed={refreshed}, " f"recovered={recovered}, expired={expired}" ) return { "checked": len(to_refresh), "refreshed": refreshed, "recovered": recovered, "expired": expired, } # 便捷函数 async def get_token_manager() -> TokenManager: """获取 TokenManager 单例""" return await TokenManager.get_instance() __all__ = ["TokenManager", "get_token_manager"]