from __future__ import annotations import time from enum import Enum from typing import Any, Dict, Optional from pydantic import BaseModel, Field, model_validator class TaskStatus(str, Enum): queued = "queued" running = "running" retrying = "retrying" fallback_running = "fallback_running" waiting_rpa = "waiting_rpa" rpa_running = "rpa_running" rpa_imported = "rpa_imported" rpa_failed = "rpa_failed" risk_paused = "risk_paused" succeeded = "succeeded" failed = "failed" IN_PROGRESS_STATUSES = frozenset( ( TaskStatus.running, TaskStatus.retrying, TaskStatus.fallback_running, TaskStatus.waiting_rpa, TaskStatus.rpa_running, TaskStatus.risk_paused, ) ) TERMINAL_STATUSES = frozenset( ( TaskStatus.succeeded, TaskStatus.failed, TaskStatus.rpa_imported, TaskStatus.rpa_failed, ) ) def canonical_task_status_value(value: Any) -> str | None: if value is None: return None if hasattr(value, "value"): value = value.value raw = str(value).strip().lower() if raw == "": return None aliases = { "pending": TaskStatus.queued.value, "waiting": TaskStatus.queued.value, "in_progress": TaskStatus.running.value, "processing": TaskStatus.running.value, "done": TaskStatus.succeeded.value, "success": TaskStatus.succeeded.value, "error": TaskStatus.failed.value, "failed": TaskStatus.failed.value, "running": TaskStatus.running.value, "queued": TaskStatus.queued.value, } if raw in aliases: raw = aliases[raw] valid = {s.value for s in TaskStatus} if raw in valid: return raw return None class CallbackStatus(str, Enum): pending = "pending" succeeded = "succeeded" failed = "failed" class CallbackState(BaseModel): status: CallbackStatus = CallbackStatus.pending callback_url: Optional[str] = None idempotency_key: Optional[str] = None attempts: int = 0 last_attempt_at: Optional[float] = None last_http_status: Optional[int] = None last_error: Optional[str] = None next_retry_at: Optional[float] = None class TaskRecord(BaseModel): id: str status: TaskStatus task_type: str target: str = "" payload: Dict[str, Any] = Field(default_factory=dict) engine: Optional[str] = None callback: Optional[CallbackState] = None created: float = Field(default_factory=lambda: time.time()) started: float | None = None finished: float | None = None retry_count: int = 0 error: Dict[str, Any] | None = None @model_validator(mode="before") @classmethod def _coerce_legacy_fields(cls, data: Any) -> Any: if not isinstance(data, dict): return data status_raw = data.get("status") status_val = canonical_task_status_value(status_raw) if status_val is None: finished_any = data.get("finished") or data.get("finished_at") started_any = data.get("started") or data.get("started_at") if finished_any or data.get("error") not in (None, "", {}): data["status"] = TaskStatus.failed.value elif started_any: data["status"] = TaskStatus.running.value else: data["status"] = TaskStatus.queued.value else: data["status"] = status_val if "created" not in data and "created_at" in data: data["created"] = data.get("created_at") if "started" not in data and "started_at" in data: data["started"] = data.get("started_at") if "finished" not in data and "finished_at" in data: data["finished"] = data.get("finished_at") retry_val = data.get("retry_count") if retry_val is None: retry_val = data.get("retries") if retry_val is not None and "retry_count" not in data: data["retry_count"] = retry_val err = data.get("error") if isinstance(err, str) and err.strip() != "": data["error"] = {"message": err} payload = data.get("payload") if payload is None: data["payload"] = {} elif not isinstance(payload, dict): data["payload"] = {} if not data.get("target"): task_type = str(data.get("task_type") or "").strip().lower() payload_obj = data.get("payload") if isinstance(data.get("payload"), dict) else {} inferred = "" if task_type == "note_url": inferred = str(payload_obj.get("note_url") or payload_obj.get("url") or "") elif task_type == "search": inferred = str(payload_obj.get("query") or payload_obj.get("keyword") or "") elif task_type == "user_profile": inferred = str( payload_obj.get("user_id") or payload_obj.get("uid") or payload_obj.get("user_url") or payload_obj.get("url") or "" ) data["target"] = inferred return data def now_ts() -> float: return time.time()