| from fastapi import ( |
| APIRouter, |
| Depends, |
| HTTPException, |
| Request, |
| Query, |
| WebSocket, |
| WebSocketDisconnect, |
| ) |
| from fastapi.responses import HTMLResponse, StreamingResponse, RedirectResponse |
| from typing import Optional |
| from pydantic import BaseModel |
| from app.core.auth import verify_api_key, verify_app_key, get_admin_api_key |
| from app.core.config import config, get_config |
| from app.core.batch_tasks import create_task, get_task, expire_task |
| from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage |
| from app.core.exceptions import AppException |
| from app.services.token.manager import get_token_manager |
| from app.services.grok.utils.batch import run_in_batches |
| import os |
| import time |
| import uuid |
| from pathlib import Path |
| import aiofiles |
| import asyncio |
| import orjson |
| from app.core.logger import logger |
| from app.api.v1.image import resolve_aspect_ratio |
| from app.services.grok.services.voice import VoiceService |
| from app.services.grok.services.image import image_service |
| from app.services.grok.models.model import ModelService |
| from app.services.grok.processors.image_ws_processors import ImageWSCollectProcessor |
| from app.services.token import EffortType |
|
|
| TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static" |
|
|
|
|
| router = APIRouter() |
|
|
| IMAGINE_SESSION_TTL = 600 |
| _IMAGINE_SESSIONS: dict[str, dict] = {} |
| _IMAGINE_SESSIONS_LOCK = asyncio.Lock() |
|
|
|
|
| async def _cleanup_imagine_sessions(now: float) -> None: |
| expired = [ |
| key |
| for key, info in _IMAGINE_SESSIONS.items() |
| if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL |
| ] |
| for key in expired: |
| _IMAGINE_SESSIONS.pop(key, None) |
|
|
|
|
| async def _create_imagine_session(prompt: str, aspect_ratio: str) -> str: |
| task_id = uuid.uuid4().hex |
| now = time.time() |
| async with _IMAGINE_SESSIONS_LOCK: |
| await _cleanup_imagine_sessions(now) |
| _IMAGINE_SESSIONS[task_id] = { |
| "prompt": prompt, |
| "aspect_ratio": aspect_ratio, |
| "created_at": now, |
| } |
| return task_id |
|
|
|
|
| async def _get_imagine_session(task_id: str) -> Optional[dict]: |
| if not task_id: |
| return None |
| now = time.time() |
| async with _IMAGINE_SESSIONS_LOCK: |
| await _cleanup_imagine_sessions(now) |
| info = _IMAGINE_SESSIONS.get(task_id) |
| if not info: |
| return None |
| created_at = float(info.get("created_at") or 0) |
| if now - created_at > IMAGINE_SESSION_TTL: |
| _IMAGINE_SESSIONS.pop(task_id, None) |
| return None |
| return dict(info) |
|
|
|
|
| async def _delete_imagine_session(task_id: str) -> None: |
| if not task_id: |
| return |
| async with _IMAGINE_SESSIONS_LOCK: |
| _IMAGINE_SESSIONS.pop(task_id, None) |
|
|
|
|
| async def _delete_imagine_sessions(task_ids: list[str]) -> int: |
| if not task_ids: |
| return 0 |
| removed = 0 |
| async with _IMAGINE_SESSIONS_LOCK: |
| for task_id in task_ids: |
| if task_id and task_id in _IMAGINE_SESSIONS: |
| _IMAGINE_SESSIONS.pop(task_id, None) |
| removed += 1 |
| return removed |
|
|
|
|
| def _collect_tokens(data: dict) -> list[str]: |
| """从请求数据中收集 token 列表""" |
| tokens = [] |
| if isinstance(data.get("token"), str) and data["token"].strip(): |
| tokens.append(data["token"].strip()) |
| if isinstance(data.get("tokens"), list): |
| tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) |
| return tokens |
|
|
|
|
| def _truncate_tokens( |
| tokens: list[str], max_tokens: int, operation: str = "operation" |
| ) -> tuple[list[str], bool, int]: |
| """去重并截断 token 列表,返回 (unique_tokens, truncated, original_count)""" |
| unique_tokens = list(dict.fromkeys(tokens)) |
| original_count = len(unique_tokens) |
| truncated = False |
|
|
| if len(unique_tokens) > max_tokens: |
| unique_tokens = unique_tokens[:max_tokens] |
| truncated = True |
| logger.warning( |
| f"{operation}: truncated from {original_count} to {max_tokens} tokens" |
| ) |
|
|
| return unique_tokens, truncated, original_count |
|
|
|
|
| def _mask_token(token: str) -> str: |
| """掩码 token 显示""" |
| return f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token |
|
|
|
|
| async def render_template(filename: str): |
| """渲染指定模板""" |
| template_path = TEMPLATE_DIR / filename |
| if not template_path.exists(): |
| return HTMLResponse(f"Template {filename} not found.", status_code=404) |
|
|
| async with aiofiles.open(template_path, "r", encoding="utf-8") as f: |
| content = await f.read() |
| return HTMLResponse(content) |
|
|
|
|
| def _sse_event(payload: dict) -> str: |
| return f"data: {orjson.dumps(payload).decode()}\n\n" |
|
|
|
|
| def _verify_stream_api_key(request: Request) -> None: |
| api_key = get_admin_api_key() |
| if not api_key: |
| return |
| key = request.query_params.get("api_key") |
| if key != api_key: |
| raise HTTPException(status_code=401, detail="Invalid authentication token") |
|
|
|
|
| @router.get("/api/v1/admin/batch/{task_id}/stream") |
| async def stream_batch(task_id: str, request: Request): |
| _verify_stream_api_key(request) |
| task = get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="Task not found") |
|
|
| async def event_stream(): |
| queue = task.attach() |
| try: |
| yield _sse_event({"type": "snapshot", **task.snapshot()}) |
|
|
| final = task.final_event() |
| if final: |
| yield _sse_event(final) |
| return |
|
|
| while True: |
| try: |
| event = await asyncio.wait_for(queue.get(), timeout=15) |
| except asyncio.TimeoutError: |
| yield ": ping\n\n" |
| final = task.final_event() |
| if final: |
| yield _sse_event(final) |
| return |
| continue |
|
|
| yield _sse_event(event) |
| if event.get("type") in ("done", "error", "cancelled"): |
| return |
| finally: |
| task.detach(queue) |
|
|
| return StreamingResponse(event_stream(), media_type="text/event-stream") |
|
|
|
|
| @router.post( |
| "/api/v1/admin/batch/{task_id}/cancel", dependencies=[Depends(verify_api_key)] |
| ) |
| async def cancel_batch(task_id: str): |
| task = get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="Task not found") |
| task.cancel() |
| return {"status": "success"} |
|
|
|
|
| @router.get("/admin", response_class=HTMLResponse, include_in_schema=False) |
| async def admin_login_page(): |
| """管理后台登录页""" |
| return await render_template("login/login.html") |
|
|
|
|
| @router.get("/", include_in_schema=False) |
| async def root_redirect(): |
| return RedirectResponse(url="/admin") |
|
|
|
|
| @router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False) |
| async def admin_config_page(): |
| """配置管理页""" |
| return await render_template("config/config.html") |
|
|
|
|
| @router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False) |
| async def admin_token_page(): |
| """Token 管理页""" |
| return await render_template("token/token.html") |
|
|
|
|
| @router.get("/admin/voice", response_class=HTMLResponse, include_in_schema=False) |
| async def admin_voice_page(): |
| """Voice Live 调试页""" |
| return await render_template("voice/voice.html") |
|
|
|
|
| @router.get("/admin/imagine", response_class=HTMLResponse, include_in_schema=False) |
| async def admin_imagine_page(): |
| """Imagine 图片瀑布流""" |
| return await render_template("imagine/imagine.html") |
|
|
|
|
| class VoiceTokenResponse(BaseModel): |
| token: str |
| url: str |
| participant_name: str = "" |
| room_name: str = "" |
|
|
|
|
| @router.get( |
| "/api/v1/admin/voice/token", |
| dependencies=[Depends(verify_api_key)], |
| response_model=VoiceTokenResponse, |
| ) |
| async def admin_voice_token( |
| voice: str = "ara", |
| personality: str = "assistant", |
| speed: float = 1.0, |
| ): |
| """获取 Grok Voice Mode (LiveKit) Token""" |
| token_mgr = await get_token_manager() |
| sso_token = None |
| for pool_name in ("ssoBasic", "ssoSuper"): |
| sso_token = token_mgr.get_token(pool_name) |
| if sso_token: |
| break |
|
|
| if not sso_token: |
| raise AppException( |
| "No available tokens for voice mode", |
| code="no_token", |
| status_code=503, |
| ) |
|
|
| service = VoiceService() |
| try: |
| data = await service.get_token( |
| token=sso_token, |
| voice=voice, |
| personality=personality, |
| speed=speed, |
| ) |
| token = data.get("token") |
| if not token: |
| raise AppException( |
| "Upstream returned no voice token", |
| code="upstream_error", |
| status_code=502, |
| ) |
|
|
| return VoiceTokenResponse( |
| token=token, |
| url="wss://livekit.grok.com", |
| participant_name="", |
| room_name="", |
| ) |
|
|
| except Exception as e: |
| if isinstance(e, AppException): |
| raise |
| raise AppException( |
| f"Voice token error: {str(e)}", |
| code="voice_error", |
| status_code=500, |
| ) |
|
|
|
|
| async def _verify_imagine_ws_auth(websocket: WebSocket) -> tuple[bool, Optional[str]]: |
| task_id = websocket.query_params.get("task_id") |
| if task_id: |
| info = await _get_imagine_session(task_id) |
| if info: |
| return True, task_id |
|
|
| api_key = get_admin_api_key() |
| if not api_key: |
| return True, None |
| key = websocket.query_params.get("api_key") |
| return key == api_key, None |
|
|
|
|
| @router.websocket("/api/v1/admin/imagine/ws") |
| async def admin_imagine_ws(websocket: WebSocket): |
| ok, session_id = await _verify_imagine_ws_auth(websocket) |
| if not ok: |
| await websocket.close(code=1008) |
| return |
|
|
| await websocket.accept() |
| stop_event = asyncio.Event() |
| run_task: Optional[asyncio.Task] = None |
|
|
| async def _send(payload: dict) -> bool: |
| try: |
| await websocket.send_text(orjson.dumps(payload).decode()) |
| return True |
| except Exception: |
| return False |
|
|
| async def _stop_run(): |
| nonlocal run_task |
| stop_event.set() |
| if run_task and not run_task.done(): |
| run_task.cancel() |
| try: |
| await run_task |
| except Exception: |
| pass |
| run_task = None |
| stop_event.clear() |
|
|
| async def _run(prompt: str, aspect_ratio: str): |
| model_id = "grok-imagine-1.0" |
| model_info = ModelService.get(model_id) |
| if not model_info or not model_info.is_image: |
| await _send( |
| { |
| "type": "error", |
| "message": "Image model is not available.", |
| "code": "model_not_supported", |
| } |
| ) |
| return |
|
|
| token_mgr = await get_token_manager() |
| enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) |
| sequence = 0 |
| run_id = uuid.uuid4().hex |
|
|
| await _send( |
| { |
| "type": "status", |
| "status": "running", |
| "prompt": prompt, |
| "aspect_ratio": aspect_ratio, |
| "run_id": run_id, |
| } |
| ) |
|
|
| while not stop_event.is_set(): |
| try: |
| await token_mgr.reload_if_stale() |
| token = None |
| for pool_name in ModelService.pool_candidates_for_model( |
| model_info.model_id |
| ): |
| token = token_mgr.get_token(pool_name) |
| if token: |
| break |
|
|
| if not token: |
| await _send( |
| { |
| "type": "error", |
| "message": "No available tokens. Please try again later.", |
| "code": "rate_limit_exceeded", |
| } |
| ) |
| await asyncio.sleep(2) |
| continue |
|
|
| upstream = image_service.stream( |
| token=token, |
| prompt=prompt, |
| aspect_ratio=aspect_ratio, |
| n=6, |
| enable_nsfw=enable_nsfw, |
| ) |
|
|
| processor = ImageWSCollectProcessor( |
| model_info.model_id, |
| token, |
| n=6, |
| response_format="b64_json", |
| ) |
|
|
| start_at = time.time() |
| images = await processor.process(upstream) |
| elapsed_ms = int((time.time() - start_at) * 1000) |
|
|
| if images and all(img and img != "error" for img in images): |
| |
| for img_b64 in images: |
| sequence += 1 |
| await _send( |
| { |
| "type": "image", |
| "b64_json": img_b64, |
| "sequence": sequence, |
| "created_at": int(time.time() * 1000), |
| "elapsed_ms": elapsed_ms, |
| "aspect_ratio": aspect_ratio, |
| "run_id": run_id, |
| } |
| ) |
|
|
| |
| try: |
| effort = ( |
| EffortType.HIGH |
| if (model_info and model_info.cost.value == "high") |
| else EffortType.LOW |
| ) |
| await token_mgr.consume(token, effort) |
| except Exception as e: |
| logger.warning(f"Failed to consume token: {e}") |
| else: |
| await _send( |
| { |
| "type": "error", |
| "message": "Image generation returned empty data.", |
| "code": "empty_image", |
| } |
| ) |
|
|
| except asyncio.CancelledError: |
| break |
| except Exception as e: |
| logger.warning(f"Imagine stream error: {e}") |
| await _send( |
| { |
| "type": "error", |
| "message": str(e), |
| "code": "internal_error", |
| } |
| ) |
| await asyncio.sleep(1.5) |
|
|
| await _send({"type": "status", "status": "stopped", "run_id": run_id}) |
|
|
| try: |
| while True: |
| try: |
| raw = await websocket.receive_text() |
| except (RuntimeError, WebSocketDisconnect): |
| |
| break |
| |
| try: |
| payload = orjson.loads(raw) |
| except Exception: |
| await _send( |
| { |
| "type": "error", |
| "message": "Invalid message format.", |
| "code": "invalid_payload", |
| } |
| ) |
| continue |
|
|
| msg_type = payload.get("type") |
| if msg_type == "start": |
| prompt = str(payload.get("prompt") or "").strip() |
| if not prompt: |
| await _send( |
| { |
| "type": "error", |
| "message": "Prompt cannot be empty.", |
| "code": "empty_prompt", |
| } |
| ) |
| continue |
| ratio = str(payload.get("aspect_ratio") or "2:3").strip() |
| if not ratio: |
| ratio = "2:3" |
| ratio = resolve_aspect_ratio(ratio) |
| await _stop_run() |
| stop_event.clear() |
| run_task = asyncio.create_task(_run(prompt, ratio)) |
| elif msg_type == "stop": |
| await _stop_run() |
| elif msg_type == "ping": |
| await _send({"type": "pong"}) |
| else: |
| await _send( |
| { |
| "type": "error", |
| "message": "Unknown command.", |
| "code": "unknown_command", |
| } |
| ) |
| except WebSocketDisconnect: |
| logger.debug("WebSocket disconnected by client") |
| except Exception as e: |
| logger.warning(f"WebSocket error: {e}") |
| finally: |
| await _stop_run() |
|
|
| try: |
| from starlette.websockets import WebSocketState |
| if websocket.client_state == WebSocketState.CONNECTED: |
| await websocket.close(code=1000, reason="Server closing connection") |
| except Exception as e: |
| logger.debug(f"WebSocket close ignored: {e}") |
| if session_id: |
| await _delete_imagine_session(session_id) |
|
|
|
|
| class ImagineStartRequest(BaseModel): |
| prompt: str |
| aspect_ratio: Optional[str] = "2:3" |
|
|
|
|
| @router.post("/api/v1/admin/imagine/start", dependencies=[Depends(verify_api_key)]) |
| async def admin_imagine_start(data: ImagineStartRequest): |
| prompt = (data.prompt or "").strip() |
| if not prompt: |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") |
| ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") |
| task_id = await _create_imagine_session(prompt, ratio) |
| return {"task_id": task_id, "aspect_ratio": ratio} |
|
|
|
|
| class ImagineStopRequest(BaseModel): |
| task_ids: list[str] |
|
|
|
|
| @router.post("/api/v1/admin/imagine/stop", dependencies=[Depends(verify_api_key)]) |
| async def admin_imagine_stop(data: ImagineStopRequest): |
| removed = await _delete_imagine_sessions(data.task_ids or []) |
| return {"status": "success", "removed": removed} |
|
|
|
|
| @router.get("/api/v1/admin/imagine/sse") |
| async def admin_imagine_sse( |
| request: Request, |
| task_id: str = Query(""), |
| prompt: str = Query(""), |
| aspect_ratio: str = Query("2:3"), |
| ): |
| """Imagine 图片瀑布流(SSE 兜底)""" |
| session = None |
| if task_id: |
| session = await _get_imagine_session(task_id) |
| if not session: |
| raise HTTPException(status_code=404, detail="Task not found") |
| else: |
| _verify_stream_api_key(request) |
|
|
| if session: |
| prompt = str(session.get("prompt") or "").strip() |
| ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" |
| else: |
| prompt = (prompt or "").strip() |
| if not prompt: |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") |
| ratio = str(aspect_ratio or "2:3").strip() or "2:3" |
| ratio = resolve_aspect_ratio(ratio) |
|
|
| async def event_stream(): |
| try: |
| model_id = "grok-imagine-1.0" |
| model_info = ModelService.get(model_id) |
| if not model_info or not model_info.is_image: |
| yield _sse_event( |
| { |
| "type": "error", |
| "message": "Image model is not available.", |
| "code": "model_not_supported", |
| } |
| ) |
| return |
|
|
| token_mgr = await get_token_manager() |
| enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) |
| sequence = 0 |
| run_id = uuid.uuid4().hex |
|
|
| yield _sse_event( |
| { |
| "type": "status", |
| "status": "running", |
| "prompt": prompt, |
| "aspect_ratio": ratio, |
| "run_id": run_id, |
| } |
| ) |
|
|
| while True: |
| if await request.is_disconnected(): |
| break |
| if task_id: |
| session_alive = await _get_imagine_session(task_id) |
| if not session_alive: |
| break |
|
|
| try: |
| await token_mgr.reload_if_stale() |
| token = None |
| for pool_name in ModelService.pool_candidates_for_model( |
| model_info.model_id |
| ): |
| token = token_mgr.get_token(pool_name) |
| if token: |
| break |
|
|
| if not token: |
| yield _sse_event( |
| { |
| "type": "error", |
| "message": "No available tokens. Please try again later.", |
| "code": "rate_limit_exceeded", |
| } |
| ) |
| await asyncio.sleep(2) |
| continue |
|
|
| upstream = image_service.stream( |
| token=token, |
| prompt=prompt, |
| aspect_ratio=ratio, |
| n=6, |
| enable_nsfw=enable_nsfw, |
| ) |
|
|
| processor = ImageWSCollectProcessor( |
| model_info.model_id, |
| token, |
| n=6, |
| response_format="b64_json", |
| ) |
|
|
| start_at = time.time() |
| images = await processor.process(upstream) |
| elapsed_ms = int((time.time() - start_at) * 1000) |
|
|
| if images and all(img and img != "error" for img in images): |
| for img_b64 in images: |
| sequence += 1 |
| yield _sse_event( |
| { |
| "type": "image", |
| "b64_json": img_b64, |
| "sequence": sequence, |
| "created_at": int(time.time() * 1000), |
| "elapsed_ms": elapsed_ms, |
| "aspect_ratio": ratio, |
| "run_id": run_id, |
| } |
| ) |
|
|
| try: |
| effort = ( |
| EffortType.HIGH |
| if (model_info and model_info.cost.value == "high") |
| else EffortType.LOW |
| ) |
| await token_mgr.consume(token, effort) |
| except Exception as e: |
| logger.warning(f"Failed to consume token: {e}") |
| else: |
| yield _sse_event( |
| { |
| "type": "error", |
| "message": "Image generation returned empty data.", |
| "code": "empty_image", |
| } |
| ) |
| except asyncio.CancelledError: |
| break |
| except Exception as e: |
| logger.warning(f"Imagine SSE error: {e}") |
| yield _sse_event( |
| {"type": "error", "message": str(e), "code": "internal_error"} |
| ) |
| await asyncio.sleep(1.5) |
|
|
| yield _sse_event({"type": "status", "status": "stopped", "run_id": run_id}) |
| finally: |
| if task_id: |
| await _delete_imagine_session(task_id) |
|
|
| return StreamingResponse( |
| event_stream(), |
| media_type="text/event-stream", |
| headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, |
| ) |
|
|
|
|
| @router.post("/api/v1/admin/login", dependencies=[Depends(verify_app_key)]) |
| async def admin_login_api(): |
| """管理后台登录验证(使用 app_key)""" |
| return {"status": "success", "api_key": get_admin_api_key()} |
|
|
|
|
| @router.get("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) |
| async def get_config_api(): |
| """获取当前配置""" |
| |
| return config._config |
|
|
|
|
| @router.post("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) |
| async def update_config_api(data: dict): |
| """更新配置""" |
| try: |
| await config.update(data) |
| return {"status": "success", "message": "配置已更新"} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.get("/api/v1/admin/storage", dependencies=[Depends(verify_api_key)]) |
| async def get_storage_info(): |
| """获取当前存储模式""" |
| storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower() |
| if not storage_type: |
| storage_type = str(get_config("storage.type")).lower() |
| if not storage_type: |
| storage = get_storage() |
| if isinstance(storage, LocalStorage): |
| storage_type = "local" |
| elif isinstance(storage, RedisStorage): |
| storage_type = "redis" |
| elif isinstance(storage, SQLStorage): |
| storage_type = { |
| "mysql": "mysql", |
| "mariadb": "mysql", |
| "postgres": "pgsql", |
| "postgresql": "pgsql", |
| "pgsql": "pgsql", |
| }.get(storage.dialect, storage.dialect) |
| return {"type": storage_type or "local"} |
|
|
|
|
| @router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) |
| async def get_tokens_api(): |
| """获取所有 Token""" |
| storage = get_storage() |
| tokens = await storage.load_tokens() |
| return tokens or {} |
|
|
|
|
| @router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) |
| async def update_tokens_api(data: dict): |
| """更新 Token 信息""" |
| storage = get_storage() |
| try: |
| from app.services.token.manager import get_token_manager |
| from app.services.token.models import TokenInfo |
|
|
| async with storage.acquire_lock("tokens_save", timeout=10): |
| existing = await storage.load_tokens() or {} |
| normalized = {} |
| allowed_fields = set(TokenInfo.model_fields.keys()) |
| existing_map = {} |
| for pool_name, tokens in existing.items(): |
| if not isinstance(tokens, list): |
| continue |
| pool_map = {} |
| for item in tokens: |
| if isinstance(item, str): |
| token_data = {"token": item} |
| elif isinstance(item, dict): |
| token_data = dict(item) |
| else: |
| continue |
| raw_token = token_data.get("token") |
| if isinstance(raw_token, str) and raw_token.startswith("sso="): |
| token_data["token"] = raw_token[4:] |
| token_key = token_data.get("token") |
| if isinstance(token_key, str): |
| pool_map[token_key] = token_data |
| existing_map[pool_name] = pool_map |
| for pool_name, tokens in (data or {}).items(): |
| if not isinstance(tokens, list): |
| continue |
| pool_list = [] |
| for item in tokens: |
| if isinstance(item, str): |
| token_data = {"token": item} |
| elif isinstance(item, dict): |
| token_data = dict(item) |
| else: |
| continue |
|
|
| raw_token = token_data.get("token") |
| if isinstance(raw_token, str) and raw_token.startswith("sso="): |
| token_data["token"] = raw_token[4:] |
|
|
| base = existing_map.get(pool_name, {}).get( |
| token_data.get("token"), {} |
| ) |
| merged = dict(base) |
| merged.update(token_data) |
| if merged.get("tags") is None: |
| merged["tags"] = [] |
|
|
| filtered = {k: v for k, v in merged.items() if k in allowed_fields} |
| try: |
| info = TokenInfo(**filtered) |
| pool_list.append(info.model_dump()) |
| except Exception as e: |
| logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") |
| continue |
| normalized[pool_name] = pool_list |
|
|
| await storage.save_tokens(normalized) |
| mgr = await get_token_manager() |
| await mgr.reload() |
| return {"status": "success", "message": "Token 已更新"} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/api/v1/admin/tokens/refresh", dependencies=[Depends(verify_api_key)]) |
| async def refresh_tokens_api(data: dict): |
| """刷新 Token 状态""" |
| try: |
| mgr = await get_token_manager() |
| tokens = _collect_tokens(data) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| |
| max_tokens = int(get_config("performance.usage_max_tokens")) |
| unique_tokens, truncated, original_count = _truncate_tokens( |
| tokens, max_tokens, "Usage refresh" |
| ) |
|
|
| |
| max_concurrent = get_config("performance.usage_max_concurrent") |
| batch_size = get_config("performance.usage_batch_size") |
|
|
| async def _refresh_one(t): |
| return await mgr.sync_usage( |
| t, "grok-3", consume_on_fail=False, is_usage=False |
| ) |
|
|
| raw_results = await run_in_batches( |
| unique_tokens, |
| _refresh_one, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| ) |
|
|
| results = {} |
| for token, res in raw_results.items(): |
| if res.get("ok"): |
| results[token] = res.get("data", False) |
| else: |
| results[token] = False |
|
|
| response = {"status": "success", "results": results} |
| if truncated: |
| response["warning"] = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| return response |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post( |
| "/api/v1/admin/tokens/refresh/async", dependencies=[Depends(verify_api_key)] |
| ) |
| async def refresh_tokens_api_async(data: dict): |
| """刷新 Token 状态(异步批量 + SSE 进度)""" |
| mgr = await get_token_manager() |
| tokens = _collect_tokens(data) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| |
| max_tokens = int(get_config("performance.usage_max_tokens")) |
| unique_tokens, truncated, original_count = _truncate_tokens( |
| tokens, max_tokens, "Usage refresh" |
| ) |
|
|
| max_concurrent = get_config("performance.usage_max_concurrent") |
| batch_size = get_config("performance.usage_batch_size") |
|
|
| task = create_task(len(unique_tokens)) |
|
|
| async def _run(): |
| try: |
|
|
| async def _refresh_one(t: str): |
| return await mgr.sync_usage( |
| t, "grok-3", consume_on_fail=False, is_usage=False |
| ) |
|
|
| async def _on_item(item: str, res: dict): |
| task.record(bool(res.get("ok"))) |
|
|
| raw_results = await run_in_batches( |
| unique_tokens, |
| _refresh_one, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| on_item=_on_item, |
| should_cancel=lambda: task.cancelled, |
| ) |
|
|
| if task.cancelled: |
| task.finish_cancelled() |
| return |
|
|
| results: dict[str, bool] = {} |
| ok_count = 0 |
| fail_count = 0 |
| for token, res in raw_results.items(): |
| if res.get("ok") and res.get("data") is True: |
| ok_count += 1 |
| results[token] = True |
| else: |
| fail_count += 1 |
| results[token] = False |
|
|
| await mgr._save() |
|
|
| result = { |
| "status": "success", |
| "summary": { |
| "total": len(unique_tokens), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
| warning = None |
| if truncated: |
| warning = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| task.finish(result, warning=warning) |
| except Exception as e: |
| task.fail_task(str(e)) |
| finally: |
| asyncio.create_task(expire_task(task.id, 300)) |
|
|
| asyncio.create_task(_run()) |
|
|
| return { |
| "status": "success", |
| "task_id": task.id, |
| "total": len(unique_tokens), |
| } |
|
|
|
|
| @router.post("/api/v1/admin/tokens/nsfw/enable", dependencies=[Depends(verify_api_key)]) |
| async def enable_nsfw_api(data: dict): |
| """批量开启 NSFW (Unhinged) 模式""" |
| from app.services.grok.services.nsfw import NSFWService |
|
|
| try: |
| mgr = await get_token_manager() |
| nsfw_service = NSFWService() |
|
|
| |
| tokens = _collect_tokens(data) |
|
|
| |
| if not tokens: |
| for pool_name, pool in mgr.pools.items(): |
| for info in pool.list(): |
| raw = ( |
| info.token[4:] if info.token.startswith("sso=") else info.token |
| ) |
| tokens.append(raw) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens available") |
|
|
| |
| max_tokens = int(get_config("performance.nsfw_max_tokens")) |
| unique_tokens, truncated, original_count = _truncate_tokens( |
| tokens, max_tokens, "NSFW enable" |
| ) |
|
|
| |
| max_concurrent = get_config("performance.nsfw_max_concurrent") |
| batch_size = get_config("performance.nsfw_batch_size") |
|
|
| |
| async def _enable(token: str): |
| result = await nsfw_service.enable(token) |
| |
| if result.success: |
| await mgr.add_tag(token, "nsfw") |
| return { |
| "success": result.success, |
| "http_status": result.http_status, |
| "grpc_status": result.grpc_status, |
| "grpc_message": result.grpc_message, |
| "error": result.error, |
| } |
|
|
| |
| raw_results = await run_in_batches( |
| unique_tokens, _enable, max_concurrent=max_concurrent, batch_size=batch_size |
| ) |
|
|
| |
| results = {} |
| ok_count = 0 |
| fail_count = 0 |
|
|
| for token, res in raw_results.items(): |
| masked = _mask_token(token) |
| if res.get("ok") and res.get("data", {}).get("success"): |
| ok_count += 1 |
| results[masked] = res.get("data", {}) |
| else: |
| fail_count += 1 |
| results[masked] = res.get("data") or {"error": res.get("error")} |
|
|
| response = { |
| "status": "success", |
| "summary": { |
| "total": len(unique_tokens), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
|
|
| |
| if truncated: |
| response["warning"] = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
|
|
| return response |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Enable NSFW failed: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post( |
| "/api/v1/admin/tokens/nsfw/enable/async", dependencies=[Depends(verify_api_key)] |
| ) |
| async def enable_nsfw_api_async(data: dict): |
| """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" |
| from app.services.grok.services.nsfw import NSFWService |
|
|
| mgr = await get_token_manager() |
| nsfw_service = NSFWService() |
|
|
| tokens = _collect_tokens(data) |
|
|
| if not tokens: |
| for pool_name, pool in mgr.pools.items(): |
| for info in pool.list(): |
| raw = info.token[4:] if info.token.startswith("sso=") else info.token |
| tokens.append(raw) |
|
|
| if not tokens: |
| raise HTTPException(status_code=400, detail="No tokens available") |
|
|
| |
| max_tokens = int(get_config("performance.nsfw_max_tokens")) |
| unique_tokens, truncated, original_count = _truncate_tokens( |
| tokens, max_tokens, "NSFW enable" |
| ) |
|
|
| max_concurrent = get_config("performance.nsfw_max_concurrent") |
| batch_size = get_config("performance.nsfw_batch_size") |
|
|
| task = create_task(len(unique_tokens)) |
|
|
| async def _run(): |
| try: |
|
|
| async def _enable(token: str): |
| result = await nsfw_service.enable(token) |
| if result.success: |
| await mgr.add_tag(token, "nsfw") |
| return { |
| "success": result.success, |
| "http_status": result.http_status, |
| "grpc_status": result.grpc_status, |
| "grpc_message": result.grpc_message, |
| "error": result.error, |
| } |
|
|
| async def _on_item(item: str, res: dict): |
| ok = bool(res.get("ok") and res.get("data", {}).get("success")) |
| task.record(ok) |
|
|
| raw_results = await run_in_batches( |
| unique_tokens, |
| _enable, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| on_item=_on_item, |
| should_cancel=lambda: task.cancelled, |
| ) |
|
|
| if task.cancelled: |
| task.finish_cancelled() |
| return |
|
|
| results = {} |
| ok_count = 0 |
| fail_count = 0 |
| for token, res in raw_results.items(): |
| masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token |
| if res.get("ok") and res.get("data", {}).get("success"): |
| ok_count += 1 |
| results[masked] = res.get("data", {}) |
| else: |
| fail_count += 1 |
| results[masked] = res.get("data") or {"error": res.get("error")} |
|
|
| await mgr._save() |
|
|
| result = { |
| "status": "success", |
| "summary": { |
| "total": len(unique_tokens), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
| warning = None |
| if truncated: |
| warning = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| task.finish(result, warning=warning) |
| except Exception as e: |
| task.fail_task(str(e)) |
| finally: |
| asyncio.create_task(expire_task(task.id, 300)) |
|
|
| asyncio.create_task(_run()) |
|
|
| return { |
| "status": "success", |
| "task_id": task.id, |
| "total": len(unique_tokens), |
| } |
|
|
|
|
| @router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False) |
| async def admin_cache_page(): |
| """缓存管理页""" |
| return await render_template("cache/cache.html") |
|
|
|
|
| @router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)]) |
| async def get_cache_stats_api(request: Request): |
| """获取缓存统计""" |
| from app.services.grok.services.assets import DownloadService, ListService |
| from app.services.token.manager import get_token_manager |
| from app.services.grok.utils.batch import run_in_batches |
|
|
| try: |
| dl_service = DownloadService() |
| image_stats = dl_service.get_stats("image") |
| video_stats = dl_service.get_stats("video") |
|
|
| mgr = await get_token_manager() |
| pools = mgr.pools |
| accounts = [] |
| for pool_name, pool in pools.items(): |
| for info in pool.list(): |
| raw_token = ( |
| info.token[4:] if info.token.startswith("sso=") else info.token |
| ) |
| masked = ( |
| f"{raw_token[:8]}...{raw_token[-16:]}" |
| if len(raw_token) > 24 |
| else raw_token |
| ) |
| accounts.append( |
| { |
| "token": raw_token, |
| "token_masked": masked, |
| "pool": pool_name, |
| "status": info.status, |
| "last_asset_clear_at": info.last_asset_clear_at, |
| } |
| ) |
|
|
| scope = request.query_params.get("scope") |
| selected_token = request.query_params.get("token") |
| tokens_param = request.query_params.get("tokens") |
| selected_tokens = [] |
| if tokens_param: |
| selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()] |
|
|
| online_stats = { |
| "count": 0, |
| "status": "unknown", |
| "token": None, |
| "last_asset_clear_at": None, |
| } |
| online_details = [] |
| account_map = {a["token"]: a for a in accounts} |
| max_concurrent = max(1, int(get_config("performance.assets_max_concurrent"))) |
| batch_size = max(1, int(get_config("performance.assets_batch_size"))) |
| max_tokens = int(get_config("performance.assets_max_tokens")) |
|
|
| truncated = False |
| original_count = 0 |
|
|
| async def _fetch_assets(token: str): |
| list_service = ListService() |
| try: |
| return await list_service.count(token) |
| finally: |
| await list_service.close() |
|
|
| async def _fetch_detail(token: str): |
| account = account_map.get(token) |
| try: |
| count = await _fetch_assets(token) |
| return { |
| "detail": { |
| "token": token, |
| "token_masked": account["token_masked"] if account else token, |
| "count": count, |
| "status": "ok", |
| "last_asset_clear_at": account["last_asset_clear_at"] |
| if account |
| else None, |
| }, |
| "count": count, |
| } |
| except Exception as e: |
| return { |
| "detail": { |
| "token": token, |
| "token_masked": account["token_masked"] if account else token, |
| "count": 0, |
| "status": f"error: {str(e)}", |
| "last_asset_clear_at": account["last_asset_clear_at"] |
| if account |
| else None, |
| }, |
| "count": 0, |
| } |
|
|
| if selected_tokens: |
| selected_tokens, truncated, original_count = _truncate_tokens( |
| selected_tokens, max_tokens, "Assets fetch" |
| ) |
| total = 0 |
| raw_results = await run_in_batches( |
| selected_tokens, |
| _fetch_detail, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| ) |
| for token, res in raw_results.items(): |
| if res.get("ok"): |
| data = res.get("data", {}) |
| detail = data.get("detail") |
| total += data.get("count", 0) |
| else: |
| account = account_map.get(token) |
| detail = { |
| "token": token, |
| "token_masked": account["token_masked"] if account else token, |
| "count": 0, |
| "status": f"error: {res.get('error')}", |
| "last_asset_clear_at": account["last_asset_clear_at"] |
| if account |
| else None, |
| } |
| if detail: |
| online_details.append(detail) |
| online_stats = { |
| "count": total, |
| "status": "ok" if selected_tokens else "no_token", |
| "token": None, |
| "last_asset_clear_at": None, |
| } |
| scope = "selected" |
| elif scope == "all": |
| total = 0 |
| tokens = list(dict.fromkeys([account["token"] for account in accounts])) |
| original_count = len(tokens) |
| if len(tokens) > max_tokens: |
| tokens = tokens[:max_tokens] |
| truncated = True |
| raw_results = await run_in_batches( |
| tokens, |
| _fetch_detail, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| ) |
| for token, res in raw_results.items(): |
| if res.get("ok"): |
| data = res.get("data", {}) |
| detail = data.get("detail") |
| total += data.get("count", 0) |
| else: |
| account = account_map.get(token) |
| detail = { |
| "token": token, |
| "token_masked": account["token_masked"] if account else token, |
| "count": 0, |
| "status": f"error: {res.get('error')}", |
| "last_asset_clear_at": account["last_asset_clear_at"] |
| if account |
| else None, |
| } |
| if detail: |
| online_details.append(detail) |
| online_stats = { |
| "count": total, |
| "status": "ok" if accounts else "no_token", |
| "token": None, |
| "last_asset_clear_at": None, |
| } |
| else: |
| token = selected_token |
| if token: |
| try: |
| count = await _fetch_assets(token) |
| match = next((a for a in accounts if a["token"] == token), None) |
| online_stats = { |
| "count": count, |
| "status": "ok", |
| "token": token, |
| "token_masked": match["token_masked"] if match else token, |
| "last_asset_clear_at": match["last_asset_clear_at"] |
| if match |
| else None, |
| } |
| except Exception as e: |
| match = next((a for a in accounts if a["token"] == token), None) |
| online_stats = { |
| "count": 0, |
| "status": f"error: {str(e)}", |
| "token": token, |
| "token_masked": match["token_masked"] if match else token, |
| "last_asset_clear_at": match["last_asset_clear_at"] |
| if match |
| else None, |
| } |
| else: |
| online_stats = { |
| "count": 0, |
| "status": "not_loaded", |
| "token": None, |
| "last_asset_clear_at": None, |
| } |
|
|
| response = { |
| "local_image": image_stats, |
| "local_video": video_stats, |
| "online": online_stats, |
| "online_accounts": accounts, |
| "online_scope": scope or "none", |
| "online_details": online_details, |
| } |
| if truncated: |
| response["warning"] = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| return response |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post( |
| "/api/v1/admin/cache/online/load/async", dependencies=[Depends(verify_api_key)] |
| ) |
| async def load_online_cache_api_async(data: dict): |
| """在线资产统计(异步批量 + SSE 进度)""" |
| from app.services.grok.services.assets import DownloadService, ListService |
| from app.services.token.manager import get_token_manager |
| from app.services.grok.utils.batch import run_in_batches |
|
|
| mgr = await get_token_manager() |
|
|
| |
| accounts = [] |
| for pool_name, pool in mgr.pools.items(): |
| for info in pool.list(): |
| raw_token = info.token[4:] if info.token.startswith("sso=") else info.token |
| masked = ( |
| f"{raw_token[:8]}...{raw_token[-16:]}" |
| if len(raw_token) > 24 |
| else raw_token |
| ) |
| accounts.append( |
| { |
| "token": raw_token, |
| "token_masked": masked, |
| "pool": pool_name, |
| "status": info.status, |
| "last_asset_clear_at": info.last_asset_clear_at, |
| } |
| ) |
|
|
| account_map = {a["token"]: a for a in accounts} |
|
|
| tokens = data.get("tokens") |
| scope = data.get("scope") |
| selected_tokens: list[str] = [] |
| if isinstance(tokens, list): |
| selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] |
|
|
| if not selected_tokens and scope == "all": |
| selected_tokens = [account["token"] for account in accounts] |
| scope = "all" |
| elif selected_tokens: |
| scope = "selected" |
| else: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| max_tokens = int(get_config("performance.assets_max_tokens")) |
| selected_tokens, truncated, original_count = _truncate_tokens( |
| selected_tokens, max_tokens, "Assets load" |
| ) |
|
|
| max_concurrent = get_config("performance.assets_max_concurrent") |
| batch_size = get_config("performance.assets_batch_size") |
|
|
| task = create_task(len(selected_tokens)) |
|
|
| async def _run(): |
| try: |
| dl_service = DownloadService() |
| image_stats = dl_service.get_stats("image") |
| video_stats = dl_service.get_stats("video") |
|
|
| async def _fetch_detail(token: str): |
| account = account_map.get(token) |
| list_service = ListService() |
| try: |
| count = await list_service.count(token) |
| detail = { |
| "token": token, |
| "token_masked": account["token_masked"] if account else token, |
| "count": count, |
| "status": "ok", |
| "last_asset_clear_at": account["last_asset_clear_at"] |
| if account |
| else None, |
| } |
| return {"ok": True, "detail": detail, "count": count} |
| except Exception as e: |
| detail = { |
| "token": token, |
| "token_masked": account["token_masked"] if account else token, |
| "count": 0, |
| "status": f"error: {str(e)}", |
| "last_asset_clear_at": account["last_asset_clear_at"] |
| if account |
| else None, |
| } |
| return {"ok": False, "detail": detail, "count": 0} |
| finally: |
| await list_service.close() |
|
|
| async def _on_item(item: str, res: dict): |
| ok = bool(res.get("data", {}).get("ok")) |
| task.record(ok) |
|
|
| raw_results = await run_in_batches( |
| selected_tokens, |
| _fetch_detail, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| on_item=_on_item, |
| should_cancel=lambda: task.cancelled, |
| ) |
|
|
| if task.cancelled: |
| task.finish_cancelled() |
| return |
|
|
| online_details = [] |
| total = 0 |
| for token, res in raw_results.items(): |
| data = res.get("data", {}) |
| detail = data.get("detail") |
| if detail: |
| online_details.append(detail) |
| total += data.get("count", 0) |
|
|
| online_stats = { |
| "count": total, |
| "status": "ok" if selected_tokens else "no_token", |
| "token": None, |
| "last_asset_clear_at": None, |
| } |
|
|
| result = { |
| "local_image": image_stats, |
| "local_video": video_stats, |
| "online": online_stats, |
| "online_accounts": accounts, |
| "online_scope": scope or "none", |
| "online_details": online_details, |
| } |
| warning = None |
| if truncated: |
| warning = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| task.finish(result, warning=warning) |
| except Exception as e: |
| task.fail_task(str(e)) |
| finally: |
| asyncio.create_task(expire_task(task.id, 300)) |
|
|
| asyncio.create_task(_run()) |
|
|
| return { |
| "status": "success", |
| "task_id": task.id, |
| "total": len(selected_tokens), |
| } |
|
|
|
|
| @router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)]) |
| async def clear_local_cache_api(data: dict): |
| """清理本地缓存""" |
| from app.services.grok.services.assets import DownloadService |
|
|
| cache_type = data.get("type", "image") |
|
|
| try: |
| dl_service = DownloadService() |
| result = dl_service.clear(cache_type) |
| return {"status": "success", "result": result} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.get("/api/v1/admin/cache/list", dependencies=[Depends(verify_api_key)]) |
| async def list_local_cache_api( |
| cache_type: str = "image", |
| type_: str = Query(default=None, alias="type"), |
| page: int = 1, |
| page_size: int = 1000, |
| ): |
| """列出本地缓存文件""" |
| from app.services.grok.services.assets import DownloadService |
|
|
| try: |
| if type_: |
| cache_type = type_ |
| dl_service = DownloadService() |
| result = dl_service.list_files(cache_type, page, page_size) |
| return {"status": "success", **result} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)]) |
| async def delete_local_cache_item_api(data: dict): |
| """删除单个本地缓存文件""" |
| from app.services.grok.services.assets import DownloadService |
|
|
| cache_type = data.get("type", "image") |
| name = data.get("name") |
| if not name: |
| raise HTTPException(status_code=400, detail="Missing file name") |
| try: |
| dl_service = DownloadService() |
| result = dl_service.delete_file(cache_type, name) |
| return {"status": "success", "result": result} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)]) |
| async def clear_online_cache_api(data: dict): |
| """清理在线缓存""" |
| from app.services.grok.services.assets import DeleteService |
| from app.services.token.manager import get_token_manager |
| from app.services.grok.utils.batch import run_in_batches |
|
|
| delete_service = None |
| try: |
| mgr = await get_token_manager() |
| tokens = data.get("tokens") |
| delete_service = DeleteService() |
|
|
| if isinstance(tokens, list): |
| token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] |
| if not token_list: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| |
| token_list = list(dict.fromkeys(token_list)) |
|
|
| |
| max_tokens = int(get_config("performance.assets_max_tokens")) |
| token_list, truncated, original_count = _truncate_tokens( |
| token_list, max_tokens, "Clear online cache" |
| ) |
|
|
| results = {} |
| max_concurrent = max( |
| 1, int(get_config("performance.assets_max_concurrent")) |
| ) |
| batch_size = max(1, int(get_config("performance.assets_batch_size"))) |
|
|
| async def _clear_one(t: str): |
| try: |
| result = await delete_service.delete_all(t) |
| await mgr.mark_asset_clear(t) |
| return {"status": "success", "result": result} |
| except Exception as e: |
| return {"status": "error", "error": str(e)} |
|
|
| raw_results = await run_in_batches( |
| token_list, |
| _clear_one, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| ) |
| for token, res in raw_results.items(): |
| if res.get("ok"): |
| results[token] = res.get("data", {}) |
| else: |
| results[token] = {"status": "error", "error": res.get("error")} |
|
|
| response = {"status": "success", "results": results} |
| if truncated: |
| response["warning"] = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| return response |
|
|
| token = data.get("token") or mgr.get_token() |
| if not token: |
| raise HTTPException( |
| status_code=400, detail="No available token to perform cleanup" |
| ) |
|
|
| result = await delete_service.delete_all(token) |
| await mgr.mark_asset_clear(token) |
| return {"status": "success", "result": result} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| finally: |
| if delete_service: |
| await delete_service.close() |
|
|
|
|
| @router.post( |
| "/api/v1/admin/cache/online/clear/async", dependencies=[Depends(verify_api_key)] |
| ) |
| async def clear_online_cache_api_async(data: dict): |
| """清理在线缓存(异步批量 + SSE 进度)""" |
| from app.services.grok.services.assets import DeleteService |
| from app.services.token.manager import get_token_manager |
| from app.services.grok.utils.batch import run_in_batches |
|
|
| mgr = await get_token_manager() |
| tokens = data.get("tokens") |
| if not isinstance(tokens, list): |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] |
| if not token_list: |
| raise HTTPException(status_code=400, detail="No tokens provided") |
|
|
| max_tokens = int(get_config("performance.assets_max_tokens")) |
| token_list, truncated, original_count = _truncate_tokens( |
| token_list, max_tokens, "Clear online cache async" |
| ) |
|
|
| max_concurrent = get_config("performance.assets_max_concurrent") |
| batch_size = get_config("performance.assets_batch_size") |
|
|
| task = create_task(len(token_list)) |
|
|
| async def _run(): |
| delete_service = DeleteService() |
| try: |
|
|
| async def _clear_one(t: str): |
| try: |
| result = await delete_service.delete_all(t) |
| await mgr.mark_asset_clear(t) |
| return {"ok": True, "result": result} |
| except Exception as e: |
| return {"ok": False, "error": str(e)} |
|
|
| async def _on_item(item: str, res: dict): |
| ok = bool(res.get("data", {}).get("ok")) |
| task.record(ok) |
|
|
| raw_results = await run_in_batches( |
| token_list, |
| _clear_one, |
| max_concurrent=max_concurrent, |
| batch_size=batch_size, |
| on_item=_on_item, |
| should_cancel=lambda: task.cancelled, |
| ) |
|
|
| if task.cancelled: |
| task.finish_cancelled() |
| return |
|
|
| results = {} |
| ok_count = 0 |
| fail_count = 0 |
| for token, res in raw_results.items(): |
| data = res.get("data", {}) |
| if data.get("ok"): |
| ok_count += 1 |
| results[token] = {"status": "success", "result": data.get("result")} |
| else: |
| fail_count += 1 |
| results[token] = {"status": "error", "error": data.get("error")} |
|
|
| result = { |
| "status": "success", |
| "summary": { |
| "total": len(token_list), |
| "ok": ok_count, |
| "fail": fail_count, |
| }, |
| "results": results, |
| } |
| warning = None |
| if truncated: |
| warning = ( |
| f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" |
| ) |
| task.finish(result, warning=warning) |
| except Exception as e: |
| task.fail_task(str(e)) |
| finally: |
| await delete_service.close() |
| asyncio.create_task(expire_task(task.id, 300)) |
|
|
| asyncio.create_task(_run()) |
|
|
| return { |
| "status": "success", |
| "task_id": task.id, |
| "total": len(token_list), |
| } |
|
|