#!/usr/bin/env python3
"""FastAPI 应用:登录鉴权、自研 GUI、翻译任务、内部 OpenAI 代理与计费。"""
from __future__ import annotations
import asyncio
import contextlib
import html
import json
import logging
import os
import shutil
import uuid
from collections import defaultdict
from pathlib import Path
from typing import Any
import httpx
from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import (
FileResponse,
HTMLResponse,
RedirectResponse,
Response,
StreamingResponse,
)
from pdf2zh_next import BasicSettings
from pdf2zh_next import OpenAISettings
from pdf2zh_next import PDFSettings
from pdf2zh_next import SettingsModel
from pdf2zh_next import TranslationSettings
from pdf2zh_next.high_level import do_translate_async_stream
import auth
import billing
import jobs
import proxy
import storage
from web.template_loader import get_static_path, load_template
# ── 配置 ──────────────────────────────────────────────────────────────────────
INTERNAL_OPENAI_BASE_URL = os.environ.get(
"INTERNAL_OPENAI_BASE_URL", "http://127.0.0.1:7860/internal/openai/v1"
)
FIXED_TRANSLATION_MODEL = "SiliconFlowFree"
DEFAULT_LANG_IN = os.environ.get("DEFAULT_LANG_IN", "en").strip()
DEFAULT_LANG_OUT = os.environ.get("DEFAULT_LANG_OUT", "zh").strip()
TRANSLATION_QPS = int(os.environ.get("TRANSLATION_QPS", "4"))
# 上传与任务执行约束配置
MAX_UPLOAD_MB = int(os.environ.get("MAX_UPLOAD_MB", "100"))
MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024
MAX_JOB_RUNTIME_SECONDS = int(os.environ.get("MAX_JOB_RUNTIME_SECONDS", "7200"))
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB per chunk
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
)
logger = logging.getLogger("gateway")
# ── 任务执行 ───────────────────────────────────────────────────────────────────
_job_queue: asyncio.Queue[str] = asyncio.Queue()
_worker_task: asyncio.Task[None] | None = None
_running_tasks: dict[str, asyncio.Task[None]] = {}
_active_job_by_user: dict[str, str] = {}
_job_subscribers: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set)
def _build_settings_for_job(row: sqlite3.Row) -> SettingsModel:
username = row["username"]
internal_key = auth._make_internal_api_key(username)
settings = SettingsModel(
basic=BasicSettings(debug=False, gui=False),
translation=TranslationSettings(
lang_in=row["lang_in"],
lang_out=row["lang_out"],
output=row["output_dir"],
qps=TRANSLATION_QPS,
),
pdf=PDFSettings(),
translate_engine_settings=OpenAISettings(
openai_model=row["model"],
openai_base_url=INTERNAL_OPENAI_BASE_URL,
openai_api_key=internal_key,
),
)
settings.validate_settings()
return settings
async def _consume_translation_stream(
job_id: str,
settings: SettingsModel,
input_path: Path,
output_dir: Path,
) -> None:
"""消费翻译流事件并驱动任务状态机。
注意:本函数不负责超时控制,由上层通过 asyncio.wait_for 约束最大执行时长。
"""
async for event in do_translate_async_stream(settings, input_path):
event_type = event.get("type")
if event_type in {"progress_start", "progress_update", "progress_end"}:
progress = float(event.get("overall_progress", 0.0))
stage = event.get("stage", "")
await _transition_and_notify(
job_id,
"progress",
progress=max(0.0, min(100.0, progress)),
message=f"{stage}" if stage else "Running",
)
elif event_type == "error":
error_msg = str(event.get("error", "Unknown translation error"))
await _transition_and_notify(
job_id,
"finish_error",
error=error_msg,
message="Translation failed",
finished_at=storage.now_iso(),
)
return
elif event_type == "finish":
result = event.get("translate_result")
mono_path = str(getattr(result, "mono_pdf_path", "") or "")
dual_path = str(getattr(result, "dual_pdf_path", "") or "")
glossary_path = str(
getattr(result, "auto_extracted_glossary_path", "") or ""
)
# 兜底:如果路径为空,尝试在输出目录中扫描常见文件
if not mono_path or not dual_path:
files = list(output_dir.glob("*.pdf"))
for file in files:
name = file.name.lower()
if ".mono.pdf" in name and not mono_path:
mono_path = str(file)
elif ".dual.pdf" in name and not dual_path:
dual_path = str(file)
await _transition_and_notify(
job_id,
"finish_ok",
progress=100.0,
message="Translation finished",
finished_at=storage.now_iso(),
mono_pdf_path=mono_path or None,
dual_pdf_path=dual_path or None,
glossary_path=glossary_path or None,
)
return
await _transition_and_notify(
job_id,
"finish_error",
error="Translation stream ended unexpectedly",
message="Translation failed",
finished_at=storage.now_iso(),
)
async def _run_single_job(job_id: str) -> None:
row = jobs.get_job_row(job_id)
if row is None:
return
if row["status"] != jobs.STATUS_QUEUED:
return
if row["cancel_requested"]:
await _transition_and_notify(
job_id,
"cancel_before_start",
message="Cancelled before start",
finished_at=storage.now_iso(),
)
return
username = row["username"]
await _transition_and_notify(
job_id,
"start",
started_at=storage.now_iso(),
message="Translation started",
progress=0.0,
)
_active_job_by_user[username] = job_id
input_path = Path(row["input_path"])
output_dir = Path(row["output_dir"])
try:
settings = _build_settings_for_job(row)
await asyncio.wait_for(
_consume_translation_stream(
job_id=job_id,
settings=settings,
input_path=input_path,
output_dir=output_dir,
),
timeout=MAX_JOB_RUNTIME_SECONDS,
)
except asyncio.TimeoutError:
logger.warning("Translation job timed out: job_id=%s", job_id)
await _transition_and_notify(
job_id,
"finish_error",
error="Translation timed out",
message="Translation timed out",
finished_at=storage.now_iso(),
)
except asyncio.CancelledError:
await _transition_and_notify(
job_id,
"cancel_running",
message="Cancelled by user",
finished_at=storage.now_iso(),
)
raise
except Exception as exc: # noqa: BLE001
logger.exception("Translation job failed: %s", job_id)
await _transition_and_notify(
job_id,
"finish_error",
error=str(exc),
message="Translation failed",
finished_at=storage.now_iso(),
)
finally:
if _active_job_by_user.get(username) == job_id:
_active_job_by_user.pop(username, None)
async def _job_worker() -> None:
logger.info("Job worker started")
while True:
job_id = await _job_queue.get()
try:
task = asyncio.create_task(_run_single_job(job_id), name=f"job-{job_id}")
_running_tasks[job_id] = task
await task
except asyncio.CancelledError:
raise
except Exception: # noqa: BLE001
logger.exception("Unhandled worker error for job=%s", job_id)
finally:
_running_tasks.pop(job_id, None)
_job_queue.task_done()
def _enqueue_pending_jobs() -> None:
# 服务重启后,正在运行中的任务标记失败。
restart_time = storage.now_iso()
storage.db_execute(
"""
UPDATE jobs
SET status='failed',
error='Service restarted while running',
message='Failed due to restart',
finished_at=?,
updated_at=?
WHERE status='running'
""",
(restart_time, restart_time),
)
rows = storage.db_fetchall(
"SELECT id FROM jobs WHERE status='queued' ORDER BY created_at ASC"
)
for row in rows:
_job_queue.put_nowait(row["id"])
async def _publish_job_event(job: dict[str, Any]) -> None:
"""将任务更新推送给所有订阅该用户的 SSE 连接。"""
username = job.get("username")
if not username:
return
payload = {
"id": job["id"],
"username": username,
"status": job.get("status"),
"progress": job.get("progress"),
"message": job.get("message"),
"error": job.get("error"),
"updated_at": job.get("updated_at"),
"artifact_urls": job.get("artifact_urls") or {},
"model": job.get("model"),
"filename": job.get("filename"),
"created_at": job.get("created_at"),
}
queues = list(_job_subscribers.get(username, ()))
for q in queues:
try:
q.put_nowait(payload)
except asyncio.QueueFull:
# 简单策略:丢弃最旧一条再塞新事件,防止阻塞 worker
try:
_ = q.get_nowait()
except asyncio.QueueEmpty:
pass
try:
q.put_nowait(payload)
except asyncio.QueueFull:
logger.warning(
"Dropping job event for user=%s job_id=%s due to full queue",
username,
job.get("id"),
)
async def _transition_and_notify(
job_id: str,
event: str,
**fields: Any,
) -> dict[str, Any] | None:
"""状态机迁移并推送事件给订阅者。"""
job = jobs.transition_job(job_id, event, **fields)
if job is not None:
await _publish_job_event(job)
else:
logger.warning(
"Invalid job transition: job_id=%s event=%s", job_id, event
)
return job
def _subscribe_user_jobs(username: str) -> asyncio.Queue[dict[str, Any]]:
"""注册一个用户的 SSE 订阅队列。"""
q: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=100)
_job_subscribers[username].add(q)
return q
def _unsubscribe_user_jobs(username: str, queue: asyncio.Queue[dict[str, Any]]) -> None:
"""取消用户的 SSE 订阅队列。"""
queues = _job_subscribers.get(username)
if not queues:
return
queues.discard(queue)
if not queues:
_job_subscribers.pop(username, None)
def _login_page(error: str = "") -> str:
"""渲染登录页 HTML。"""
tpl = load_template("login.html")
error_block = f'
{html.escape(error)}
' if error else ""
return tpl.replace("__ERROR_BLOCK__", error_block)
def _dashboard_page(username: str) -> str:
safe_user = html.escape(username)
safe_lang_in = html.escape(DEFAULT_LANG_IN)
safe_lang_out = html.escape(DEFAULT_LANG_OUT)
tpl = load_template("dashboard.html")
return (
tpl.replace("__USERNAME__", safe_user)
.replace("__LANG_IN__", safe_lang_in)
.replace("__LANG_OUT__", safe_lang_out)
)
# ── FastAPI App ───────────────────────────────────────────────────────────────
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
_http_client: httpx.AsyncClient | None = None
@app.on_event("startup")
async def _startup() -> None:
global _http_client, _worker_task
storage.init_db()
_enqueue_pending_jobs()
_http_client = httpx.AsyncClient(timeout=httpx.Timeout(180.0))
_worker_task = asyncio.create_task(_job_worker(), name="job-worker")
if not proxy.OPENAI_REAL_API_KEY:
logger.info(
"OPENAI_API_KEY is empty, non-routed OpenAI models will fail"
)
logger.info("Gateway started. Data dir: %s", storage.DATA_DIR)
@app.on_event("shutdown")
async def _shutdown() -> None:
global _worker_task, _http_client
if _worker_task:
_worker_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await _worker_task
_worker_task = None
for task in list(_running_tasks.values()):
task.cancel()
if _http_client:
await _http_client.aclose()
_http_client = None
storage.close_db()
# ── 路由:基础与认证(当前 Space 原型,不保证向后兼容) ─────────────────────
@app.get("/healthz")
async def healthz() -> Response:
return Response("ok", media_type="text/plain")
@app.get("/login", response_class=HTMLResponse)
async def login_page(request: Request) -> HTMLResponse:
if auth._get_session_user(request):
return RedirectResponse("/", status_code=302)
return HTMLResponse(_login_page())
@app.post("/login")
async def login(
request: Request,
username: str = Form(...),
password: str = Form(...),
) -> Response:
next_url = request.query_params.get("next", "/")
if auth._verify_credentials(username, password):
token = auth._make_session(username)
resp = RedirectResponse(next_url, status_code=303)
resp.set_cookie(
auth.SESSION_COOKIE,
token,
max_age=auth.SESSION_MAX_AGE,
httponly=True,
samesite="lax",
)
logger.info("Login successful: %s", username)
return resp
logger.warning("Login failed: %s", username)
return HTMLResponse(_login_page("用户名或密码错误。"), status_code=401)
@app.get("/logout")
async def logout() -> Response:
resp = RedirectResponse("/login", status_code=302)
resp.delete_cookie(auth.SESSION_COOKIE)
return resp
# ── 路由:页面渲染(HTML) ────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def index(request: Request) -> Response:
username = auth._get_session_user(request)
if not username:
return RedirectResponse("/login", status_code=302)
return HTMLResponse(_dashboard_page(username))
# ── 路由:任务 API ─────────────────────────────────────────────────────────────
@app.get("/api/me")
async def api_me(username: str = Depends(auth._require_user)) -> dict[str, str]:
return {"username": username}
@app.get("/api/jobs")
async def api_list_jobs(
limit: int = 50,
username: str = Depends(auth._require_user),
) -> dict[str, Any]:
limit = max(1, min(limit, 200))
jobs_list = jobs.get_jobs_for_user(username=username, limit=limit)
return {"jobs": jobs_list}
@app.get("/api/jobs/{job_id}")
async def api_get_job(
job_id: str,
username: str = Depends(auth._require_user),
) -> dict[str, Any]:
job = jobs.get_job_for_user(job_id=job_id, username=username)
if job is None:
raise HTTPException(status_code=404, detail="Job not found")
return {"job": job}
@app.post("/api/jobs")
async def api_create_job(
request: Request,
file: UploadFile = File(...),
lang_in: str = Form(DEFAULT_LANG_IN),
lang_out: str = Form(DEFAULT_LANG_OUT),
username: str = Depends(auth._require_user),
) -> dict[str, Any]:
filename = file.filename or "input.pdf"
if not filename.lower().endswith(".pdf"):
raise HTTPException(status_code=400, detail="仅支持 PDF 文件")
# 如果客户端提供了 Content-Length,可做一次粗略预检,避免明显超大的请求
content_length = request.headers.get("content-length")
if content_length:
try:
total_len = int(content_length)
except ValueError:
total_len = 0
if total_len > MAX_UPLOAD_BYTES * 2:
logger.warning(
"Upload rejected by Content-Length: username=%s size=%s limit=%s",
username,
total_len,
MAX_UPLOAD_BYTES,
)
raise HTTPException(
status_code=413,
detail=f"上传文件过大,最大 {MAX_UPLOAD_MB}MB",
)
job_id = uuid.uuid4().hex
safe_filename = Path(filename).name
input_path = (storage.UPLOAD_DIR / f"{job_id}.pdf").resolve()
output_dir = (storage.JOB_DIR / job_id).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
total_bytes = 0
too_large = False
try:
with input_path.open("wb") as f:
while True:
chunk = file.file.read(UPLOAD_CHUNK_SIZE)
if not chunk:
break
total_bytes += len(chunk)
if total_bytes > MAX_UPLOAD_BYTES:
too_large = True
break
f.write(chunk)
finally:
await file.close()
if too_large:
# 删除已写入的部分文件,避免残留
with contextlib.suppress(FileNotFoundError):
input_path.unlink()
logger.warning(
"Upload too large: username=%s job_id=%s size=%s limit=%s",
username,
job_id,
total_bytes,
MAX_UPLOAD_BYTES,
)
raise HTTPException(
status_code=413,
detail=f"上传文件过大,最大 {MAX_UPLOAD_MB}MB",
)
job_dict = jobs.create_job_record(
job_id=job_id,
username=username,
filename=safe_filename,
input_path=input_path,
output_dir=output_dir,
model=FIXED_TRANSLATION_MODEL,
lang_in=lang_in.strip() or DEFAULT_LANG_IN,
lang_out=lang_out.strip() or DEFAULT_LANG_OUT,
)
await _job_queue.put(job_id)
return {"job": job_dict}
@app.post("/api/jobs/{job_id}/cancel")
async def api_cancel_job(
job_id: str,
username: str = Depends(auth._require_user),
) -> dict[str, Any]:
row = jobs.get_job_row(job_id)
if row is None or row["username"] != username:
raise HTTPException(status_code=404, detail="Job not found")
status = row["status"]
if status in {
jobs.STATUS_SUCCEEDED,
jobs.STATUS_FAILED,
jobs.STATUS_CANCELLED,
}:
return {"status": status, "message": "Job already finished"}
jobs.update_job(job_id, cancel_requested=1, message="Cancel requested")
if status == jobs.STATUS_QUEUED:
await _transition_and_notify(
job_id,
"cancel_before_start",
finished_at=storage.now_iso(),
progress=0.0,
message="Job cancelled",
)
return {"status": "cancelled", "message": "Job cancelled"}
task = _running_tasks.get(job_id)
if task:
task.cancel()
return {"status": "cancelling", "message": "Cancellation requested"}
@app.get("/api/jobs/stream")
async def api_jobs_stream(
request: Request,
username: str = Depends(auth._require_user),
) -> StreamingResponse:
"""任务状态 SSE 推送,仅推送当前用户的任务更新。"""
queue = _subscribe_user_jobs(username)
async def event_generator() -> Any:
try:
while True:
if await request.is_disconnected():
break
try:
payload = await asyncio.wait_for(queue.get(), timeout=15)
except asyncio.TimeoutError:
# SSE 心跳:保持连接活跃,同时让 is_disconnected() 有机会检测断开
yield ": heartbeat\n\n"
continue
yield f"data: {json.dumps(payload)}\n\n"
except asyncio.CancelledError:
logger.info("SSE connection cancelled for user=%s", username)
raise
finally:
_unsubscribe_user_jobs(username, queue)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache"},
)
@app.get("/api/jobs/{job_id}/artifacts/{artifact_type}")
async def api_download_artifact(
job_id: str,
artifact_type: str,
username: str = Depends(auth._require_user),
) -> Response:
row = jobs.get_job_row(job_id)
if row is not None and row["username"] != username:
row = None
if row is None:
raise HTTPException(status_code=404, detail="Job not found")
col_map = {
"mono": "mono_pdf_path",
"dual": "dual_pdf_path",
"glossary": "glossary_path",
}
column = col_map.get(artifact_type)
if column is None:
raise HTTPException(status_code=404, detail="Unknown artifact")
output_dir = Path(row["output_dir"]).resolve()
path = jobs.resolve_artifact_path(row[column], output_dir)
if path is None:
raise HTTPException(status_code=404, detail="Artifact not found")
return FileResponse(path)
# ── 路由:计费 API ─────────────────────────────────────────────────────────────
@app.get("/api/billing/me")
async def api_billing_summary(
username: str = Depends(auth._require_user),
) -> dict[str, Any]:
return billing.get_billing_summary(username)
@app.get("/api/billing/me/records")
async def api_billing_records(
limit: int = 50,
username: str = Depends(auth._require_user),
) -> dict[str, Any]:
limit = max(1, min(limit, 200))
records = billing.get_billing_records(username=username, limit=limit)
return {"records": records}
@app.post("/internal/openai/v1/chat/completions")
async def internal_openai_chat_completions(request: Request) -> Response:
return await proxy.handle_internal_chat_completions(
request=request,
http_client=_http_client,
active_job_by_user=_active_job_by_user,
)
# ── 路由:静态资源 ─────────────────────────────────────────────────────────────
@app.get("/static/dashboard.js")
async def dashboard_js() -> FileResponse:
"""提供控制台前端脚本。"""
path = get_static_path("dashboard.js")
return FileResponse(path, media_type="application/javascript")