| from __future__ import annotations |
|
|
| import time |
| from enum import Enum |
| from typing import Any, Dict, Optional |
|
|
| from pydantic import BaseModel, Field, model_validator |
|
|
|
|
| class TaskStatus(str, Enum): |
| queued = "queued" |
| running = "running" |
| retrying = "retrying" |
| fallback_running = "fallback_running" |
| waiting_rpa = "waiting_rpa" |
| rpa_running = "rpa_running" |
| rpa_imported = "rpa_imported" |
| rpa_failed = "rpa_failed" |
| risk_paused = "risk_paused" |
| succeeded = "succeeded" |
| failed = "failed" |
|
|
|
|
| IN_PROGRESS_STATUSES = frozenset( |
| ( |
| TaskStatus.running, |
| TaskStatus.retrying, |
| TaskStatus.fallback_running, |
| TaskStatus.waiting_rpa, |
| TaskStatus.rpa_running, |
| TaskStatus.risk_paused, |
| ) |
| ) |
|
|
| TERMINAL_STATUSES = frozenset( |
| ( |
| TaskStatus.succeeded, |
| TaskStatus.failed, |
| TaskStatus.rpa_imported, |
| TaskStatus.rpa_failed, |
| ) |
| ) |
|
|
|
|
| def canonical_task_status_value(value: Any) -> str | None: |
| if value is None: |
| return None |
| if hasattr(value, "value"): |
| value = value.value |
| raw = str(value).strip().lower() |
| if raw == "": |
| return None |
| aliases = { |
| "pending": TaskStatus.queued.value, |
| "waiting": TaskStatus.queued.value, |
| "in_progress": TaskStatus.running.value, |
| "processing": TaskStatus.running.value, |
| "done": TaskStatus.succeeded.value, |
| "success": TaskStatus.succeeded.value, |
| "error": TaskStatus.failed.value, |
| "failed": TaskStatus.failed.value, |
| "running": TaskStatus.running.value, |
| "queued": TaskStatus.queued.value, |
| } |
| if raw in aliases: |
| raw = aliases[raw] |
| valid = {s.value for s in TaskStatus} |
| if raw in valid: |
| return raw |
| return None |
|
|
|
|
|
|
| class CallbackStatus(str, Enum): |
| pending = "pending" |
| succeeded = "succeeded" |
| failed = "failed" |
|
|
|
|
| class CallbackState(BaseModel): |
| status: CallbackStatus = CallbackStatus.pending |
| callback_url: Optional[str] = None |
| idempotency_key: Optional[str] = None |
| attempts: int = 0 |
| last_attempt_at: Optional[float] = None |
| last_http_status: Optional[int] = None |
| last_error: Optional[str] = None |
| next_retry_at: Optional[float] = None |
|
|
|
|
| class TaskRecord(BaseModel): |
| id: str |
| status: TaskStatus |
| task_type: str |
| target: str = "" |
| payload: Dict[str, Any] = Field(default_factory=dict) |
| engine: Optional[str] = None |
| callback: Optional[CallbackState] = None |
| created: float = Field(default_factory=lambda: time.time()) |
| started: float | None = None |
| finished: float | None = None |
| retry_count: int = 0 |
| error: Dict[str, Any] | None = None |
|
|
| @model_validator(mode="before") |
| @classmethod |
| def _coerce_legacy_fields(cls, data: Any) -> Any: |
| if not isinstance(data, dict): |
| return data |
|
|
| status_raw = data.get("status") |
| status_val = canonical_task_status_value(status_raw) |
| if status_val is None: |
| finished_any = data.get("finished") or data.get("finished_at") |
| started_any = data.get("started") or data.get("started_at") |
| if finished_any or data.get("error") not in (None, "", {}): |
| data["status"] = TaskStatus.failed.value |
| elif started_any: |
| data["status"] = TaskStatus.running.value |
| else: |
| data["status"] = TaskStatus.queued.value |
| else: |
| data["status"] = status_val |
|
|
| if "created" not in data and "created_at" in data: |
| data["created"] = data.get("created_at") |
| if "started" not in data and "started_at" in data: |
| data["started"] = data.get("started_at") |
| if "finished" not in data and "finished_at" in data: |
| data["finished"] = data.get("finished_at") |
|
|
| retry_val = data.get("retry_count") |
| if retry_val is None: |
| retry_val = data.get("retries") |
| if retry_val is not None and "retry_count" not in data: |
| data["retry_count"] = retry_val |
|
|
| err = data.get("error") |
| if isinstance(err, str) and err.strip() != "": |
| data["error"] = {"message": err} |
|
|
| payload = data.get("payload") |
| if payload is None: |
| data["payload"] = {} |
| elif not isinstance(payload, dict): |
| data["payload"] = {} |
|
|
| if not data.get("target"): |
| task_type = str(data.get("task_type") or "").strip().lower() |
| payload_obj = data.get("payload") if isinstance(data.get("payload"), dict) else {} |
| inferred = "" |
| if task_type == "note_url": |
| inferred = str(payload_obj.get("note_url") or payload_obj.get("url") or "") |
| elif task_type == "search": |
| inferred = str(payload_obj.get("query") or payload_obj.get("keyword") or "") |
| elif task_type == "user_profile": |
| inferred = str( |
| payload_obj.get("user_id") |
| or payload_obj.get("uid") |
| or payload_obj.get("user_url") |
| or payload_obj.get("url") |
| or "" |
| ) |
| data["target"] = inferred |
|
|
| return data |
|
|
|
|
| def now_ts() -> float: |
| return time.time() |
|
|