File size: 6,612 Bytes
d3a7520 b2d8381 d3a7520 b2d8381 d3a7520 b2d8381 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """任务相关的持久化操作与辅助函数。"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import sqlite3
import storage
# ── 任务状态与状态机 ────────────────────────────────────────────────────────────
# 约定的任务状态枚举,避免在业务层随意写字符串
STATUS_QUEUED = "queued"
STATUS_RUNNING = "running"
STATUS_SUCCEEDED = "succeeded"
STATUS_FAILED = "failed"
STATUS_CANCELLED = "cancelled"
ALLOWED_STATUSES: set[str] = {
STATUS_QUEUED,
STATUS_RUNNING,
STATUS_SUCCEEDED,
STATUS_FAILED,
STATUS_CANCELLED,
}
def row_to_job_dict(row: sqlite3.Row) -> dict[str, Any]:
"""将任务行转换为对外暴露的字典结构。"""
job = dict(row)
job["artifact_urls"] = {
"mono": f"/api/jobs/{job['id']}/artifacts/mono"
if job.get("mono_pdf_path")
else None,
"dual": f"/api/jobs/{job['id']}/artifacts/dual"
if job.get("dual_pdf_path")
else None,
"glossary": f"/api/jobs/{job['id']}/artifacts/glossary"
if job.get("glossary_path")
else None,
}
return job
def update_job(job_id: str, **fields: Any) -> None:
"""更新任务记录指定字段。
注意:业务代码应该优先通过 transition_job 做状态机驱动更新,
直接调用本函数仅用于与状态无关的字段(例如 cancel_requested)。
"""
if not fields:
return
fields["updated_at"] = storage.now_iso()
set_clause = ", ".join(f"{k} = ?" for k in fields.keys())
params = tuple(fields.values()) + (job_id,)
storage.db_execute(f"UPDATE jobs SET {set_clause} WHERE id = ?", params)
def get_job_row(job_id: str) -> sqlite3.Row | None:
"""按 ID 获取任务原始行。"""
return storage.db_fetchone("SELECT * FROM jobs WHERE id = ?", (job_id,))
def get_job_for_user(job_id: str, username: str) -> dict[str, Any] | None:
"""获取用户可见的任务,如果不存在或不属于该用户返回 None。"""
row = storage.db_fetchone(
"SELECT * FROM jobs WHERE id = ? AND username = ?",
(job_id, username),
)
if row is None:
return None
return row_to_job_dict(row)
def get_jobs_for_user(username: str, limit: int) -> list[dict[str, Any]]:
"""列出用户的任务列表,按创建时间倒序。"""
rows = storage.db_fetchall(
"""
SELECT * FROM jobs
WHERE username = ?
ORDER BY created_at DESC
LIMIT ?
""",
(username, limit),
)
return [row_to_job_dict(row) for row in rows]
def create_job_record(
*,
job_id: str,
username: str,
filename: str,
input_path: Path,
output_dir: Path,
model: str,
lang_in: str,
lang_out: str,
) -> dict[str, Any]:
"""插入一条新任务并返回任务字典。"""
now = storage.now_iso()
storage.db_execute(
"""
INSERT INTO jobs(
id, username, filename, input_path, output_dir,
status, progress, message, error,
model, lang_in, lang_out,
cancel_requested,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
job_id,
username,
filename,
str(input_path),
str(output_dir),
"queued",
0.0,
"Queued",
None,
model,
lang_in,
lang_out,
0,
now,
now,
),
)
row = storage.db_fetchone("SELECT * FROM jobs WHERE id = ?", (job_id,))
if row is None:
raise RuntimeError("Failed to fetch job after insert")
return row_to_job_dict(row)
def resolve_artifact_path(raw_path: str | None, output_dir: Path) -> Path | None:
"""解析并校验任务产物路径,限制在 output_dir 内部。"""
if not raw_path:
return None
path = Path(raw_path)
if not path.is_absolute():
path = (output_dir / path).resolve()
else:
path = path.resolve()
if not path.exists():
return None
try:
path.relative_to(output_dir)
except ValueError:
return None
return path
def transition_job(job_id: str, event: str, **extra_fields: Any) -> dict[str, Any] | None:
"""基于事件驱动的任务状态迁移。
这里只负责:
* 校验当前状态是否允许执行给定事件
* 决定目标状态(如果有)
* 写入数据库
* 返回更新后的任务字典(用于推送给前端)
状态枚举固定为 queued/running/succeeded/failed/cancelled,避免状态空间爆炸。
"""
row = get_job_row(job_id)
if row is None:
return None
current_status = row["status"]
if current_status not in ALLOWED_STATUSES:
# 非法状态一律拒绝迁移,由调用方记录日志
return None
# 简单的事件 -> 允许来源状态集合、目标状态映射
# 对于 progress 这类事件,目标状态为 None,只更新进度等字段。
transitions: dict[str, dict[str, Any]] = {
"start": {
"from": {STATUS_QUEUED},
"to": STATUS_RUNNING,
},
"progress": {
"from": {STATUS_RUNNING},
"to": None,
},
"finish_ok": {
"from": {STATUS_RUNNING},
"to": STATUS_SUCCEEDED,
},
"finish_error": {
"from": {STATUS_QUEUED, STATUS_RUNNING},
"to": STATUS_FAILED,
},
"cancel_before_start": {
"from": {STATUS_QUEUED},
"to": STATUS_CANCELLED,
},
"cancel_running": {
"from": {STATUS_RUNNING},
"to": STATUS_CANCELLED,
},
# 预留重启失败事件,当前在 gateway 中直接 SQL 处理,不走这里
"restart_fail": {
"from": {STATUS_RUNNING},
"to": STATUS_FAILED,
},
}
cfg = transitions.get(event)
if cfg is None:
return None
if current_status not in cfg["from"]:
return None
fields: dict[str, Any] = dict(extra_fields)
target_status = cfg["to"]
if target_status is not None:
fields["status"] = target_status
update_job(job_id, **fields)
new_row = get_job_row(job_id)
if new_row is None:
return None
return row_to_job_dict(new_row)
|