"""Phase 4 — Task Queue Engine. Priority-based persistent task queue backed by Redis (Upstash). Falls back to in-memory queue when Redis is not available. States: QUEUED → RUNNING → COMPLETE | FAILED | CANCELLED """ from __future__ import annotations import asyncio import json import logging import os import time import threading from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) REDIS_URL = os.getenv("REDIS_URL", "") class JobPriority(int, Enum): CRITICAL = 0 HIGH = 1 NORMAL = 2 LOW = 3 @dataclass class QueueJob: job_id: str task_id: int prompt: str priority: JobPriority = JobPriority.NORMAL created_at: float = field(default_factory=time.time) started_at: Optional[float] = None completed_at: Optional[float] = None status: str = "QUEUED" # QUEUED | RUNNING | COMPLETE | FAILED | CANCELLED error: Optional[str] = None retry_count: int = 0 max_retries: int = 2 def to_dict(self) -> Dict[str, Any]: return { "job_id": self.job_id, "task_id": self.task_id, "prompt": self.prompt, "priority": int(self.priority), "created_at": self.created_at, "started_at": self.started_at, "completed_at": self.completed_at, "status": self.status, "error": self.error, "retry_count": self.retry_count, "max_retries": self.max_retries, } @classmethod def from_dict(cls, d: Dict[str, Any]) -> "QueueJob": return cls( job_id=d["job_id"], task_id=d["task_id"], prompt=d["prompt"], priority=JobPriority(d.get("priority", 2)), created_at=d.get("created_at", time.time()), started_at=d.get("started_at"), completed_at=d.get("completed_at"), status=d.get("status", "QUEUED"), error=d.get("error"), retry_count=d.get("retry_count", 0), max_retries=d.get("max_retries", 2), ) class RedisQueue: """Redis-backed priority queue using sorted sets.""" QUEUE_KEY = "onehands:job_queue" JOB_PREFIX = "onehands:job:" RUNNING_KEY = "onehands:running" def __init__(self, redis_client: Any) -> None: self._r = redis_client def enqueue(self, job: QueueJob) -> None: # score = priority * 1e12 + timestamp (lower = higher priority) score = int(job.priority) * 1_000_000_000_000 + int(job.created_at * 1000) pipe = self._r.pipeline() pipe.set(f"{self.JOB_PREFIX}{job.job_id}", json.dumps(job.to_dict())) pipe.zadd(self.QUEUE_KEY, {job.job_id: score}) pipe.execute() def dequeue(self) -> Optional[QueueJob]: """Pop lowest-score (highest priority + oldest) job.""" result = self._r.zpopmin(self.QUEUE_KEY, 1) if not result: return None job_id = result[0][0] if isinstance(job_id, bytes): job_id = job_id.decode() raw = self._r.get(f"{self.JOB_PREFIX}{job_id}") if not raw: return None job = QueueJob.from_dict(json.loads(raw)) job.status = "RUNNING" job.started_at = time.time() self._r.set(f"{self.JOB_PREFIX}{job_id}", json.dumps(job.to_dict())) self._r.sadd(self.RUNNING_KEY, job_id) return job def update(self, job: QueueJob) -> None: self._r.set(f"{self.JOB_PREFIX}{job.job_id}", json.dumps(job.to_dict())) if job.status in ("COMPLETE", "FAILED", "CANCELLED"): self._r.srem(self.RUNNING_KEY, job.job_id) def get_job(self, job_id: str) -> Optional[QueueJob]: raw = self._r.get(f"{self.JOB_PREFIX}{job_id}") if not raw: return None return QueueJob.from_dict(json.loads(raw)) def queue_depth(self) -> int: return self._r.zcard(self.QUEUE_KEY) def running_count(self) -> int: return self._r.scard(self.RUNNING_KEY) def list_queued(self) -> List[QueueJob]: ids = self._r.zrange(self.QUEUE_KEY, 0, -1) jobs = [] for jid in ids: if isinstance(jid, bytes): jid = jid.decode() raw = self._r.get(f"{self.JOB_PREFIX}{jid}") if raw: jobs.append(QueueJob.from_dict(json.loads(raw))) return jobs class MemoryQueue: """In-memory fallback queue (no Redis).""" def __init__(self) -> None: self._lock = threading.Lock() self._queue: List[QueueJob] = [] self._all: Dict[str, QueueJob] = {} self._running: set = set() def enqueue(self, job: QueueJob) -> None: with self._lock: self._all[job.job_id] = job self._queue.append(job) self._queue.sort(key=lambda j: (int(j.priority), j.created_at)) def dequeue(self) -> Optional[QueueJob]: with self._lock: for job in self._queue: if job.status == "QUEUED": job.status = "RUNNING" job.started_at = time.time() self._queue.remove(job) self._running.add(job.job_id) return job return None def update(self, job: QueueJob) -> None: with self._lock: self._all[job.job_id] = job if job.status in ("COMPLETE", "FAILED", "CANCELLED"): self._running.discard(job.job_id) def get_job(self, job_id: str) -> Optional[QueueJob]: with self._lock: return self._all.get(job_id) def queue_depth(self) -> int: with self._lock: return len([j for j in self._queue if j.status == "QUEUED"]) def running_count(self) -> int: with self._lock: return len(self._running) def list_queued(self) -> List[QueueJob]: with self._lock: return [j for j in self._queue if j.status == "QUEUED"] # ---------- Singleton ---------- # _queue_instance: Optional[Any] = None _queue_lock = threading.Lock() def get_queue() -> Any: global _queue_instance with _queue_lock: if _queue_instance is None: _queue_instance = _build_queue() return _queue_instance def _build_queue() -> Any: if not REDIS_URL: logger.info("Phase 4 Queue: Redis not configured — using in-memory queue") return MemoryQueue() try: import redis as redis_lib # Support TLS (rediss://) and plain (redis://) URLs client = redis_lib.from_url(REDIS_URL, decode_responses=False, socket_timeout=5) client.ping() logger.info("Phase 4 Queue: Connected to Redis") return RedisQueue(client) except Exception as exc: logger.warning("Phase 4 Queue: Redis connection failed (%s) — using in-memory queue", exc) return MemoryQueue() # ---------- Background Worker ---------- # class QueueWorker: """Background worker that drains the job queue.""" def __init__(self, max_concurrent: int = 3) -> None: self._max_concurrent = max_concurrent self._semaphore: Optional[asyncio.Semaphore] = None self._running = False self._loop: Optional[asyncio.AbstractEventLoop] = None self._thread: Optional[threading.Thread] = None def start(self) -> None: if self._running: return self._running = True self._thread = threading.Thread(target=self._run_loop, daemon=True) self._thread.start() logger.info("Phase 4 QueueWorker started (max_concurrent=%d)", self._max_concurrent) def stop(self) -> None: self._running = False def _run_loop(self) -> None: self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) self._semaphore = asyncio.Semaphore(self._max_concurrent) self._loop.run_until_complete(self._worker_loop()) async def _worker_loop(self) -> None: from .runner_v2 import AgentRunnerV2 from ..db.schema import SessionLocal, Task q = get_queue() while self._running: try: job = q.dequeue() if job is None: await asyncio.sleep(1.0) continue logger.info("Phase 4: Dequeued job %s (task %d)", job.job_id, job.task_id) async def run_job(j: QueueJob) -> None: async with self._semaphore: try: runner = AgentRunnerV2(j.task_id) result = await runner.run(j.prompt) j.status = "COMPLETE" j.completed_at = time.time() except Exception as exc: logger.exception("Job %s failed: %s", j.job_id, exc) j.status = "FAILED" j.error = str(exc) j.completed_at = time.time() finally: q.update(j) asyncio.create_task(run_job(job)) except Exception as exc: logger.exception("QueueWorker loop error: %s", exc) await asyncio.sleep(2.0) _worker_instance: Optional[QueueWorker] = None def get_worker() -> QueueWorker: global _worker_instance if _worker_instance is None: _worker_instance = QueueWorker(max_concurrent=3) return _worker_instance