| import asyncio |
|
|
| import orjson |
| from fastapi import APIRouter, Depends, HTTPException, Request |
| from fastapi.responses import StreamingResponse |
|
|
| from app.core.auth import get_app_key, verify_app_key |
| from app.core.batch import create_task, expire_task, get_task |
| from app.core.logger import logger |
| from app.core.storage import get_storage |
| from app.services.grok.batch_services.usage import UsageService |
| from app.services.grok.batch_services.nsfw import NSFWService |
| from app.services.token.manager import get_token_manager |
|
|
| router = APIRouter() |
|
|
|
|
| @router.get("/tokens", dependencies=[Depends(verify_app_key)]) |
| async def get_tokens(): |
| """获取所有 Token""" |
| storage = get_storage() |
| tokens = await storage.load_tokens() |
| return tokens or {} |
|
|
|
|
| @router.post("/tokens", dependencies=[Depends(verify_app_key)]) |
| async def update_tokens(data: dict): |
| """更新 Token 信息""" |
| storage = get_storage() |
| try: |
| from app.services.token.models import TokenInfo |
|
|
| async with storage.acquire_lock("tokens_save", timeout=10): |
| existing = await storage.load_tokens() or {} |
| normalized = {} |
| allowed_fields = set(TokenInfo.model_fields.keys()) |
| existing_map = {} |
| for pool_name, tokens in existing.items(): |
| if not isinstance(tokens, list): |
| continue |
| pool_map = {} |
| for item in tokens: |
| if isinstance(item, str): |
| token_data = {"token": item} |
| elif isinstance(item, dict): |
| token_data = dict(item) |
| else: |
| continue |
| raw_token = token_data.get("token") |
| if isinstance(raw_token, str) and raw_token.startswith("sso="): |
| token_data["token"] = raw_token[4:] |
| token_key = token_data.get("token") |
| if isinstance(token_key, str): |
| pool_map[token_key] = token_data |
| existing_map[pool_name] = pool_map |
| for pool_name, tokens in (data or {}).items(): |
| if not isinstance(tokens, list): |
| continue |
| pool_list = [] |
| for item in tokens: |
| if isinstance(item, str): |
| token_data = {"token": item} |
| elif isinstance(item, dict): |
| token_data = dict(item) |
| else: |
| continue |
|
|
| raw_token = token_data.get("token") |
| if isinstance(raw_token, str) and raw_token.startswith("sso="): |
| token_data["token"] = raw_token[4:] |
|
|
| base = existing_map.get(pool_name, {}).get( |
| token_data.get("token"), {} |
| ) |
| merged = dict(base) |
| merged.update(token_data) |
| if merged.get("tags") is None: |
| merged["tags"] = [] |
|
|
| filtered = {k: v for k, v in merged.items() if k in allowed_fields} |
| try: |
| info = TokenInfo(**filtered) |
| pool_list.append(info.model_dump()) |
| except Exception as e: |
| logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") |
| continue |
| normalized[pool_name] = pool_list |
|
|
| await storage.save_tokens(normalized) |
| mgr = await get_token_manager() |
| await mgr.reload() |
| return {"status": "success", "message": "Token 已更新"} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)]) |
| async def refresh_tokens(data: dict): |
| """刷新 Token 状态""" |
| try: |
| mgr = await get_token_manager() |
| tokens = [] |
| if isinstance(data.get("token"), str) and data["token"].strip(): |
| tokens.append(data["token"].strip()) |
| if isinstance(data.get("tokens"), list): |
| tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| unique_tokens = list(dict.fromkeys(tokens)) |
|
|
| raw_results = await UsageService.batch( |
| unique_tokens, |
| mgr, |
| ) |
|
|
| results = {} |
| for token, res in raw_results.items(): |
| if res.get("ok"): |
| results[token] = res.get("data", False) |
| else: |
| results[token] = False |
|
|
| response = {"status": "success", "results": results} |
| return response |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)]) |
| async def refresh_tokens_async(data: dict): |
| """刷新 Token 状态(异步批量 + SSE 进度)""" |
| mgr = await get_token_manager() |
| tokens = [] |
| if isinstance(data.get("token"), str) and data["token"].strip(): |
| tokens.append(data["token"].strip()) |
| if isinstance(data.get("tokens"), list): |
| tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| unique_tokens = list(dict.fromkeys(tokens)) |
|
|
| task = create_task(len(unique_tokens)) |
|
|
| async def _run(): |
| try: |
|
|
| async def _on_item(item: str, res: dict): |
| task.record(bool(res.get("ok"))) |
|
|
| raw_results = await UsageService.batch( |
| unique_tokens, |
| mgr, |
| on_item=_on_item, |
| should_cancel=lambda: task.cancelled, |
| ) |
|
|
| if task.cancelled: |
| task.finish_cancelled() |
| return |
|
|
| results: dict[str, bool] = {} |
| ok_count = 0 |
| fail_count = 0 |
| for token, res in raw_results.items(): |
| if res.get("ok") and res.get("data") is True: |
| ok_count += 1 |
| results[token] = True |
| else: |
| fail_count += 1 |
| results[token] = False |
|
|
| await mgr._save(force=True) |
|
|
| result = { |
| "status": "success", |
| "summary": { |
| "total": len(unique_tokens), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
| task.finish(result) |
| except Exception as e: |
| task.fail_task(str(e)) |
| finally: |
| import asyncio |
| asyncio.create_task(expire_task(task.id, 300)) |
|
|
| import asyncio |
| asyncio.create_task(_run()) |
|
|
| return { |
| "status": "success", |
| "task_id": task.id, |
| "total": len(unique_tokens), |
| } |
|
|
|
|
| @router.get("/batch/{task_id}/stream") |
| async def batch_stream(task_id: str, request: Request): |
| app_key = get_app_key() |
| if app_key: |
| key = request.query_params.get("app_key") |
| if key != app_key: |
| raise HTTPException(status_code=401, detail="Invalid authentication token") |
| task = get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="Task not found") |
|
|
| async def event_stream(): |
| queue = task.attach() |
| try: |
| yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n" |
|
|
| final = task.final_event() |
| if final: |
| yield f"data: {orjson.dumps(final).decode()}\n\n" |
| return |
|
|
| while True: |
| try: |
| event = await asyncio.wait_for(queue.get(), timeout=15) |
| except asyncio.TimeoutError: |
| yield ": ping\n\n" |
| final = task.final_event() |
| if final: |
| yield f"data: {orjson.dumps(final).decode()}\n\n" |
| return |
| continue |
|
|
| yield f"data: {orjson.dumps(event).decode()}\n\n" |
| if event.get("type") in ("done", "error", "cancelled"): |
| return |
| finally: |
| task.detach(queue) |
|
|
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
|
|
| @router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)]) |
| async def batch_cancel(task_id: str): |
| task = get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="Task not found") |
| task.cancel() |
| return {"status": "success"} |
|
|
|
|
| @router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)]) |
| async def enable_nsfw(data: dict): |
| """批量开启 NSFW (Unhinged) 模式""" |
| try: |
| mgr = await get_token_manager() |
|
|
| tokens = [] |
| if isinstance(data.get("token"), str) and data["token"].strip(): |
| tokens.append(data["token"].strip()) |
| if isinstance(data.get("tokens"), list): |
| tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) |
|
|
| if not tokens: |
| for pool_name, pool in mgr.pools.items(): |
| for info in pool.list(): |
| raw = ( |
| info.token[4:] if info.token.startswith("sso=") else info.token |
| ) |
| tokens.append(raw) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens available") |
|
|
| unique_tokens = list(dict.fromkeys(tokens)) |
|
|
| raw_results = await NSFWService.batch( |
| unique_tokens, |
| mgr, |
| ) |
|
|
| results = {} |
| ok_count = 0 |
| fail_count = 0 |
|
|
| for token, res in raw_results.items(): |
| masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token |
| if res.get("ok") and res.get("data", {}).get("success"): |
| ok_count += 1 |
| results[masked] = res.get("data", {}) |
| else: |
| fail_count += 1 |
| results[masked] = res.get("data") or {"error": res.get("error")} |
|
|
| response = { |
| "status": "success", |
| "summary": { |
| "total": len(unique_tokens), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
|
|
| return response |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Enable NSFW failed: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)]) |
| async def enable_nsfw_async(data: dict): |
| """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" |
| mgr = await get_token_manager() |
|
|
| tokens = [] |
| if isinstance(data.get("token"), str) and data["token"].strip(): |
| tokens.append(data["token"].strip()) |
| if isinstance(data.get("tokens"), list): |
| tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) |
|
|
| if not tokens: |
| for pool_name, pool in mgr.pools.items(): |
| for info in pool.list(): |
| raw = info.token[4:] if info.token.startswith("sso=") else info.token |
| tokens.append(raw) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens available") |
|
|
| unique_tokens = list(dict.fromkeys(tokens)) |
|
|
| task = create_task(len(unique_tokens)) |
|
|
| async def _run(): |
| try: |
|
|
| async def _on_item(item: str, res: dict): |
| ok = bool(res.get("ok") and res.get("data", {}).get("success")) |
| task.record(ok) |
|
|
| raw_results = await NSFWService.batch( |
| unique_tokens, |
| mgr, |
| on_item=_on_item, |
| should_cancel=lambda: task.cancelled, |
| ) |
|
|
| if task.cancelled: |
| task.finish_cancelled() |
| return |
|
|
| results = {} |
| ok_count = 0 |
| fail_count = 0 |
| for token, res in raw_results.items(): |
| masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token |
| if res.get("ok") and res.get("data", {}).get("success"): |
| ok_count += 1 |
| results[masked] = res.get("data", {}) |
| else: |
| fail_count += 1 |
| results[masked] = res.get("data") or {"error": res.get("error")} |
|
|
| await mgr._save(force=True) |
|
|
| result = { |
| "status": "success", |
| "summary": { |
| "total": len(unique_tokens), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
| task.finish(result) |
| except Exception as e: |
| task.fail_task(str(e)) |
| finally: |
| import asyncio |
| asyncio.create_task(expire_task(task.id, 300)) |
|
|
| import asyncio |
| asyncio.create_task(_run()) |
|
|
| return { |
| "status": "success", |
| "task_id": task.id, |
| "total": len(unique_tokens), |
| } |
|
|