XHS / service /tasks.py
Trae Bot
Upload Spider_XHS project
c481f8a
from __future__ import annotations
import time
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
class TaskStatus(str, Enum):
queued = "queued"
running = "running"
retrying = "retrying"
fallback_running = "fallback_running"
waiting_rpa = "waiting_rpa"
rpa_running = "rpa_running"
rpa_imported = "rpa_imported"
rpa_failed = "rpa_failed"
risk_paused = "risk_paused"
succeeded = "succeeded"
failed = "failed"
IN_PROGRESS_STATUSES = frozenset(
(
TaskStatus.running,
TaskStatus.retrying,
TaskStatus.fallback_running,
TaskStatus.waiting_rpa,
TaskStatus.rpa_running,
TaskStatus.risk_paused,
)
)
TERMINAL_STATUSES = frozenset(
(
TaskStatus.succeeded,
TaskStatus.failed,
TaskStatus.rpa_imported,
TaskStatus.rpa_failed,
)
)
def canonical_task_status_value(value: Any) -> str | None:
if value is None:
return None
if hasattr(value, "value"):
value = value.value
raw = str(value).strip().lower()
if raw == "":
return None
aliases = {
"pending": TaskStatus.queued.value,
"waiting": TaskStatus.queued.value,
"in_progress": TaskStatus.running.value,
"processing": TaskStatus.running.value,
"done": TaskStatus.succeeded.value,
"success": TaskStatus.succeeded.value,
"error": TaskStatus.failed.value,
"failed": TaskStatus.failed.value,
"running": TaskStatus.running.value,
"queued": TaskStatus.queued.value,
}
if raw in aliases:
raw = aliases[raw]
valid = {s.value for s in TaskStatus}
if raw in valid:
return raw
return None
class CallbackStatus(str, Enum):
pending = "pending"
succeeded = "succeeded"
failed = "failed"
class CallbackState(BaseModel):
status: CallbackStatus = CallbackStatus.pending
callback_url: Optional[str] = None
idempotency_key: Optional[str] = None
attempts: int = 0
last_attempt_at: Optional[float] = None
last_http_status: Optional[int] = None
last_error: Optional[str] = None
next_retry_at: Optional[float] = None
class TaskRecord(BaseModel):
id: str
status: TaskStatus
task_type: str
target: str = ""
payload: Dict[str, Any] = Field(default_factory=dict)
engine: Optional[str] = None
callback: Optional[CallbackState] = None
created: float = Field(default_factory=lambda: time.time())
started: float | None = None
finished: float | None = None
retry_count: int = 0
error: Dict[str, Any] | None = None
@model_validator(mode="before")
@classmethod
def _coerce_legacy_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
status_raw = data.get("status")
status_val = canonical_task_status_value(status_raw)
if status_val is None:
finished_any = data.get("finished") or data.get("finished_at")
started_any = data.get("started") or data.get("started_at")
if finished_any or data.get("error") not in (None, "", {}):
data["status"] = TaskStatus.failed.value
elif started_any:
data["status"] = TaskStatus.running.value
else:
data["status"] = TaskStatus.queued.value
else:
data["status"] = status_val
if "created" not in data and "created_at" in data:
data["created"] = data.get("created_at")
if "started" not in data and "started_at" in data:
data["started"] = data.get("started_at")
if "finished" not in data and "finished_at" in data:
data["finished"] = data.get("finished_at")
retry_val = data.get("retry_count")
if retry_val is None:
retry_val = data.get("retries")
if retry_val is not None and "retry_count" not in data:
data["retry_count"] = retry_val
err = data.get("error")
if isinstance(err, str) and err.strip() != "":
data["error"] = {"message": err}
payload = data.get("payload")
if payload is None:
data["payload"] = {}
elif not isinstance(payload, dict):
data["payload"] = {}
if not data.get("target"):
task_type = str(data.get("task_type") or "").strip().lower()
payload_obj = data.get("payload") if isinstance(data.get("payload"), dict) else {}
inferred = ""
if task_type == "note_url":
inferred = str(payload_obj.get("note_url") or payload_obj.get("url") or "")
elif task_type == "search":
inferred = str(payload_obj.get("query") or payload_obj.get("keyword") or "")
elif task_type == "user_profile":
inferred = str(
payload_obj.get("user_id")
or payload_obj.get("uid")
or payload_obj.get("user_url")
or payload_obj.get("url")
or ""
)
data["target"] = inferred
return data
def now_ts() -> float:
return time.time()