#!/usr/bin/env python3 """FastAPI 应用:登录鉴权、自研 GUI、翻译任务、内部 OpenAI 代理与计费。""" from __future__ import annotations import asyncio import contextlib import html import json import logging import os import shutil import uuid from collections import defaultdict from pathlib import Path from typing import Any import httpx from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi.responses import ( FileResponse, HTMLResponse, RedirectResponse, Response, StreamingResponse, ) from pdf2zh_next import BasicSettings from pdf2zh_next import OpenAISettings from pdf2zh_next import PDFSettings from pdf2zh_next import SettingsModel from pdf2zh_next import TranslationSettings from pdf2zh_next.high_level import do_translate_async_stream import auth import billing import jobs import proxy import storage from web.template_loader import get_static_path, load_template # ── 配置 ────────────────────────────────────────────────────────────────────── INTERNAL_OPENAI_BASE_URL = os.environ.get( "INTERNAL_OPENAI_BASE_URL", "http://127.0.0.1:7860/internal/openai/v1" ) FIXED_TRANSLATION_MODEL = "SiliconFlowFree" DEFAULT_LANG_IN = os.environ.get("DEFAULT_LANG_IN", "en").strip() DEFAULT_LANG_OUT = os.environ.get("DEFAULT_LANG_OUT", "zh").strip() TRANSLATION_QPS = int(os.environ.get("TRANSLATION_QPS", "4")) # 上传与任务执行约束配置 MAX_UPLOAD_MB = int(os.environ.get("MAX_UPLOAD_MB", "100")) MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024 MAX_JOB_RUNTIME_SECONDS = int(os.environ.get("MAX_JOB_RUNTIME_SECONDS", "7200")) UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB per chunk logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s - %(message)s", ) logger = logging.getLogger("gateway") # ── 任务执行 ─────────────────────────────────────────────────────────────────── _job_queue: asyncio.Queue[str] = asyncio.Queue() _worker_task: asyncio.Task[None] | None = None _running_tasks: dict[str, asyncio.Task[None]] = {} _active_job_by_user: dict[str, str] = {} _job_subscribers: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set) def _build_settings_for_job(row: sqlite3.Row) -> SettingsModel: username = row["username"] internal_key = auth._make_internal_api_key(username) settings = SettingsModel( basic=BasicSettings(debug=False, gui=False), translation=TranslationSettings( lang_in=row["lang_in"], lang_out=row["lang_out"], output=row["output_dir"], qps=TRANSLATION_QPS, ), pdf=PDFSettings(), translate_engine_settings=OpenAISettings( openai_model=row["model"], openai_base_url=INTERNAL_OPENAI_BASE_URL, openai_api_key=internal_key, ), ) settings.validate_settings() return settings async def _consume_translation_stream( job_id: str, settings: SettingsModel, input_path: Path, output_dir: Path, ) -> None: """消费翻译流事件并驱动任务状态机。 注意:本函数不负责超时控制,由上层通过 asyncio.wait_for 约束最大执行时长。 """ async for event in do_translate_async_stream(settings, input_path): event_type = event.get("type") if event_type in {"progress_start", "progress_update", "progress_end"}: progress = float(event.get("overall_progress", 0.0)) stage = event.get("stage", "") await _transition_and_notify( job_id, "progress", progress=max(0.0, min(100.0, progress)), message=f"{stage}" if stage else "Running", ) elif event_type == "error": error_msg = str(event.get("error", "Unknown translation error")) await _transition_and_notify( job_id, "finish_error", error=error_msg, message="Translation failed", finished_at=storage.now_iso(), ) return elif event_type == "finish": result = event.get("translate_result") mono_path = str(getattr(result, "mono_pdf_path", "") or "") dual_path = str(getattr(result, "dual_pdf_path", "") or "") glossary_path = str( getattr(result, "auto_extracted_glossary_path", "") or "" ) # 兜底:如果路径为空,尝试在输出目录中扫描常见文件 if not mono_path or not dual_path: files = list(output_dir.glob("*.pdf")) for file in files: name = file.name.lower() if ".mono.pdf" in name and not mono_path: mono_path = str(file) elif ".dual.pdf" in name and not dual_path: dual_path = str(file) await _transition_and_notify( job_id, "finish_ok", progress=100.0, message="Translation finished", finished_at=storage.now_iso(), mono_pdf_path=mono_path or None, dual_pdf_path=dual_path or None, glossary_path=glossary_path or None, ) return await _transition_and_notify( job_id, "finish_error", error="Translation stream ended unexpectedly", message="Translation failed", finished_at=storage.now_iso(), ) async def _run_single_job(job_id: str) -> None: row = jobs.get_job_row(job_id) if row is None: return if row["status"] != jobs.STATUS_QUEUED: return if row["cancel_requested"]: await _transition_and_notify( job_id, "cancel_before_start", message="Cancelled before start", finished_at=storage.now_iso(), ) return username = row["username"] await _transition_and_notify( job_id, "start", started_at=storage.now_iso(), message="Translation started", progress=0.0, ) _active_job_by_user[username] = job_id input_path = Path(row["input_path"]) output_dir = Path(row["output_dir"]) try: settings = _build_settings_for_job(row) await asyncio.wait_for( _consume_translation_stream( job_id=job_id, settings=settings, input_path=input_path, output_dir=output_dir, ), timeout=MAX_JOB_RUNTIME_SECONDS, ) except asyncio.TimeoutError: logger.warning("Translation job timed out: job_id=%s", job_id) await _transition_and_notify( job_id, "finish_error", error="Translation timed out", message="Translation timed out", finished_at=storage.now_iso(), ) except asyncio.CancelledError: await _transition_and_notify( job_id, "cancel_running", message="Cancelled by user", finished_at=storage.now_iso(), ) raise except Exception as exc: # noqa: BLE001 logger.exception("Translation job failed: %s", job_id) await _transition_and_notify( job_id, "finish_error", error=str(exc), message="Translation failed", finished_at=storage.now_iso(), ) finally: if _active_job_by_user.get(username) == job_id: _active_job_by_user.pop(username, None) async def _job_worker() -> None: logger.info("Job worker started") while True: job_id = await _job_queue.get() try: task = asyncio.create_task(_run_single_job(job_id), name=f"job-{job_id}") _running_tasks[job_id] = task await task except asyncio.CancelledError: raise except Exception: # noqa: BLE001 logger.exception("Unhandled worker error for job=%s", job_id) finally: _running_tasks.pop(job_id, None) _job_queue.task_done() def _enqueue_pending_jobs() -> None: # 服务重启后,正在运行中的任务标记失败。 restart_time = storage.now_iso() storage.db_execute( """ UPDATE jobs SET status='failed', error='Service restarted while running', message='Failed due to restart', finished_at=?, updated_at=? WHERE status='running' """, (restart_time, restart_time), ) rows = storage.db_fetchall( "SELECT id FROM jobs WHERE status='queued' ORDER BY created_at ASC" ) for row in rows: _job_queue.put_nowait(row["id"]) async def _publish_job_event(job: dict[str, Any]) -> None: """将任务更新推送给所有订阅该用户的 SSE 连接。""" username = job.get("username") if not username: return payload = { "id": job["id"], "username": username, "status": job.get("status"), "progress": job.get("progress"), "message": job.get("message"), "error": job.get("error"), "updated_at": job.get("updated_at"), "artifact_urls": job.get("artifact_urls") or {}, "model": job.get("model"), "filename": job.get("filename"), "created_at": job.get("created_at"), } queues = list(_job_subscribers.get(username, ())) for q in queues: try: q.put_nowait(payload) except asyncio.QueueFull: # 简单策略:丢弃最旧一条再塞新事件,防止阻塞 worker try: _ = q.get_nowait() except asyncio.QueueEmpty: pass try: q.put_nowait(payload) except asyncio.QueueFull: logger.warning( "Dropping job event for user=%s job_id=%s due to full queue", username, job.get("id"), ) async def _transition_and_notify( job_id: str, event: str, **fields: Any, ) -> dict[str, Any] | None: """状态机迁移并推送事件给订阅者。""" job = jobs.transition_job(job_id, event, **fields) if job is not None: await _publish_job_event(job) else: logger.warning( "Invalid job transition: job_id=%s event=%s", job_id, event ) return job def _subscribe_user_jobs(username: str) -> asyncio.Queue[dict[str, Any]]: """注册一个用户的 SSE 订阅队列。""" q: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=100) _job_subscribers[username].add(q) return q def _unsubscribe_user_jobs(username: str, queue: asyncio.Queue[dict[str, Any]]) -> None: """取消用户的 SSE 订阅队列。""" queues = _job_subscribers.get(username) if not queues: return queues.discard(queue) if not queues: _job_subscribers.pop(username, None) def _login_page(error: str = "") -> str: """渲染登录页 HTML。""" tpl = load_template("login.html") error_block = f'
{html.escape(error)}
' if error else "" return tpl.replace("__ERROR_BLOCK__", error_block) def _dashboard_page(username: str) -> str: safe_user = html.escape(username) safe_lang_in = html.escape(DEFAULT_LANG_IN) safe_lang_out = html.escape(DEFAULT_LANG_OUT) tpl = load_template("dashboard.html") return ( tpl.replace("__USERNAME__", safe_user) .replace("__LANG_IN__", safe_lang_in) .replace("__LANG_OUT__", safe_lang_out) ) # ── FastAPI App ─────────────────────────────────────────────────────────────── app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) _http_client: httpx.AsyncClient | None = None @app.on_event("startup") async def _startup() -> None: global _http_client, _worker_task storage.init_db() _enqueue_pending_jobs() _http_client = httpx.AsyncClient(timeout=httpx.Timeout(180.0)) _worker_task = asyncio.create_task(_job_worker(), name="job-worker") if not proxy.OPENAI_REAL_API_KEY: logger.info( "OPENAI_API_KEY is empty, non-routed OpenAI models will fail" ) logger.info("Gateway started. Data dir: %s", storage.DATA_DIR) @app.on_event("shutdown") async def _shutdown() -> None: global _worker_task, _http_client if _worker_task: _worker_task.cancel() with contextlib.suppress(asyncio.CancelledError): await _worker_task _worker_task = None for task in list(_running_tasks.values()): task.cancel() if _http_client: await _http_client.aclose() _http_client = None storage.close_db() # ── 路由:基础与认证(当前 Space 原型,不保证向后兼容) ───────────────────── @app.get("/healthz") async def healthz() -> Response: return Response("ok", media_type="text/plain") @app.get("/login", response_class=HTMLResponse) async def login_page(request: Request) -> HTMLResponse: if auth._get_session_user(request): return RedirectResponse("/", status_code=302) return HTMLResponse(_login_page()) @app.post("/login") async def login( request: Request, username: str = Form(...), password: str = Form(...), ) -> Response: next_url = request.query_params.get("next", "/") if auth._verify_credentials(username, password): token = auth._make_session(username) resp = RedirectResponse(next_url, status_code=303) resp.set_cookie( auth.SESSION_COOKIE, token, max_age=auth.SESSION_MAX_AGE, httponly=True, samesite="lax", ) logger.info("Login successful: %s", username) return resp logger.warning("Login failed: %s", username) return HTMLResponse(_login_page("用户名或密码错误。"), status_code=401) @app.get("/logout") async def logout() -> Response: resp = RedirectResponse("/login", status_code=302) resp.delete_cookie(auth.SESSION_COOKIE) return resp # ── 路由:页面渲染(HTML) ──────────────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def index(request: Request) -> Response: username = auth._get_session_user(request) if not username: return RedirectResponse("/login", status_code=302) return HTMLResponse(_dashboard_page(username)) # ── 路由:任务 API ───────────────────────────────────────────────────────────── @app.get("/api/me") async def api_me(username: str = Depends(auth._require_user)) -> dict[str, str]: return {"username": username} @app.get("/api/jobs") async def api_list_jobs( limit: int = 50, username: str = Depends(auth._require_user), ) -> dict[str, Any]: limit = max(1, min(limit, 200)) jobs_list = jobs.get_jobs_for_user(username=username, limit=limit) return {"jobs": jobs_list} @app.get("/api/jobs/{job_id}") async def api_get_job( job_id: str, username: str = Depends(auth._require_user), ) -> dict[str, Any]: job = jobs.get_job_for_user(job_id=job_id, username=username) if job is None: raise HTTPException(status_code=404, detail="Job not found") return {"job": job} @app.post("/api/jobs") async def api_create_job( request: Request, file: UploadFile = File(...), lang_in: str = Form(DEFAULT_LANG_IN), lang_out: str = Form(DEFAULT_LANG_OUT), username: str = Depends(auth._require_user), ) -> dict[str, Any]: filename = file.filename or "input.pdf" if not filename.lower().endswith(".pdf"): raise HTTPException(status_code=400, detail="仅支持 PDF 文件") # 如果客户端提供了 Content-Length,可做一次粗略预检,避免明显超大的请求 content_length = request.headers.get("content-length") if content_length: try: total_len = int(content_length) except ValueError: total_len = 0 if total_len > MAX_UPLOAD_BYTES * 2: logger.warning( "Upload rejected by Content-Length: username=%s size=%s limit=%s", username, total_len, MAX_UPLOAD_BYTES, ) raise HTTPException( status_code=413, detail=f"上传文件过大,最大 {MAX_UPLOAD_MB}MB", ) job_id = uuid.uuid4().hex safe_filename = Path(filename).name input_path = (storage.UPLOAD_DIR / f"{job_id}.pdf").resolve() output_dir = (storage.JOB_DIR / job_id).resolve() output_dir.mkdir(parents=True, exist_ok=True) total_bytes = 0 too_large = False try: with input_path.open("wb") as f: while True: chunk = file.file.read(UPLOAD_CHUNK_SIZE) if not chunk: break total_bytes += len(chunk) if total_bytes > MAX_UPLOAD_BYTES: too_large = True break f.write(chunk) finally: await file.close() if too_large: # 删除已写入的部分文件,避免残留 with contextlib.suppress(FileNotFoundError): input_path.unlink() logger.warning( "Upload too large: username=%s job_id=%s size=%s limit=%s", username, job_id, total_bytes, MAX_UPLOAD_BYTES, ) raise HTTPException( status_code=413, detail=f"上传文件过大,最大 {MAX_UPLOAD_MB}MB", ) job_dict = jobs.create_job_record( job_id=job_id, username=username, filename=safe_filename, input_path=input_path, output_dir=output_dir, model=FIXED_TRANSLATION_MODEL, lang_in=lang_in.strip() or DEFAULT_LANG_IN, lang_out=lang_out.strip() or DEFAULT_LANG_OUT, ) await _job_queue.put(job_id) return {"job": job_dict} @app.post("/api/jobs/{job_id}/cancel") async def api_cancel_job( job_id: str, username: str = Depends(auth._require_user), ) -> dict[str, Any]: row = jobs.get_job_row(job_id) if row is None or row["username"] != username: raise HTTPException(status_code=404, detail="Job not found") status = row["status"] if status in { jobs.STATUS_SUCCEEDED, jobs.STATUS_FAILED, jobs.STATUS_CANCELLED, }: return {"status": status, "message": "Job already finished"} jobs.update_job(job_id, cancel_requested=1, message="Cancel requested") if status == jobs.STATUS_QUEUED: await _transition_and_notify( job_id, "cancel_before_start", finished_at=storage.now_iso(), progress=0.0, message="Job cancelled", ) return {"status": "cancelled", "message": "Job cancelled"} task = _running_tasks.get(job_id) if task: task.cancel() return {"status": "cancelling", "message": "Cancellation requested"} @app.get("/api/jobs/stream") async def api_jobs_stream( request: Request, username: str = Depends(auth._require_user), ) -> StreamingResponse: """任务状态 SSE 推送,仅推送当前用户的任务更新。""" queue = _subscribe_user_jobs(username) async def event_generator() -> Any: try: while True: if await request.is_disconnected(): break try: payload = await asyncio.wait_for(queue.get(), timeout=15) except asyncio.TimeoutError: # SSE 心跳:保持连接活跃,同时让 is_disconnected() 有机会检测断开 yield ": heartbeat\n\n" continue yield f"data: {json.dumps(payload)}\n\n" except asyncio.CancelledError: logger.info("SSE connection cancelled for user=%s", username) raise finally: _unsubscribe_user_jobs(username, queue) return StreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"}, ) @app.get("/api/jobs/{job_id}/artifacts/{artifact_type}") async def api_download_artifact( job_id: str, artifact_type: str, username: str = Depends(auth._require_user), ) -> Response: row = jobs.get_job_row(job_id) if row is not None and row["username"] != username: row = None if row is None: raise HTTPException(status_code=404, detail="Job not found") col_map = { "mono": "mono_pdf_path", "dual": "dual_pdf_path", "glossary": "glossary_path", } column = col_map.get(artifact_type) if column is None: raise HTTPException(status_code=404, detail="Unknown artifact") output_dir = Path(row["output_dir"]).resolve() path = jobs.resolve_artifact_path(row[column], output_dir) if path is None: raise HTTPException(status_code=404, detail="Artifact not found") return FileResponse(path) # ── 路由:计费 API ───────────────────────────────────────────────────────────── @app.get("/api/billing/me") async def api_billing_summary( username: str = Depends(auth._require_user), ) -> dict[str, Any]: return billing.get_billing_summary(username) @app.get("/api/billing/me/records") async def api_billing_records( limit: int = 50, username: str = Depends(auth._require_user), ) -> dict[str, Any]: limit = max(1, min(limit, 200)) records = billing.get_billing_records(username=username, limit=limit) return {"records": records} @app.post("/internal/openai/v1/chat/completions") async def internal_openai_chat_completions(request: Request) -> Response: return await proxy.handle_internal_chat_completions( request=request, http_client=_http_client, active_job_by_user=_active_job_by_user, ) # ── 路由:静态资源 ───────────────────────────────────────────────────────────── @app.get("/static/dashboard.js") async def dashboard_js() -> FileResponse: """提供控制台前端脚本。""" path = get_static_path("dashboard.js") return FileResponse(path, media_type="application/javascript")