ggload / app /api /v1 /admin_api /token.py
f2d90b38's picture
Upload 120 files
8cdca00 verified
import asyncio
import orjson
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from app.core.auth import get_app_key, verify_app_key
from app.core.batch import create_task, expire_task, get_task
from app.core.logger import logger
from app.core.storage import get_storage
from app.services.grok.batch_services.usage import UsageService
from app.services.grok.batch_services.nsfw import NSFWService
from app.services.token.manager import get_token_manager
router = APIRouter()
@router.get("/tokens", dependencies=[Depends(verify_app_key)])
async def get_tokens():
"""获取所有 Token"""
storage = get_storage()
tokens = await storage.load_tokens()
return tokens or {}
@router.post("/tokens", dependencies=[Depends(verify_app_key)])
async def update_tokens(data: dict):
"""更新 Token 信息"""
storage = get_storage()
try:
from app.services.token.models import TokenInfo
async with storage.acquire_lock("tokens_save", timeout=10):
existing = await storage.load_tokens() or {}
normalized = {}
allowed_fields = set(TokenInfo.model_fields.keys())
existing_map = {}
for pool_name, tokens in existing.items():
if not isinstance(tokens, list):
continue
pool_map = {}
for item in tokens:
if isinstance(item, str):
token_data = {"token": item}
elif isinstance(item, dict):
token_data = dict(item)
else:
continue
raw_token = token_data.get("token")
if isinstance(raw_token, str) and raw_token.startswith("sso="):
token_data["token"] = raw_token[4:]
token_key = token_data.get("token")
if isinstance(token_key, str):
pool_map[token_key] = token_data
existing_map[pool_name] = pool_map
for pool_name, tokens in (data or {}).items():
if not isinstance(tokens, list):
continue
pool_list = []
for item in tokens:
if isinstance(item, str):
token_data = {"token": item}
elif isinstance(item, dict):
token_data = dict(item)
else:
continue
raw_token = token_data.get("token")
if isinstance(raw_token, str) and raw_token.startswith("sso="):
token_data["token"] = raw_token[4:]
base = existing_map.get(pool_name, {}).get(
token_data.get("token"), {}
)
merged = dict(base)
merged.update(token_data)
if merged.get("tags") is None:
merged["tags"] = []
filtered = {k: v for k, v in merged.items() if k in allowed_fields}
try:
info = TokenInfo(**filtered)
pool_list.append(info.model_dump())
except Exception as e:
logger.warning(f"Skip invalid token in pool '{pool_name}': {e}")
continue
normalized[pool_name] = pool_list
await storage.save_tokens(normalized)
mgr = await get_token_manager()
await mgr.reload()
return {"status": "success", "message": "Token 已更新"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/tokens/refresh", dependencies=[Depends(verify_app_key)])
async def refresh_tokens(data: dict):
"""刷新 Token 状态"""
try:
mgr = await get_token_manager()
tokens = []
if isinstance(data.get("token"), str) and data["token"].strip():
tokens.append(data["token"].strip())
if isinstance(data.get("tokens"), list):
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
if not tokens:
raise HTTPException(status_code=400, detail="No tokens provided")
unique_tokens = list(dict.fromkeys(tokens))
raw_results = await UsageService.batch(
unique_tokens,
mgr,
)
results = {}
for token, res in raw_results.items():
if res.get("ok"):
results[token] = res.get("data", False)
else:
results[token] = False
response = {"status": "success", "results": results}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/tokens/refresh/async", dependencies=[Depends(verify_app_key)])
async def refresh_tokens_async(data: dict):
"""刷新 Token 状态(异步批量 + SSE 进度)"""
mgr = await get_token_manager()
tokens = []
if isinstance(data.get("token"), str) and data["token"].strip():
tokens.append(data["token"].strip())
if isinstance(data.get("tokens"), list):
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
if not tokens:
raise HTTPException(status_code=400, detail="No tokens provided")
unique_tokens = list(dict.fromkeys(tokens))
task = create_task(len(unique_tokens))
async def _run():
try:
async def _on_item(item: str, res: dict):
task.record(bool(res.get("ok")))
raw_results = await UsageService.batch(
unique_tokens,
mgr,
on_item=_on_item,
should_cancel=lambda: task.cancelled,
)
if task.cancelled:
task.finish_cancelled()
return
results: dict[str, bool] = {}
ok_count = 0
fail_count = 0
for token, res in raw_results.items():
if res.get("ok") and res.get("data") is True:
ok_count += 1
results[token] = True
else:
fail_count += 1
results[token] = False
await mgr._save(force=True)
result = {
"status": "success",
"summary": {
"total": len(unique_tokens),
"ok": ok_count,
"fail": fail_count,
},
"results": results,
}
task.finish(result)
except Exception as e:
task.fail_task(str(e))
finally:
import asyncio
asyncio.create_task(expire_task(task.id, 300))
import asyncio
asyncio.create_task(_run())
return {
"status": "success",
"task_id": task.id,
"total": len(unique_tokens),
}
@router.get("/batch/{task_id}/stream")
async def batch_stream(task_id: str, request: Request):
app_key = get_app_key()
if app_key:
key = request.query_params.get("app_key")
if key != app_key:
raise HTTPException(status_code=401, detail="Invalid authentication token")
task = get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
async def event_stream():
queue = task.attach()
try:
yield f"data: {orjson.dumps({'type': 'snapshot', **task.snapshot()}).decode()}\n\n"
final = task.final_event()
if final:
yield f"data: {orjson.dumps(final).decode()}\n\n"
return
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=15)
except asyncio.TimeoutError:
yield ": ping\n\n"
final = task.final_event()
if final:
yield f"data: {orjson.dumps(final).decode()}\n\n"
return
continue
yield f"data: {orjson.dumps(event).decode()}\n\n"
if event.get("type") in ("done", "error", "cancelled"):
return
finally:
task.detach(queue)
return StreamingResponse(event_stream(), media_type="text/event-stream")
@router.post("/batch/{task_id}/cancel", dependencies=[Depends(verify_app_key)])
async def batch_cancel(task_id: str):
task = get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
task.cancel()
return {"status": "success"}
@router.post("/tokens/nsfw/enable", dependencies=[Depends(verify_app_key)])
async def enable_nsfw(data: dict):
"""批量开启 NSFW (Unhinged) 模式"""
try:
mgr = await get_token_manager()
tokens = []
if isinstance(data.get("token"), str) and data["token"].strip():
tokens.append(data["token"].strip())
if isinstance(data.get("tokens"), list):
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
if not tokens:
for pool_name, pool in mgr.pools.items():
for info in pool.list():
raw = (
info.token[4:] if info.token.startswith("sso=") else info.token
)
tokens.append(raw)
if not tokens:
raise HTTPException(status_code=400, detail="No tokens available")
unique_tokens = list(dict.fromkeys(tokens))
raw_results = await NSFWService.batch(
unique_tokens,
mgr,
)
results = {}
ok_count = 0
fail_count = 0
for token, res in raw_results.items():
masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token
if res.get("ok") and res.get("data", {}).get("success"):
ok_count += 1
results[masked] = res.get("data", {})
else:
fail_count += 1
results[masked] = res.get("data") or {"error": res.get("error")}
response = {
"status": "success",
"summary": {
"total": len(unique_tokens),
"ok": ok_count,
"fail": fail_count,
},
"results": results,
}
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Enable NSFW failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/tokens/nsfw/enable/async", dependencies=[Depends(verify_app_key)])
async def enable_nsfw_async(data: dict):
"""批量开启 NSFW (Unhinged) 模式(异步批量 + SSE 进度)"""
mgr = await get_token_manager()
tokens = []
if isinstance(data.get("token"), str) and data["token"].strip():
tokens.append(data["token"].strip())
if isinstance(data.get("tokens"), list):
tokens.extend([str(t).strip() for t in data["tokens"] if str(t).strip()])
if not tokens:
for pool_name, pool in mgr.pools.items():
for info in pool.list():
raw = info.token[4:] if info.token.startswith("sso=") else info.token
tokens.append(raw)
if not tokens:
raise HTTPException(status_code=400, detail="No tokens available")
unique_tokens = list(dict.fromkeys(tokens))
task = create_task(len(unique_tokens))
async def _run():
try:
async def _on_item(item: str, res: dict):
ok = bool(res.get("ok") and res.get("data", {}).get("success"))
task.record(ok)
raw_results = await NSFWService.batch(
unique_tokens,
mgr,
on_item=_on_item,
should_cancel=lambda: task.cancelled,
)
if task.cancelled:
task.finish_cancelled()
return
results = {}
ok_count = 0
fail_count = 0
for token, res in raw_results.items():
masked = f"{token[:8]}...{token[-8:]}" if len(token) > 20 else token
if res.get("ok") and res.get("data", {}).get("success"):
ok_count += 1
results[masked] = res.get("data", {})
else:
fail_count += 1
results[masked] = res.get("data") or {"error": res.get("error")}
await mgr._save(force=True)
result = {
"status": "success",
"summary": {
"total": len(unique_tokens),
"ok": ok_count,
"fail": fail_count,
},
"results": results,
}
task.finish(result)
except Exception as e:
task.fail_task(str(e))
finally:
import asyncio
asyncio.create_task(expire_task(task.id, 300))
import asyncio
asyncio.create_task(_run())
return {
"status": "success",
"task_id": task.id,
"total": len(unique_tokens),
}