BirkhoffLee commited on
Commit
b2d8381
·
unverified ·
1 Parent(s): d3a7520

refactor: 完成了阶段二的修复

Browse files
Files changed (3) hide show
  1. src/gateway.py +144 -21
  2. src/jobs.py +96 -1
  3. src/web/static/dashboard.js +75 -4
src/gateway.py CHANGED
@@ -6,16 +6,24 @@ from __future__ import annotations
6
  import asyncio
7
  import contextlib
8
  import html
 
9
  import logging
10
  import os
11
  import shutil
12
  import uuid
 
13
  from pathlib import Path
14
  from typing import Any
15
 
16
  import httpx
17
  from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
18
- from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse, Response
 
 
 
 
 
 
19
  from pdf2zh_next import BasicSettings
20
  from pdf2zh_next import OpenAISettings
21
  from pdf2zh_next import PDFSettings
@@ -53,6 +61,7 @@ _job_queue: asyncio.Queue[str] = asyncio.Queue()
53
  _worker_task: asyncio.Task[None] | None = None
54
  _running_tasks: dict[str, asyncio.Task[None]] = {}
55
  _active_job_by_user: dict[str, str] = {}
 
56
 
57
 
58
  def _build_settings_for_job(row: sqlite3.Row) -> SettingsModel:
@@ -82,16 +91,21 @@ async def _run_single_job(job_id: str) -> None:
82
  row = jobs.get_job_row(job_id)
83
  if row is None:
84
  return
85
- if row["status"] != "queued":
86
  return
87
  if row["cancel_requested"]:
88
- _update_job(job_id, status="cancelled", message="Cancelled before start")
 
 
 
 
 
89
  return
90
 
91
  username = row["username"]
92
- jobs.update_job(
93
  job_id,
94
- status="running",
95
  started_at=storage.now_iso(),
96
  message="Translation started",
97
  progress=0.0,
@@ -108,16 +122,17 @@ async def _run_single_job(job_id: str) -> None:
108
  if event_type in {"progress_start", "progress_update", "progress_end"}:
109
  progress = float(event.get("overall_progress", 0.0))
110
  stage = event.get("stage", "")
111
- jobs.update_job(
112
  job_id,
 
113
  progress=max(0.0, min(100.0, progress)),
114
  message=f"{stage}" if stage else "Running",
115
  )
116
  elif event_type == "error":
117
  error_msg = str(event.get("error", "Unknown translation error"))
118
- jobs.update_job(
119
  job_id,
120
- status="failed",
121
  error=error_msg,
122
  message="Translation failed",
123
  finished_at=storage.now_iso(),
@@ -141,9 +156,9 @@ async def _run_single_job(job_id: str) -> None:
141
  elif ".dual.pdf" in name and not dual_path:
142
  dual_path = str(file)
143
 
144
- jobs.update_job(
145
  job_id,
146
- status="succeeded",
147
  progress=100.0,
148
  message="Translation finished",
149
  finished_at=storage.now_iso(),
@@ -153,26 +168,26 @@ async def _run_single_job(job_id: str) -> None:
153
  )
154
  return
155
 
156
- jobs.update_job(
157
  job_id,
158
- status="failed",
159
  error="Translation stream ended unexpectedly",
160
  message="Translation failed",
161
  finished_at=storage.now_iso(),
162
- )
163
  except asyncio.CancelledError:
164
- jobs.update_job(
165
  job_id,
166
- status="cancelled",
167
  message="Cancelled by user",
168
  finished_at=storage.now_iso(),
169
  )
170
  raise
171
  except Exception as exc: # noqa: BLE001
172
  logger.exception("Translation job failed: %s", job_id)
173
- jobs.update_job(
174
  job_id,
175
- status="failed",
176
  error=str(exc),
177
  message="Translation failed",
178
  finished_at=storage.now_iso(),
@@ -222,6 +237,79 @@ def _enqueue_pending_jobs() -> None:
222
  _job_queue.put_nowait(row["id"])
223
 
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def _login_page(error: str = "") -> str:
226
  """渲染登录页 HTML。"""
227
  tpl = load_template("login.html")
@@ -413,16 +501,21 @@ async def api_cancel_job(
413
  raise HTTPException(status_code=404, detail="Job not found")
414
 
415
  status = row["status"]
416
- if status in {"succeeded", "failed", "cancelled"}:
 
 
 
 
417
  return {"status": status, "message": "Job already finished"}
418
 
419
  jobs.update_job(job_id, cancel_requested=1, message="Cancel requested")
420
- if status == "queued":
421
- jobs.update_job(
422
  job_id,
423
- status="cancelled",
424
  finished_at=storage.now_iso(),
425
  progress=0.0,
 
426
  )
427
  return {"status": "cancelled", "message": "Job cancelled"}
428
 
@@ -433,6 +526,36 @@ async def api_cancel_job(
433
  return {"status": "cancelling", "message": "Cancellation requested"}
434
 
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  @app.get("/api/jobs/{job_id}/artifacts/{artifact_type}")
437
  async def api_download_artifact(
438
  job_id: str,
 
6
  import asyncio
7
  import contextlib
8
  import html
9
+ import json
10
  import logging
11
  import os
12
  import shutil
13
  import uuid
14
+ from collections import defaultdict
15
  from pathlib import Path
16
  from typing import Any
17
 
18
  import httpx
19
  from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
20
+ from fastapi.responses import (
21
+ FileResponse,
22
+ HTMLResponse,
23
+ RedirectResponse,
24
+ Response,
25
+ StreamingResponse,
26
+ )
27
  from pdf2zh_next import BasicSettings
28
  from pdf2zh_next import OpenAISettings
29
  from pdf2zh_next import PDFSettings
 
61
  _worker_task: asyncio.Task[None] | None = None
62
  _running_tasks: dict[str, asyncio.Task[None]] = {}
63
  _active_job_by_user: dict[str, str] = {}
64
+ _job_subscribers: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set)
65
 
66
 
67
  def _build_settings_for_job(row: sqlite3.Row) -> SettingsModel:
 
91
  row = jobs.get_job_row(job_id)
92
  if row is None:
93
  return
94
+ if row["status"] != jobs.STATUS_QUEUED:
95
  return
96
  if row["cancel_requested"]:
97
+ await _transition_and_notify(
98
+ job_id,
99
+ "cancel_before_start",
100
+ message="Cancelled before start",
101
+ finished_at=storage.now_iso(),
102
+ )
103
  return
104
 
105
  username = row["username"]
106
+ await _transition_and_notify(
107
  job_id,
108
+ "start",
109
  started_at=storage.now_iso(),
110
  message="Translation started",
111
  progress=0.0,
 
122
  if event_type in {"progress_start", "progress_update", "progress_end"}:
123
  progress = float(event.get("overall_progress", 0.0))
124
  stage = event.get("stage", "")
125
+ await _transition_and_notify(
126
  job_id,
127
+ "progress",
128
  progress=max(0.0, min(100.0, progress)),
129
  message=f"{stage}" if stage else "Running",
130
  )
131
  elif event_type == "error":
132
  error_msg = str(event.get("error", "Unknown translation error"))
133
+ await _transition_and_notify(
134
  job_id,
135
+ "finish_error",
136
  error=error_msg,
137
  message="Translation failed",
138
  finished_at=storage.now_iso(),
 
156
  elif ".dual.pdf" in name and not dual_path:
157
  dual_path = str(file)
158
 
159
+ await _transition_and_notify(
160
  job_id,
161
+ "finish_ok",
162
  progress=100.0,
163
  message="Translation finished",
164
  finished_at=storage.now_iso(),
 
168
  )
169
  return
170
 
171
+ await _transition_and_notify(
172
  job_id,
173
+ "finish_error",
174
  error="Translation stream ended unexpectedly",
175
  message="Translation failed",
176
  finished_at=storage.now_iso(),
177
+ )
178
  except asyncio.CancelledError:
179
+ await _transition_and_notify(
180
  job_id,
181
+ "cancel_running",
182
  message="Cancelled by user",
183
  finished_at=storage.now_iso(),
184
  )
185
  raise
186
  except Exception as exc: # noqa: BLE001
187
  logger.exception("Translation job failed: %s", job_id)
188
+ await _transition_and_notify(
189
  job_id,
190
+ "finish_error",
191
  error=str(exc),
192
  message="Translation failed",
193
  finished_at=storage.now_iso(),
 
237
  _job_queue.put_nowait(row["id"])
238
 
239
 
240
+ async def _publish_job_event(job: dict[str, Any]) -> None:
241
+ """将任务更新推送给所有订阅该用户的 SSE 连接。"""
242
+ username = job.get("username")
243
+ if not username:
244
+ return
245
+
246
+ payload = {
247
+ "id": job["id"],
248
+ "username": username,
249
+ "status": job.get("status"),
250
+ "progress": job.get("progress"),
251
+ "message": job.get("message"),
252
+ "error": job.get("error"),
253
+ "updated_at": job.get("updated_at"),
254
+ "artifact_urls": job.get("artifact_urls") or {},
255
+ "model": job.get("model"),
256
+ "filename": job.get("filename"),
257
+ "created_at": job.get("created_at"),
258
+ }
259
+
260
+ queues = list(_job_subscribers.get(username, ()))
261
+ for q in queues:
262
+ try:
263
+ q.put_nowait(payload)
264
+ except asyncio.QueueFull:
265
+ # 简单策略:丢弃最旧一条再塞新事件,防止阻塞 worker
266
+ try:
267
+ _ = q.get_nowait()
268
+ except asyncio.QueueEmpty:
269
+ pass
270
+ try:
271
+ q.put_nowait(payload)
272
+ except asyncio.QueueFull:
273
+ logger.warning(
274
+ "Dropping job event for user=%s job_id=%s due to full queue",
275
+ username,
276
+ job.get("id"),
277
+ )
278
+
279
+
280
+ async def _transition_and_notify(
281
+ job_id: str,
282
+ event: str,
283
+ **fields: Any,
284
+ ) -> dict[str, Any] | None:
285
+ """状态机迁移并推送事件给订阅者。"""
286
+ job = jobs.transition_job(job_id, event, **fields)
287
+ if job is not None:
288
+ await _publish_job_event(job)
289
+ else:
290
+ logger.warning(
291
+ "Invalid job transition: job_id=%s event=%s", job_id, event
292
+ )
293
+ return job
294
+
295
+
296
+ def _subscribe_user_jobs(username: str) -> asyncio.Queue[dict[str, Any]]:
297
+ """注册一个用户的 SSE 订阅队列。"""
298
+ q: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=100)
299
+ _job_subscribers[username].add(q)
300
+ return q
301
+
302
+
303
+ def _unsubscribe_user_jobs(username: str, queue: asyncio.Queue[dict[str, Any]]) -> None:
304
+ """取消用户的 SSE 订阅队列。"""
305
+ queues = _job_subscribers.get(username)
306
+ if not queues:
307
+ return
308
+ queues.discard(queue)
309
+ if not queues:
310
+ _job_subscribers.pop(username, None)
311
+
312
+
313
  def _login_page(error: str = "") -> str:
314
  """渲染登录页 HTML。"""
315
  tpl = load_template("login.html")
 
501
  raise HTTPException(status_code=404, detail="Job not found")
502
 
503
  status = row["status"]
504
+ if status in {
505
+ jobs.STATUS_SUCCEEDED,
506
+ jobs.STATUS_FAILED,
507
+ jobs.STATUS_CANCELLED,
508
+ }:
509
  return {"status": status, "message": "Job already finished"}
510
 
511
  jobs.update_job(job_id, cancel_requested=1, message="Cancel requested")
512
+ if status == jobs.STATUS_QUEUED:
513
+ await _transition_and_notify(
514
  job_id,
515
+ "cancel_before_start",
516
  finished_at=storage.now_iso(),
517
  progress=0.0,
518
+ message="Job cancelled",
519
  )
520
  return {"status": "cancelled", "message": "Job cancelled"}
521
 
 
526
  return {"status": "cancelling", "message": "Cancellation requested"}
527
 
528
 
529
+ @app.get("/api/jobs/stream")
530
+ async def api_jobs_stream(
531
+ request: Request,
532
+ username: str = Depends(auth._require_user),
533
+ ) -> StreamingResponse:
534
+ """任务状态 SSE 推送,仅推送当前用户的任务更新。"""
535
+
536
+ queue = _subscribe_user_jobs(username)
537
+
538
+ async def event_generator() -> Any:
539
+ try:
540
+ while True:
541
+ # 支持客户端主动断开
542
+ if await request.is_disconnected():
543
+ break
544
+ payload = await queue.get()
545
+ yield f"data: {json.dumps(payload)}\n\n"
546
+ except asyncio.CancelledError:
547
+ logger.info("Job SSE connection cancelled for user=%s", username)
548
+ raise
549
+ finally:
550
+ _unsubscribe_user_jobs(username, queue)
551
+
552
+ return StreamingResponse(
553
+ event_generator(),
554
+ media_type="text/event-stream",
555
+ headers={"Cache-Control": "no-cache"},
556
+ )
557
+
558
+
559
  @app.get("/api/jobs/{job_id}/artifacts/{artifact_type}")
560
  async def api_download_artifact(
561
  job_id: str,
src/jobs.py CHANGED
@@ -10,6 +10,24 @@ import sqlite3
10
  import storage
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def row_to_job_dict(row: sqlite3.Row) -> dict[str, Any]:
14
  """将任务行转换为对外暴露的字典结构。"""
15
  job = dict(row)
@@ -28,7 +46,11 @@ def row_to_job_dict(row: sqlite3.Row) -> dict[str, Any]:
28
 
29
 
30
  def update_job(job_id: str, **fields: Any) -> None:
31
- """更新任务记录指定字段。"""
 
 
 
 
32
  if not fields:
33
  return
34
  fields["updated_at"] = storage.now_iso()
@@ -133,3 +155,76 @@ def resolve_artifact_path(raw_path: str | None, output_dir: Path) -> Path | None
133
  return None
134
  return path
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import storage
11
 
12
 
13
+ # ── 任务状态与状态机 ────────────────────────────────────────────────────────────
14
+
15
+ # 约定的任务状态枚举,避免在业务层随意写字符串
16
+ STATUS_QUEUED = "queued"
17
+ STATUS_RUNNING = "running"
18
+ STATUS_SUCCEEDED = "succeeded"
19
+ STATUS_FAILED = "failed"
20
+ STATUS_CANCELLED = "cancelled"
21
+
22
+ ALLOWED_STATUSES: set[str] = {
23
+ STATUS_QUEUED,
24
+ STATUS_RUNNING,
25
+ STATUS_SUCCEEDED,
26
+ STATUS_FAILED,
27
+ STATUS_CANCELLED,
28
+ }
29
+
30
+
31
  def row_to_job_dict(row: sqlite3.Row) -> dict[str, Any]:
32
  """将任务行转换为对外暴露的字典结构。"""
33
  job = dict(row)
 
46
 
47
 
48
  def update_job(job_id: str, **fields: Any) -> None:
49
+ """更新任务记录指定字段。
50
+
51
+ 注意:业务代码应该优先通过 transition_job 做状态机驱动更新,
52
+ 直接调用本函数仅用于与状态无关的字段(例如 cancel_requested)。
53
+ """
54
  if not fields:
55
  return
56
  fields["updated_at"] = storage.now_iso()
 
155
  return None
156
  return path
157
 
158
+
159
+ def transition_job(job_id: str, event: str, **extra_fields: Any) -> dict[str, Any] | None:
160
+ """基于事件驱动的任务状态迁移。
161
+
162
+ 这里只负责:
163
+ * 校验当前状态是否允许执行给定事件
164
+ * 决定目标状态(如果有)
165
+ * 写入数据库
166
+ * 返回更新后的任务字典(用于推送给前端)
167
+
168
+ 状态枚举固定为 queued/running/succeeded/failed/cancelled,避免状态空间爆炸。
169
+ """
170
+ row = get_job_row(job_id)
171
+ if row is None:
172
+ return None
173
+
174
+ current_status = row["status"]
175
+ if current_status not in ALLOWED_STATUSES:
176
+ # 非法状态一律拒绝迁移,由调用方记录日志
177
+ return None
178
+
179
+ # 简单的事件 -> 允许来源状态集合、目标状态映射
180
+ # 对于 progress 这类事件,目标状态为 None,只更新进度等字段。
181
+ transitions: dict[str, dict[str, Any]] = {
182
+ "start": {
183
+ "from": {STATUS_QUEUED},
184
+ "to": STATUS_RUNNING,
185
+ },
186
+ "progress": {
187
+ "from": {STATUS_RUNNING},
188
+ "to": None,
189
+ },
190
+ "finish_ok": {
191
+ "from": {STATUS_RUNNING},
192
+ "to": STATUS_SUCCEEDED,
193
+ },
194
+ "finish_error": {
195
+ "from": {STATUS_QUEUED, STATUS_RUNNING},
196
+ "to": STATUS_FAILED,
197
+ },
198
+ "cancel_before_start": {
199
+ "from": {STATUS_QUEUED},
200
+ "to": STATUS_CANCELLED,
201
+ },
202
+ "cancel_running": {
203
+ "from": {STATUS_RUNNING},
204
+ "to": STATUS_CANCELLED,
205
+ },
206
+ # 预留重启失败事件,当前在 gateway 中直接 SQL 处理,不走这里
207
+ "restart_fail": {
208
+ "from": {STATUS_RUNNING},
209
+ "to": STATUS_FAILED,
210
+ },
211
+ }
212
+
213
+ cfg = transitions.get(event)
214
+ if cfg is None:
215
+ return None
216
+
217
+ if current_status not in cfg["from"]:
218
+ return None
219
+
220
+ fields: dict[str, Any] = dict(extra_fields)
221
+ target_status = cfg["to"]
222
+ if target_status is not None:
223
+ fields["status"] = target_status
224
+
225
+ update_job(job_id, **fields)
226
+ new_row = get_job_row(job_id)
227
+ if new_row is None:
228
+ return None
229
+ return row_to_job_dict(new_row)
230
+
src/web/static/dashboard.js CHANGED
@@ -42,6 +42,9 @@ async function refreshBilling() {
42
  }
43
  }
44
 
 
 
 
45
  function actionButtons(job) {
46
  const actions = [];
47
  if (job.status === "queued" || job.status === "running") {
@@ -78,12 +81,16 @@ function statusText(status) {
78
  return statusMap[status] || status;
79
  }
80
 
81
- async function refreshJobs() {
82
- const data = await apiJson("/api/jobs?limit=50");
83
  const body = document.getElementById("jobsBody");
84
  body.innerHTML = "";
85
 
86
- for (const job of data.jobs) {
 
 
 
 
 
87
  const tr = document.createElement("tr");
88
  tr.innerHTML = `
89
  <td class="mono">${esc(job.id)}</td>
@@ -98,6 +105,21 @@ async function refreshJobs() {
98
  }
99
  }
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  async function cancelJob(jobId) {
102
  try {
103
  await apiJson(`/api/jobs/${jobId}/cancel`, { method: "POST" });
@@ -127,6 +149,55 @@ async function refreshAll() {
127
  await Promise.all([refreshJobs(), refreshBilling()]);
128
  }
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  refreshAll();
131
- setInterval(refreshAll, 3000);
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
43
  }
44
 
45
+ // 任务状态缓存:在前端维护一个简单的内存表,方便 SSE/轮询统一渲染
46
+ const jobsState = new Map();
47
+
48
  function actionButtons(job) {
49
  const actions = [];
50
  if (job.status === "queued" || job.status === "running") {
 
81
  return statusMap[status] || status;
82
  }
83
 
84
+ function renderJobsFromState() {
 
85
  const body = document.getElementById("jobsBody");
86
  body.innerHTML = "";
87
 
88
+ const jobs = Array.from(jobsState.values());
89
+ jobs.sort((a, b) =>
90
+ (b.created_at || "").localeCompare(a.created_at || ""),
91
+ );
92
+
93
+ for (const job of jobs) {
94
  const tr = document.createElement("tr");
95
  tr.innerHTML = `
96
  <td class="mono">${esc(job.id)}</td>
 
105
  }
106
  }
107
 
108
+ function upsertJob(jobPatch) {
109
+ const existing = jobsState.get(jobPatch.id) || {};
110
+ jobsState.set(jobPatch.id, { ...existing, ...jobPatch });
111
+ renderJobsFromState();
112
+ }
113
+
114
+ async function refreshJobs() {
115
+ const data = await apiJson("/api/jobs?limit=50");
116
+ jobsState.clear();
117
+ for (const job of data.jobs) {
118
+ jobsState.set(job.id, job);
119
+ }
120
+ renderJobsFromState();
121
+ }
122
+
123
  async function cancelJob(jobId) {
124
  try {
125
  await apiJson(`/api/jobs/${jobId}/cancel`, { method: "POST" });
 
149
  await Promise.all([refreshJobs(), refreshBilling()]);
150
  }
151
 
152
+ let jobEventSource = null;
153
+ let pollingEnabled = true;
154
+ const POLL_INTERVAL_MS = 10000;
155
+
156
+ function setupJobStream() {
157
+ if (!("EventSource" in window)) {
158
+ console.warn("EventSource not supported, fallback to polling");
159
+ pollingEnabled = true;
160
+ return;
161
+ }
162
+
163
+ jobEventSource = new EventSource("/api/jobs/stream");
164
+
165
+ jobEventSource.onmessage = (event) => {
166
+ try {
167
+ const payload = JSON.parse(event.data);
168
+ if (!payload || !payload.id) {
169
+ return;
170
+ }
171
+ upsertJob(payload);
172
+ } catch (err) {
173
+ console.error("Failed to parse job SSE payload:", err);
174
+ }
175
+ };
176
+
177
+ jobEventSource.onerror = () => {
178
+ console.error("Job SSE error, switching back to polling");
179
+ if (jobEventSource) {
180
+ jobEventSource.close();
181
+ jobEventSource = null;
182
+ }
183
+ pollingEnabled = true;
184
+ };
185
+
186
+ pollingEnabled = false;
187
+ }
188
+
189
  refreshAll();
190
+ setupJobStream();
191
 
192
+ setInterval(async () => {
193
+ if (document.hidden) {
194
+ // 页面不可见时降低刷新频率:完全跳过本轮
195
+ return;
196
+ }
197
+
198
+ if (pollingEnabled) {
199
+ await refreshAll();
200
+ } else {
201
+ await refreshBilling();
202
+ }
203
+ }, POLL_INTERVAL_MS);