from fastapi import ( APIRouter, Depends, HTTPException, Request, Query, WebSocket, WebSocketDisconnect, ) from fastapi.responses import HTMLResponse, StreamingResponse, RedirectResponse from typing import Optional from pydantic import BaseModel from app.core.auth import verify_api_key, verify_app_key, get_admin_api_key from app.core.config import config, get_config from app.core.batch_tasks import create_task, get_task, expire_task from app.core.storage import get_storage, LocalStorage, RedisStorage, SQLStorage from app.core.exceptions import AppException from app.services.token.manager import get_token_manager from app.services.grok.utils.batch import run_in_batches import os import time import uuid from pathlib import Path import aiofiles import asyncio import orjson from app.core.logger import logger from app.api.v1.image import resolve_aspect_ratio from app.services.grok.services.voice import VoiceService from app.services.grok.services.image import image_service from app.services.grok.models.model import ModelService from app.services.grok.processors.image_ws_processors import ImageWSCollectProcessor from app.services.token import EffortType TEMPLATE_DIR = Path(__file__).parent.parent.parent / "static" router = APIRouter() IMAGINE_SESSION_TTL = 600 _IMAGINE_SESSIONS: dict[str, dict] = {} _IMAGINE_SESSIONS_LOCK = asyncio.Lock() async def _cleanup_imagine_sessions(now: float) -> None: expired = [ key for key, info in _IMAGINE_SESSIONS.items() if now - float(info.get("created_at") or 0) > IMAGINE_SESSION_TTL ] for key in expired: _IMAGINE_SESSIONS.pop(key, None) async def _create_imagine_session(prompt: str, aspect_ratio: str) -> str: task_id = uuid.uuid4().hex now = time.time() async with _IMAGINE_SESSIONS_LOCK: await _cleanup_imagine_sessions(now) _IMAGINE_SESSIONS[task_id] = { "prompt": prompt, "aspect_ratio": aspect_ratio, "created_at": now, } return task_id async def _get_imagine_session(task_id: str) -> Optional[dict]: if not task_id: return None now = time.time() async with _IMAGINE_SESSIONS_LOCK: await _cleanup_imagine_sessions(now) info = _IMAGINE_SESSIONS.get(task_id) if not info: return None created_at = float(info.get("created_at") or 0) if now - created_at > IMAGINE_SESSION_TTL: _IMAGINE_SESSIONS.pop(task_id, None) return None return dict(info) async def _delete_imagine_session(task_id: str) -> None: if not task_id: return async with _IMAGINE_SESSIONS_LOCK: _IMAGINE_SESSIONS.pop(task_id, None) async def _delete_imagine_sessions(task_ids: list[str]) -> int: if not task_ids: return 0 removed = 0 async with _IMAGINE_SESSIONS_LOCK: for task_id in task_ids: if task_id and task_id in _IMAGINE_SESSIONS: _IMAGINE_SESSIONS.pop(task_id, None) removed += 1 return removed def _collect_tokens(data: dict) -> list[str]: """从请求数据中收集 token 列表""" tokens = [] if isinstance(data.get("token"), str) and data["token"].strip(): tokens.append(data["token"].strip()) if isinstance(data.get("tokens"), list): tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()]) return tokens def _truncate_tokens( tokens: list[str], max_tokens: int, operation: str = "operation" ) -> tuple[list[str], bool, int]: """去重并截断 token 列表,返回 (unique_tokens, truncated, original_count)""" unique_tokens = list(dict.fromkeys(tokens)) original_count = len(unique_tokens) truncated = False if len(unique_tokens) > max_tokens: unique_tokens = unique_tokens[:max_tokens] truncated = True logger.warning( f"{operation}: truncated from {original_count} to {max_tokens} tokens" ) return unique_tokens, truncated, original_count def _mask_token(token: str) -> str: """掩码 token 显示""" return f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token async def render_template(filename: str): """渲染指定模板""" template_path = TEMPLATE_DIR / filename if not template_path.exists(): return HTMLResponse(f"Template {filename} not found.", status_code=404) async with aiofiles.open(template_path, "r", encoding="utf-8") as f: content = await f.read() return HTMLResponse(content) def _sse_event(payload: dict) -> str: return f"data: {orjson.dumps(payload).decode()}\n\n" def _verify_stream_api_key(request: Request) -> None: api_key = get_admin_api_key() if not api_key: return key = request.query_params.get("api_key") if key != api_key: raise HTTPException(status_code=401, detail="Invalid authentication token") @router.get("/api/v1/admin/batch/{task_id}/stream") async def stream_batch(task_id: str, request: Request): _verify_stream_api_key(request) task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") async def event_stream(): queue = task.attach() try: yield _sse_event({"type": "snapshot", **task.snapshot()}) final = task.final_event() if final: yield _sse_event(final) return while True: try: event = await asyncio.wait_for(queue.get(), timeout=15) except asyncio.TimeoutError: yield ": ping\n\n" final = task.final_event() if final: yield _sse_event(final) return continue yield _sse_event(event) if event.get("type") in ("done", "error", "cancelled"): return finally: task.detach(queue) return StreamingResponse(event_stream(), media_type="text/event-stream") @router.post( "/api/v1/admin/batch/{task_id}/cancel", dependencies=[Depends(verify_api_key)] ) async def cancel_batch(task_id: str): task = get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") task.cancel() return {"status": "success"} @router.get("/admin", response_class=HTMLResponse, include_in_schema=False) async def admin_login_page(): """管理后台登录页""" return await render_template("login/login.html") @router.get("/", include_in_schema=False) async def root_redirect(): return RedirectResponse(url="/admin") @router.get("/admin/config", response_class=HTMLResponse, include_in_schema=False) async def admin_config_page(): """配置管理页""" return await render_template("config/config.html") @router.get("/admin/token", response_class=HTMLResponse, include_in_schema=False) async def admin_token_page(): """Token 管理页""" return await render_template("token/token.html") @router.get("/admin/voice", response_class=HTMLResponse, include_in_schema=False) async def admin_voice_page(): """Voice Live 调试页""" return await render_template("voice/voice.html") @router.get("/admin/imagine", response_class=HTMLResponse, include_in_schema=False) async def admin_imagine_page(): """Imagine 图片瀑布流""" return await render_template("imagine/imagine.html") class VoiceTokenResponse(BaseModel): token: str url: str participant_name: str = "" room_name: str = "" @router.get( "/api/v1/admin/voice/token", dependencies=[Depends(verify_api_key)], response_model=VoiceTokenResponse, ) async def admin_voice_token( voice: str = "ara", personality: str = "assistant", speed: float = 1.0, ): """获取 Grok Voice Mode (LiveKit) Token""" token_mgr = await get_token_manager() sso_token = None for pool_name in ("ssoBasic", "ssoSuper"): sso_token = token_mgr.get_token(pool_name) if sso_token: break if not sso_token: raise AppException( "No available tokens for voice mode", code="no_token", status_code=503, ) service = VoiceService() try: data = await service.get_token( token=sso_token, voice=voice, personality=personality, speed=speed, ) token = data.get("token") if not token: raise AppException( "Upstream returned no voice token", code="upstream_error", status_code=502, ) return VoiceTokenResponse( token=token, url="wss://livekit.grok.com", participant_name="", room_name="", ) except Exception as e: if isinstance(e, AppException): raise raise AppException( f"Voice token error: {str(e)}", code="voice_error", status_code=500, ) async def _verify_imagine_ws_auth(websocket: WebSocket) -> tuple[bool, Optional[str]]: task_id = websocket.query_params.get("task_id") if task_id: info = await _get_imagine_session(task_id) if info: return True, task_id api_key = get_admin_api_key() if not api_key: return True, None key = websocket.query_params.get("api_key") return key == api_key, None @router.websocket("/api/v1/admin/imagine/ws") async def admin_imagine_ws(websocket: WebSocket): ok, session_id = await _verify_imagine_ws_auth(websocket) if not ok: await websocket.close(code=1008) return await websocket.accept() stop_event = asyncio.Event() run_task: Optional[asyncio.Task] = None async def _send(payload: dict) -> bool: try: await websocket.send_text(orjson.dumps(payload).decode()) return True except Exception: return False async def _stop_run(): nonlocal run_task stop_event.set() if run_task and not run_task.done(): run_task.cancel() try: await run_task except Exception: pass run_task = None stop_event.clear() async def _run(prompt: str, aspect_ratio: str): model_id = "grok-imagine-1.0" model_info = ModelService.get(model_id) if not model_info or not model_info.is_image: await _send( { "type": "error", "message": "Image model is not available.", "code": "model_not_supported", } ) return token_mgr = await get_token_manager() enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) sequence = 0 run_id = uuid.uuid4().hex await _send( { "type": "status", "status": "running", "prompt": prompt, "aspect_ratio": aspect_ratio, "run_id": run_id, } ) while not stop_event.is_set(): try: await token_mgr.reload_if_stale() token = None for pool_name in ModelService.pool_candidates_for_model( model_info.model_id ): token = token_mgr.get_token(pool_name) if token: break if not token: await _send( { "type": "error", "message": "No available tokens. Please try again later.", "code": "rate_limit_exceeded", } ) await asyncio.sleep(2) continue upstream = image_service.stream( token=token, prompt=prompt, aspect_ratio=aspect_ratio, n=6, enable_nsfw=enable_nsfw, ) processor = ImageWSCollectProcessor( model_info.model_id, token, n=6, response_format="b64_json", ) start_at = time.time() images = await processor.process(upstream) elapsed_ms = int((time.time() - start_at) * 1000) if images and all(img and img != "error" for img in images): # 一次发送所有 6 张图片 for img_b64 in images: sequence += 1 await _send( { "type": "image", "b64_json": img_b64, "sequence": sequence, "created_at": int(time.time() * 1000), "elapsed_ms": elapsed_ms, "aspect_ratio": aspect_ratio, "run_id": run_id, } ) # 消耗 token(6 张图片按高成本计算) try: effort = ( EffortType.HIGH if (model_info and model_info.cost.value == "high") else EffortType.LOW ) await token_mgr.consume(token, effort) except Exception as e: logger.warning(f"Failed to consume token: {e}") else: await _send( { "type": "error", "message": "Image generation returned empty data.", "code": "empty_image", } ) except asyncio.CancelledError: break except Exception as e: logger.warning(f"Imagine stream error: {e}") await _send( { "type": "error", "message": str(e), "code": "internal_error", } ) await asyncio.sleep(1.5) await _send({"type": "status", "status": "stopped", "run_id": run_id}) try: while True: try: raw = await websocket.receive_text() except (RuntimeError, WebSocketDisconnect): # WebSocket already closed or disconnected break try: payload = orjson.loads(raw) except Exception: await _send( { "type": "error", "message": "Invalid message format.", "code": "invalid_payload", } ) continue msg_type = payload.get("type") if msg_type == "start": prompt = str(payload.get("prompt") or "").strip() if not prompt: await _send( { "type": "error", "message": "Prompt cannot be empty.", "code": "empty_prompt", } ) continue ratio = str(payload.get("aspect_ratio") or "2:3").strip() if not ratio: ratio = "2:3" ratio = resolve_aspect_ratio(ratio) await _stop_run() stop_event.clear() run_task = asyncio.create_task(_run(prompt, ratio)) elif msg_type == "stop": await _stop_run() elif msg_type == "ping": await _send({"type": "pong"}) else: await _send( { "type": "error", "message": "Unknown command.", "code": "unknown_command", } ) except WebSocketDisconnect: logger.debug("WebSocket disconnected by client") except Exception as e: logger.warning(f"WebSocket error: {e}") finally: await _stop_run() try: from starlette.websockets import WebSocketState if websocket.client_state == WebSocketState.CONNECTED: await websocket.close(code=1000, reason="Server closing connection") except Exception as e: logger.debug(f"WebSocket close ignored: {e}") if session_id: await _delete_imagine_session(session_id) class ImagineStartRequest(BaseModel): prompt: str aspect_ratio: Optional[str] = "2:3" @router.post("/api/v1/admin/imagine/start", dependencies=[Depends(verify_api_key)]) async def admin_imagine_start(data: ImagineStartRequest): prompt = (data.prompt or "").strip() if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") ratio = resolve_aspect_ratio(str(data.aspect_ratio or "2:3").strip() or "2:3") task_id = await _create_imagine_session(prompt, ratio) return {"task_id": task_id, "aspect_ratio": ratio} class ImagineStopRequest(BaseModel): task_ids: list[str] @router.post("/api/v1/admin/imagine/stop", dependencies=[Depends(verify_api_key)]) async def admin_imagine_stop(data: ImagineStopRequest): removed = await _delete_imagine_sessions(data.task_ids or []) return {"status": "success", "removed": removed} @router.get("/api/v1/admin/imagine/sse") async def admin_imagine_sse( request: Request, task_id: str = Query(""), prompt: str = Query(""), aspect_ratio: str = Query("2:3"), ): """Imagine 图片瀑布流(SSE 兜底)""" session = None if task_id: session = await _get_imagine_session(task_id) if not session: raise HTTPException(status_code=404, detail="Task not found") else: _verify_stream_api_key(request) if session: prompt = str(session.get("prompt") or "").strip() ratio = str(session.get("aspect_ratio") or "2:3").strip() or "2:3" else: prompt = (prompt or "").strip() if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") ratio = str(aspect_ratio or "2:3").strip() or "2:3" ratio = resolve_aspect_ratio(ratio) async def event_stream(): try: model_id = "grok-imagine-1.0" model_info = ModelService.get(model_id) if not model_info or not model_info.is_image: yield _sse_event( { "type": "error", "message": "Image model is not available.", "code": "model_not_supported", } ) return token_mgr = await get_token_manager() enable_nsfw = bool(get_config("image.image_ws_nsfw", True)) sequence = 0 run_id = uuid.uuid4().hex yield _sse_event( { "type": "status", "status": "running", "prompt": prompt, "aspect_ratio": ratio, "run_id": run_id, } ) while True: if await request.is_disconnected(): break if task_id: session_alive = await _get_imagine_session(task_id) if not session_alive: break try: await token_mgr.reload_if_stale() token = None for pool_name in ModelService.pool_candidates_for_model( model_info.model_id ): token = token_mgr.get_token(pool_name) if token: break if not token: yield _sse_event( { "type": "error", "message": "No available tokens. Please try again later.", "code": "rate_limit_exceeded", } ) await asyncio.sleep(2) continue upstream = image_service.stream( token=token, prompt=prompt, aspect_ratio=ratio, n=6, enable_nsfw=enable_nsfw, ) processor = ImageWSCollectProcessor( model_info.model_id, token, n=6, response_format="b64_json", ) start_at = time.time() images = await processor.process(upstream) elapsed_ms = int((time.time() - start_at) * 1000) if images and all(img and img != "error" for img in images): for img_b64 in images: sequence += 1 yield _sse_event( { "type": "image", "b64_json": img_b64, "sequence": sequence, "created_at": int(time.time() * 1000), "elapsed_ms": elapsed_ms, "aspect_ratio": ratio, "run_id": run_id, } ) try: effort = ( EffortType.HIGH if (model_info and model_info.cost.value == "high") else EffortType.LOW ) await token_mgr.consume(token, effort) except Exception as e: logger.warning(f"Failed to consume token: {e}") else: yield _sse_event( { "type": "error", "message": "Image generation returned empty data.", "code": "empty_image", } ) except asyncio.CancelledError: break except Exception as e: logger.warning(f"Imagine SSE error: {e}") yield _sse_event( {"type": "error", "message": str(e), "code": "internal_error"} ) await asyncio.sleep(1.5) yield _sse_event({"type": "status", "status": "stopped", "run_id": run_id}) finally: if task_id: await _delete_imagine_session(task_id) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) @router.post("/api/v1/admin/login", dependencies=[Depends(verify_app_key)]) async def admin_login_api(): """管理后台登录验证(使用 app_key)""" return {"status": "success", "api_key": get_admin_api_key()} @router.get("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) async def get_config_api(): """获取当前配置""" # 暴露原始配置字典 return config._config @router.post("/api/v1/admin/config", dependencies=[Depends(verify_api_key)]) async def update_config_api(data: dict): """更新配置""" try: await config.update(data) return {"status": "success", "message": "配置已更新"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/api/v1/admin/storage", dependencies=[Depends(verify_api_key)]) async def get_storage_info(): """获取当前存储模式""" storage_type = os.getenv("SERVER_STORAGE_TYPE", "").lower() if not storage_type: storage_type = str(get_config("storage.type")).lower() if not storage_type: storage = get_storage() if isinstance(storage, LocalStorage): storage_type = "local" elif isinstance(storage, RedisStorage): storage_type = "redis" elif isinstance(storage, SQLStorage): storage_type = { "mysql": "mysql", "mariadb": "mysql", "postgres": "pgsql", "postgresql": "pgsql", "pgsql": "pgsql", }.get(storage.dialect, storage.dialect) return {"type": storage_type or "local"} @router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) async def get_tokens_api(): """获取所有 Token""" storage = get_storage() tokens = await storage.load_tokens() return tokens or {} @router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)]) async def update_tokens_api(data: dict): """更新 Token 信息""" storage = get_storage() try: from app.services.token.manager import get_token_manager from app.services.token.models import TokenInfo async with storage.acquire_lock("tokens_save", timeout=10): existing = await storage.load_tokens() or {} normalized = {} allowed_fields = set(TokenInfo.model_fields.keys()) existing_map = {} for pool_name, tokens in existing.items(): if not isinstance(tokens, list): continue pool_map = {} for item in tokens: if isinstance(item, str): token_data = {"token": item} elif isinstance(item, dict): token_data = dict(item) else: continue raw_token = token_data.get("token") if isinstance(raw_token, str) and raw_token.startswith("sso="): token_data["token"] = raw_token[4:] token_key = token_data.get("token") if isinstance(token_key, str): pool_map[token_key] = token_data existing_map[pool_name] = pool_map for pool_name, tokens in (data or {}).items(): if not isinstance(tokens, list): continue pool_list = [] for item in tokens: if isinstance(item, str): token_data = {"token": item} elif isinstance(item, dict): token_data = dict(item) else: continue raw_token = token_data.get("token") if isinstance(raw_token, str) and raw_token.startswith("sso="): token_data["token"] = raw_token[4:] base = existing_map.get(pool_name, {}).get( token_data.get("token"), {} ) merged = dict(base) merged.update(token_data) if merged.get("tags") is None: merged["tags"] = [] filtered = {k: v for k, v in merged.items() if k in allowed_fields} try: info = TokenInfo(**filtered) pool_list.append(info.model_dump()) except Exception as e: logger.warning(f"Skip invalid token in pool '{pool_name}': {e}") continue normalized[pool_name] = pool_list await storage.save_tokens(normalized) mgr = await get_token_manager() await mgr.reload() return {"status": "success", "message": "Token 已更新"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/v1/admin/tokens/refresh", dependencies=[Depends(verify_api_key)]) async def refresh_tokens_api(data: dict): """刷新 Token 状态""" try: mgr = await get_token_manager() tokens = _collect_tokens(data) if not tokens: raise HTTPException(status_code=400, detail="No tokens provided") # 去重并截断 max_tokens = int(get_config("performance.usage_max_tokens")) unique_tokens, truncated, original_count = _truncate_tokens( tokens, max_tokens, "Usage refresh" ) # 批量执行配置 max_concurrent = get_config("performance.usage_max_concurrent") batch_size = get_config("performance.usage_batch_size") async def _refresh_one(t): return await mgr.sync_usage( t, "grok-3", consume_on_fail=False, is_usage=False ) raw_results = await run_in_batches( unique_tokens, _refresh_one, max_concurrent=max_concurrent, batch_size=batch_size, ) results = {} for token, res in raw_results.items(): if res.get("ok"): results[token] = res.get("data", False) else: results[token] = False response = {"status": "success", "results": results} if truncated: response["warning"] = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post( "/api/v1/admin/tokens/refresh/async", dependencies=[Depends(verify_api_key)] ) async def refresh_tokens_api_async(data: dict): """刷新 Token 状态(异步批量 + SSE 进度)""" mgr = await get_token_manager() tokens = _collect_tokens(data) if not tokens: raise HTTPException(status_code=400, detail="No tokens provided") # 去重并截断 max_tokens = int(get_config("performance.usage_max_tokens")) unique_tokens, truncated, original_count = _truncate_tokens( tokens, max_tokens, "Usage refresh" ) max_concurrent = get_config("performance.usage_max_concurrent") batch_size = get_config("performance.usage_batch_size") task = create_task(len(unique_tokens)) async def _run(): try: async def _refresh_one(t: str): return await mgr.sync_usage( t, "grok-3", consume_on_fail=False, is_usage=False ) async def _on_item(item: str, res: dict): task.record(bool(res.get("ok"))) raw_results = await run_in_batches( unique_tokens, _refresh_one, max_concurrent=max_concurrent, batch_size=batch_size, on_item=_on_item, should_cancel=lambda: task.cancelled, ) if task.cancelled: task.finish_cancelled() return results: dict[str, bool] = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): if res.get("ok") and res.get("data") is True: ok_count += 1 results[token] = True else: fail_count += 1 results[token] = False await mgr._save() result = { "status": "success", "summary": { "total": len(unique_tokens), "ok": ok_count, "fail": fail_count, }, "results": results, } warning = None if truncated: warning = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) task.finish(result, warning=warning) except Exception as e: task.fail_task(str(e)) finally: asyncio.create_task(expire_task(task.id, 300)) asyncio.create_task(_run()) return { "status": "success", "task_id": task.id, "total": len(unique_tokens), } @router.post("/api/v1/admin/tokens/nsfw/enable", dependencies=[Depends(verify_api_key)]) async def enable_nsfw_api(data: dict): """批量开启 NSFW (Unhinged) 模式""" from app.services.grok.services.nsfw import NSFWService try: mgr = await get_token_manager() nsfw_service = NSFWService() # 收集 token 列表 tokens = _collect_tokens(data) # 若未指定,则使用所有 pool 中的 token if not tokens: for pool_name, pool in mgr.pools.items(): for info in pool.list(): raw = ( info.token[4:] if info.token.startswith("sso=") else info.token ) tokens.append(raw) if not tokens: raise HTTPException(status_code=400, detail="No tokens available") # 去重并截断 max_tokens = int(get_config("performance.nsfw_max_tokens")) unique_tokens, truncated, original_count = _truncate_tokens( tokens, max_tokens, "NSFW enable" ) # 批量执行配置 max_concurrent = get_config("performance.nsfw_max_concurrent") batch_size = get_config("performance.nsfw_batch_size") # 定义 worker async def _enable(token: str): result = await nsfw_service.enable(token) # 成功后添加 nsfw tag if result.success: await mgr.add_tag(token, "nsfw") return { "success": result.success, "http_status": result.http_status, "grpc_status": result.grpc_status, "grpc_message": result.grpc_message, "error": result.error, } # 执行批量操作 raw_results = await run_in_batches( unique_tokens, _enable, max_concurrent=max_concurrent, batch_size=batch_size ) # 构造返回结果(mask token) results = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): masked = _mask_token(token) if res.get("ok") and res.get("data", {}).get("success"): ok_count += 1 results[masked] = res.get("data", {}) else: fail_count += 1 results[masked] = res.get("data") or {"error": res.get("error")} response = { "status": "success", "summary": { "total": len(unique_tokens), "ok": ok_count, "fail": fail_count, }, "results": results, } # 添加截断提示 if truncated: response["warning"] = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) return response except HTTPException: raise except Exception as e: logger.error(f"Enable NSFW failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post( "/api/v1/admin/tokens/nsfw/enable/async", dependencies=[Depends(verify_api_key)] ) async def enable_nsfw_api_async(data: dict): """批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)""" from app.services.grok.services.nsfw import NSFWService mgr = await get_token_manager() nsfw_service = NSFWService() tokens = _collect_tokens(data) if not tokens: for pool_name, pool in mgr.pools.items(): for info in pool.list(): raw = info.token[4:] if info.token.startswith("sso=") else info.token tokens.append(raw) if not tokens: raise HTTPException(status_code=400, detail="No tokens available") # 去重并截断 max_tokens = int(get_config("performance.nsfw_max_tokens")) unique_tokens, truncated, original_count = _truncate_tokens( tokens, max_tokens, "NSFW enable" ) max_concurrent = get_config("performance.nsfw_max_concurrent") batch_size = get_config("performance.nsfw_batch_size") task = create_task(len(unique_tokens)) async def _run(): try: async def _enable(token: str): result = await nsfw_service.enable(token) if result.success: await mgr.add_tag(token, "nsfw") return { "success": result.success, "http_status": result.http_status, "grpc_status": result.grpc_status, "grpc_message": result.grpc_message, "error": result.error, } async def _on_item(item: str, res: dict): ok = bool(res.get("ok") and res.get("data", {}).get("success")) task.record(ok) raw_results = await run_in_batches( unique_tokens, _enable, max_concurrent=max_concurrent, batch_size=batch_size, on_item=_on_item, should_cancel=lambda: task.cancelled, ) if task.cancelled: task.finish_cancelled() return results = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token if res.get("ok") and res.get("data", {}).get("success"): ok_count += 1 results[masked] = res.get("data", {}) else: fail_count += 1 results[masked] = res.get("data") or {"error": res.get("error")} await mgr._save() result = { "status": "success", "summary": { "total": len(unique_tokens), "ok": ok_count, "fail": fail_count, }, "results": results, } warning = None if truncated: warning = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) task.finish(result, warning=warning) except Exception as e: task.fail_task(str(e)) finally: asyncio.create_task(expire_task(task.id, 300)) asyncio.create_task(_run()) return { "status": "success", "task_id": task.id, "total": len(unique_tokens), } @router.get("/admin/cache", response_class=HTMLResponse, include_in_schema=False) async def admin_cache_page(): """缓存管理页""" return await render_template("cache/cache.html") @router.get("/api/v1/admin/cache", dependencies=[Depends(verify_api_key)]) async def get_cache_stats_api(request: Request): """获取缓存统计""" from app.services.grok.services.assets import DownloadService, ListService from app.services.token.manager import get_token_manager from app.services.grok.utils.batch import run_in_batches try: dl_service = DownloadService() image_stats = dl_service.get_stats("image") video_stats = dl_service.get_stats("video") mgr = await get_token_manager() pools = mgr.pools accounts = [] for pool_name, pool in pools.items(): for info in pool.list(): raw_token = ( info.token[4:] if info.token.startswith("sso=") else info.token ) masked = ( f"{raw_token[:8]}...{raw_token[-16:]}" if len(raw_token) > 24 else raw_token ) accounts.append( { "token": raw_token, "token_masked": masked, "pool": pool_name, "status": info.status, "last_asset_clear_at": info.last_asset_clear_at, } ) scope = request.query_params.get("scope") selected_token = request.query_params.get("token") tokens_param = request.query_params.get("tokens") selected_tokens = [] if tokens_param: selected_tokens = [t.strip() for t in tokens_param.split(",") if t.strip()] online_stats = { "count": 0, "status": "unknown", "token": None, "last_asset_clear_at": None, } online_details = [] account_map = {a["token"]: a for a in accounts} max_concurrent = max(1, int(get_config("performance.assets_max_concurrent"))) batch_size = max(1, int(get_config("performance.assets_batch_size"))) max_tokens = int(get_config("performance.assets_max_tokens")) truncated = False original_count = 0 async def _fetch_assets(token: str): list_service = ListService() try: return await list_service.count(token) finally: await list_service.close() async def _fetch_detail(token: str): account = account_map.get(token) try: count = await _fetch_assets(token) return { "detail": { "token": token, "token_masked": account["token_masked"] if account else token, "count": count, "status": "ok", "last_asset_clear_at": account["last_asset_clear_at"] if account else None, }, "count": count, } except Exception as e: return { "detail": { "token": token, "token_masked": account["token_masked"] if account else token, "count": 0, "status": f"error: {str(e)}", "last_asset_clear_at": account["last_asset_clear_at"] if account else None, }, "count": 0, } if selected_tokens: selected_tokens, truncated, original_count = _truncate_tokens( selected_tokens, max_tokens, "Assets fetch" ) total = 0 raw_results = await run_in_batches( selected_tokens, _fetch_detail, max_concurrent=max_concurrent, batch_size=batch_size, ) for token, res in raw_results.items(): if res.get("ok"): data = res.get("data", {}) detail = data.get("detail") total += data.get("count", 0) else: account = account_map.get(token) detail = { "token": token, "token_masked": account["token_masked"] if account else token, "count": 0, "status": f"error: {res.get('error')}", "last_asset_clear_at": account["last_asset_clear_at"] if account else None, } if detail: online_details.append(detail) online_stats = { "count": total, "status": "ok" if selected_tokens else "no_token", "token": None, "last_asset_clear_at": None, } scope = "selected" elif scope == "all": total = 0 tokens = list(dict.fromkeys([account["token"] for account in accounts])) original_count = len(tokens) if len(tokens) > max_tokens: tokens = tokens[:max_tokens] truncated = True raw_results = await run_in_batches( tokens, _fetch_detail, max_concurrent=max_concurrent, batch_size=batch_size, ) for token, res in raw_results.items(): if res.get("ok"): data = res.get("data", {}) detail = data.get("detail") total += data.get("count", 0) else: account = account_map.get(token) detail = { "token": token, "token_masked": account["token_masked"] if account else token, "count": 0, "status": f"error: {res.get('error')}", "last_asset_clear_at": account["last_asset_clear_at"] if account else None, } if detail: online_details.append(detail) online_stats = { "count": total, "status": "ok" if accounts else "no_token", "token": None, "last_asset_clear_at": None, } else: token = selected_token if token: try: count = await _fetch_assets(token) match = next((a for a in accounts if a["token"] == token), None) online_stats = { "count": count, "status": "ok", "token": token, "token_masked": match["token_masked"] if match else token, "last_asset_clear_at": match["last_asset_clear_at"] if match else None, } except Exception as e: match = next((a for a in accounts if a["token"] == token), None) online_stats = { "count": 0, "status": f"error: {str(e)}", "token": token, "token_masked": match["token_masked"] if match else token, "last_asset_clear_at": match["last_asset_clear_at"] if match else None, } else: online_stats = { "count": 0, "status": "not_loaded", "token": None, "last_asset_clear_at": None, } response = { "local_image": image_stats, "local_video": video_stats, "online": online_stats, "online_accounts": accounts, "online_scope": scope or "none", "online_details": online_details, } if truncated: response["warning"] = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post( "/api/v1/admin/cache/online/load/async", dependencies=[Depends(verify_api_key)] ) async def load_online_cache_api_async(data: dict): """在线资产统计(异步批量 + SSE 进度)""" from app.services.grok.services.assets import DownloadService, ListService from app.services.token.manager import get_token_manager from app.services.grok.utils.batch import run_in_batches mgr = await get_token_manager() # 账号列表 accounts = [] for pool_name, pool in mgr.pools.items(): for info in pool.list(): raw_token = info.token[4:] if info.token.startswith("sso=") else info.token masked = ( f"{raw_token[:8]}...{raw_token[-16:]}" if len(raw_token) > 24 else raw_token ) accounts.append( { "token": raw_token, "token_masked": masked, "pool": pool_name, "status": info.status, "last_asset_clear_at": info.last_asset_clear_at, } ) account_map = {a["token"]: a for a in accounts} tokens = data.get("tokens") scope = data.get("scope") selected_tokens: list[str] = [] if isinstance(tokens, list): selected_tokens = [str(t).strip() for t in tokens if str(t).strip()] if not selected_tokens and scope == "all": selected_tokens = [account["token"] for account in accounts] scope = "all" elif selected_tokens: scope = "selected" else: raise HTTPException(status_code=400, detail="No tokens provided") max_tokens = int(get_config("performance.assets_max_tokens")) selected_tokens, truncated, original_count = _truncate_tokens( selected_tokens, max_tokens, "Assets load" ) max_concurrent = get_config("performance.assets_max_concurrent") batch_size = get_config("performance.assets_batch_size") task = create_task(len(selected_tokens)) async def _run(): try: dl_service = DownloadService() image_stats = dl_service.get_stats("image") video_stats = dl_service.get_stats("video") async def _fetch_detail(token: str): account = account_map.get(token) list_service = ListService() try: count = await list_service.count(token) detail = { "token": token, "token_masked": account["token_masked"] if account else token, "count": count, "status": "ok", "last_asset_clear_at": account["last_asset_clear_at"] if account else None, } return {"ok": True, "detail": detail, "count": count} except Exception as e: detail = { "token": token, "token_masked": account["token_masked"] if account else token, "count": 0, "status": f"error: {str(e)}", "last_asset_clear_at": account["last_asset_clear_at"] if account else None, } return {"ok": False, "detail": detail, "count": 0} finally: await list_service.close() async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) task.record(ok) raw_results = await run_in_batches( selected_tokens, _fetch_detail, max_concurrent=max_concurrent, batch_size=batch_size, on_item=_on_item, should_cancel=lambda: task.cancelled, ) if task.cancelled: task.finish_cancelled() return online_details = [] total = 0 for token, res in raw_results.items(): data = res.get("data", {}) detail = data.get("detail") if detail: online_details.append(detail) total += data.get("count", 0) online_stats = { "count": total, "status": "ok" if selected_tokens else "no_token", "token": None, "last_asset_clear_at": None, } result = { "local_image": image_stats, "local_video": video_stats, "online": online_stats, "online_accounts": accounts, "online_scope": scope or "none", "online_details": online_details, } warning = None if truncated: warning = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) task.finish(result, warning=warning) except Exception as e: task.fail_task(str(e)) finally: asyncio.create_task(expire_task(task.id, 300)) asyncio.create_task(_run()) return { "status": "success", "task_id": task.id, "total": len(selected_tokens), } @router.post("/api/v1/admin/cache/clear", dependencies=[Depends(verify_api_key)]) async def clear_local_cache_api(data: dict): """清理本地缓存""" from app.services.grok.services.assets import DownloadService cache_type = data.get("type", "image") try: dl_service = DownloadService() result = dl_service.clear(cache_type) return {"status": "success", "result": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/api/v1/admin/cache/list", dependencies=[Depends(verify_api_key)]) async def list_local_cache_api( cache_type: str = "image", type_: str = Query(default=None, alias="type"), page: int = 1, page_size: int = 1000, ): """列出本地缓存文件""" from app.services.grok.services.assets import DownloadService try: if type_: cache_type = type_ dl_service = DownloadService() result = dl_service.list_files(cache_type, page, page_size) return {"status": "success", **result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/v1/admin/cache/item/delete", dependencies=[Depends(verify_api_key)]) async def delete_local_cache_item_api(data: dict): """删除单个本地缓存文件""" from app.services.grok.services.assets import DownloadService cache_type = data.get("type", "image") name = data.get("name") if not name: raise HTTPException(status_code=400, detail="Missing file name") try: dl_service = DownloadService() result = dl_service.delete_file(cache_type, name) return {"status": "success", "result": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/v1/admin/cache/online/clear", dependencies=[Depends(verify_api_key)]) async def clear_online_cache_api(data: dict): """清理在线缓存""" from app.services.grok.services.assets import DeleteService from app.services.token.manager import get_token_manager from app.services.grok.utils.batch import run_in_batches delete_service = None try: mgr = await get_token_manager() tokens = data.get("tokens") delete_service = DeleteService() if isinstance(tokens, list): token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] if not token_list: raise HTTPException(status_code=400, detail="No tokens provided") # 去重并保持顺序 token_list = list(dict.fromkeys(token_list)) # 最大数量限制 max_tokens = int(get_config("performance.assets_max_tokens")) token_list, truncated, original_count = _truncate_tokens( token_list, max_tokens, "Clear online cache" ) results = {} max_concurrent = max( 1, int(get_config("performance.assets_max_concurrent")) ) batch_size = max(1, int(get_config("performance.assets_batch_size"))) async def _clear_one(t: str): try: result = await delete_service.delete_all(t) await mgr.mark_asset_clear(t) return {"status": "success", "result": result} except Exception as e: return {"status": "error", "error": str(e)} raw_results = await run_in_batches( token_list, _clear_one, max_concurrent=max_concurrent, batch_size=batch_size, ) for token, res in raw_results.items(): if res.get("ok"): results[token] = res.get("data", {}) else: results[token] = {"status": "error", "error": res.get("error")} response = {"status": "success", "results": results} if truncated: response["warning"] = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) return response token = data.get("token") or mgr.get_token() if not token: raise HTTPException( status_code=400, detail="No available token to perform cleanup" ) result = await delete_service.delete_all(token) await mgr.mark_asset_clear(token) return {"status": "success", "result": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: if delete_service: await delete_service.close() @router.post( "/api/v1/admin/cache/online/clear/async", dependencies=[Depends(verify_api_key)] ) async def clear_online_cache_api_async(data: dict): """清理在线缓存(异步批量 + SSE 进度)""" from app.services.grok.services.assets import DeleteService from app.services.token.manager import get_token_manager from app.services.grok.utils.batch import run_in_batches mgr = await get_token_manager() tokens = data.get("tokens") if not isinstance(tokens, list): raise HTTPException(status_code=400, detail="No tokens provided") token_list = [t.strip() for t in tokens if isinstance(t, str) and t.strip()] if not token_list: raise HTTPException(status_code=400, detail="No tokens provided") max_tokens = int(get_config("performance.assets_max_tokens")) token_list, truncated, original_count = _truncate_tokens( token_list, max_tokens, "Clear online cache async" ) max_concurrent = get_config("performance.assets_max_concurrent") batch_size = get_config("performance.assets_batch_size") task = create_task(len(token_list)) async def _run(): delete_service = DeleteService() try: async def _clear_one(t: str): try: result = await delete_service.delete_all(t) await mgr.mark_asset_clear(t) return {"ok": True, "result": result} except Exception as e: return {"ok": False, "error": str(e)} async def _on_item(item: str, res: dict): ok = bool(res.get("data", {}).get("ok")) task.record(ok) raw_results = await run_in_batches( token_list, _clear_one, max_concurrent=max_concurrent, batch_size=batch_size, on_item=_on_item, should_cancel=lambda: task.cancelled, ) if task.cancelled: task.finish_cancelled() return results = {} ok_count = 0 fail_count = 0 for token, res in raw_results.items(): data = res.get("data", {}) if data.get("ok"): ok_count += 1 results[token] = {"status": "success", "result": data.get("result")} else: fail_count += 1 results[token] = {"status": "error", "error": data.get("error")} result = { "status": "success", "summary": { "total": len(token_list), "ok": ok_count, "fail": fail_count, }, "results": results, } warning = None if truncated: warning = ( f"数量超出限制,仅处理前 {max_tokens} 个(共 {original_count} 个)" ) task.finish(result, warning=warning) except Exception as e: task.fail_task(str(e)) finally: await delete_service.close() asyncio.create_task(expire_task(task.id, 300)) asyncio.create_task(_run()) return { "status": "success", "task_id": task.id, "total": len(token_list), }