import asyncio import orjson from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from app.core.auth import get_app_key, verify_app_key from app.core.batch import create_task, expire_task, get_task from app.core.logger import logger from app.core.storage import get_storage from app.services.grok.batch_services.usage import UsageService from app.services.grok.batch_services.nsfw import NSFWService from app.services.token.manager import get_token_manager router = APIRouter() @router.get("/tokens", dependencies=[Depends(verify_app_key)]) async def get_tokens(): """获取所有 Token""" storage = get_storage() tokens = await storage.load_tokens() return tokens or {} @router.post("/tokens", dependencies=[Depends(verify_app_key)]) async def update_tokens(data: dict): """更新 Token 信息""" storage = get_storage() try: from app.services.token.models import TokenInfo async with storage.acquire_lock("tokens_save", timeout=10): existing = await storage.load_tokens() or {} normalized = {} allowed_fields = set(TokenInfo.model_fields.keys()) existing_map = {} for pool_name, tokens in existing.items(): if not isinstance(tokens, list): continue pool_map = {} for item in tokens: if isinstance(item, str): token_data = {"token": item} elif isinstance(item, dict): token_data = dict(item) else: continue raw_token = token_data.get("token") if isinstance(raw_token, str) and raw_token.startswith("sso="): token_data["token"] = raw_token[4:] token_key = token_data.get("token") if isinstance(token_key, str): pool_map[token_key] = token_data existing_map[pool_name] = pool_map for pool_name, tokens in (data or {}).items(): if not isinstance(tokens, list): continue pool_list = [] for item in tokens: if isinstance(item, str): token_data = {"token": item} elif isinstance(item, dict): token_data = dict(item) else: continue raw_token = token_data.get("token") if isinstance(raw_token, str) and raw_token.startswith("sso="): token_data["token"] = raw_token[4:] base = existing_map.get(pool_name, {}).get( token_data.get("token"), {} ) merged = dict(base) merged.update(token_data) if merged.get("tags") is None: merged["tags"] = [] filtered = {k: v for k, v in merged.items() if k in allowed_fields} try: info = TokenInfo(**filtered) pool_list.append(info.model_dump()) except Exception as e: logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") continue normalized[pool_name] = pool_list await storage.save_tokens(normalized) mgr = await get_token_manager() await mgr.reload() return {"status": "success", "message": "Token 已更新"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)]) async def refresh_tokens(data: dict): """刷新 Token 状态""" try: mgr = await get_token_manager() tokens = [] if isinstance(data.get("token"), str) and data["token"].strip(): tokens.append(data["token"].strip()) if isinstance(data.get("tokens"), list): tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) if not tokens: raise HTTPException(status_code=400, detail="No tokens provided") unique_tokens = list(dict.fromkeys(tokens)) raw_results = await UsageService.batch( unique_tokens, mgr, ) results = {} for token, res in raw_results.items(): if res.get("ok"): results[token] = res.get("data", False) else: results[token] = False response = {"status": "success", "results": results} return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)]) async def refresh_tokens_async(data: dict): """刷新 Token 状态(异步批量 + SSE 进度)""" mgr = await get_token_manager() tokens = [] if isinstance(data.get("token"), str) and data["token"].strip(): tokens.append(data["token"].strip()) if isinstance(data.get("tokens"), list): tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) if not tokens: raise HTTPException(status_code=400, detail="No tokens provided") unique_tokens = list(dict.fromkeys(tokens)) task = create_task(len(unique_tokens)) async def _run(): try: async def _on_item(item: str, res: dict): task.record(bool(res.get("ok"))) raw_results = await UsageService.batch( unique_tokens, mgr, on_item=_on_item, should_cancel=lambda: task.cancelled, ) if task.cancelled: task.finish_cancelled() return results: dict[str, bool] = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): if res.get("ok") and res.get("data") is True: ok_count += 1 results[token] = True else: fail_count += 1 results[token] = False await mgr._save(force=True) result = { "status": "success", "summary": { "total": len(unique_tokens), "ok": ok_count, "fail": fail_count, }, "results": results, } task.finish(result) except Exception as e: task.fail_task(str(e)) finally: import asyncio asyncio.create_task(expire_task(task.id, 300)) import asyncio asyncio.create_task(_run()) return { "status": "success", "task_id": task.id, "total": len(unique_tokens), } @router.get("/batch/{task_id}/stream") async def batch_stream(task_id: str, request: Request): app_key = get_app_key() if app_key: key = request.query_params.get("app_key") if key != app_key: raise HTTPException(status_code=401, detail="Invalid authentication token") task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") async def event_stream(): queue = task.attach() try: yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n" final = task.final_event() if final: yield f"data: {orjson.dumps(final).decode()}\n\n" return while True: try: event = await asyncio.wait_for(queue.get(), timeout=15) except asyncio.TimeoutError: yield ": ping\n\n" final = task.final_event() if final: yield f"data: {orjson.dumps(final).decode()}\n\n" return continue yield f"data: {orjson.dumps(event).decode()}\n\n" if event.get("type") in ("done", "error", "cancelled"): return finally: task.detach(queue) return StreamingResponse(event_stream(), media_type="text/event-stream") @router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)]) async def batch_cancel(task_id: str): task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") task.cancel() return {"status": "success"} @router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)]) async def enable_nsfw(data: dict): """批量开启 NSFW (Unhinged) 模式""" try: mgr = await get_token_manager() tokens = [] if isinstance(data.get("token"), str) and data["token"].strip(): tokens.append(data["token"].strip()) if isinstance(data.get("tokens"), list): tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) if not tokens: for pool_name, pool in mgr.pools.items(): for info in pool.list(): raw = ( info.token[4:] if info.token.startswith("sso=") else info.token ) tokens.append(raw) if not tokens: raise HTTPException(status_code=400, detail="No tokens available") unique_tokens = list(dict.fromkeys(tokens)) raw_results = await NSFWService.batch( unique_tokens, mgr, ) results = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token if res.get("ok") and res.get("data", {}).get("success"): ok_count += 1 results[masked] = res.get("data", {}) else: fail_count += 1 results[masked] = res.get("data") or {"error": res.get("error")} response = { "status": "success", "summary": { "total": len(unique_tokens), "ok": ok_count, "fail": fail_count, }, "results": results, } return response except HTTPException: raise except Exception as e: logger.error(f"Enable NSFW failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)]) async def enable_nsfw_async(data: dict): """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" mgr = await get_token_manager() tokens = [] if isinstance(data.get("token"), str) and data["token"].strip(): tokens.append(data["token"].strip()) if isinstance(data.get("tokens"), list): tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) if not tokens: for pool_name, pool in mgr.pools.items(): for info in pool.list(): raw = info.token[4:] if info.token.startswith("sso=") else info.token tokens.append(raw) if not tokens: raise HTTPException(status_code=400, detail="No tokens available") unique_tokens = list(dict.fromkeys(tokens)) task = create_task(len(unique_tokens)) async def _run(): try: async def _on_item(item: str, res: dict): ok = bool(res.get("ok") and res.get("data", {}).get("success")) task.record(ok) raw_results = await NSFWService.batch( unique_tokens, mgr, on_item=_on_item, should_cancel=lambda: task.cancelled, ) if task.cancelled: task.finish_cancelled() return results = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token if res.get("ok") and res.get("data", {}).get("success"): ok_count += 1 results[masked] = res.get("data", {}) else: fail_count += 1 results[masked] = res.get("data") or {"error": res.get("error")} await mgr._save(force=True) result = { "status": "success", "summary": { "total": len(unique_tokens), "ok": ok_count, "fail": fail_count, }, "results": results, } task.finish(result) except Exception as e: task.fail_task(str(e)) finally: import asyncio asyncio.create_task(expire_task(task.id, 300)) import asyncio asyncio.create_task(_run()) return { "status": "success", "task_id": task.id, "total": len(unique_tokens), }