File size: 9,615 Bytes
2f3fd39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""Phase 4 — Task Queue Engine.

Priority-based persistent task queue backed by Redis (Upstash).
Falls back to in-memory queue when Redis is not available.

States: QUEUED → RUNNING → COMPLETE | FAILED | CANCELLED
"""
from __future__ import annotations

import asyncio
import json
import logging
import os
import time
import threading
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)

REDIS_URL = os.getenv("REDIS_URL", "")


class JobPriority(int, Enum):
    CRITICAL = 0
    HIGH = 1
    NORMAL = 2
    LOW = 3


@dataclass
class QueueJob:
    job_id: str
    task_id: int
    prompt: str
    priority: JobPriority = JobPriority.NORMAL
    created_at: float = field(default_factory=time.time)
    started_at: Optional[float] = None
    completed_at: Optional[float] = None
    status: str = "QUEUED"   # QUEUED | RUNNING | COMPLETE | FAILED | CANCELLED
    error: Optional[str] = None
    retry_count: int = 0
    max_retries: int = 2

    def to_dict(self) -> Dict[str, Any]:
        return {
            "job_id": self.job_id,
            "task_id": self.task_id,
            "prompt": self.prompt,
            "priority": int(self.priority),
            "created_at": self.created_at,
            "started_at": self.started_at,
            "completed_at": self.completed_at,
            "status": self.status,
            "error": self.error,
            "retry_count": self.retry_count,
            "max_retries": self.max_retries,
        }

    @classmethod
    def from_dict(cls, d: Dict[str, Any]) -> "QueueJob":
        return cls(
            job_id=d["job_id"],
            task_id=d["task_id"],
            prompt=d["prompt"],
            priority=JobPriority(d.get("priority", 2)),
            created_at=d.get("created_at", time.time()),
            started_at=d.get("started_at"),
            completed_at=d.get("completed_at"),
            status=d.get("status", "QUEUED"),
            error=d.get("error"),
            retry_count=d.get("retry_count", 0),
            max_retries=d.get("max_retries", 2),
        )


class RedisQueue:
    """Redis-backed priority queue using sorted sets."""

    QUEUE_KEY = "onehands:job_queue"
    JOB_PREFIX = "onehands:job:"
    RUNNING_KEY = "onehands:running"

    def __init__(self, redis_client: Any) -> None:
        self._r = redis_client

    def enqueue(self, job: QueueJob) -> None:
        # score = priority * 1e12 + timestamp (lower = higher priority)
        score = int(job.priority) * 1_000_000_000_000 + int(job.created_at * 1000)
        pipe = self._r.pipeline()
        pipe.set(f"{self.JOB_PREFIX}{job.job_id}", json.dumps(job.to_dict()))
        pipe.zadd(self.QUEUE_KEY, {job.job_id: score})
        pipe.execute()

    def dequeue(self) -> Optional[QueueJob]:
        """Pop lowest-score (highest priority + oldest) job."""
        result = self._r.zpopmin(self.QUEUE_KEY, 1)
        if not result:
            return None
        job_id = result[0][0]
        if isinstance(job_id, bytes):
            job_id = job_id.decode()
        raw = self._r.get(f"{self.JOB_PREFIX}{job_id}")
        if not raw:
            return None
        job = QueueJob.from_dict(json.loads(raw))
        job.status = "RUNNING"
        job.started_at = time.time()
        self._r.set(f"{self.JOB_PREFIX}{job_id}", json.dumps(job.to_dict()))
        self._r.sadd(self.RUNNING_KEY, job_id)
        return job

    def update(self, job: QueueJob) -> None:
        self._r.set(f"{self.JOB_PREFIX}{job.job_id}", json.dumps(job.to_dict()))
        if job.status in ("COMPLETE", "FAILED", "CANCELLED"):
            self._r.srem(self.RUNNING_KEY, job.job_id)

    def get_job(self, job_id: str) -> Optional[QueueJob]:
        raw = self._r.get(f"{self.JOB_PREFIX}{job_id}")
        if not raw:
            return None
        return QueueJob.from_dict(json.loads(raw))

    def queue_depth(self) -> int:
        return self._r.zcard(self.QUEUE_KEY)

    def running_count(self) -> int:
        return self._r.scard(self.RUNNING_KEY)

    def list_queued(self) -> List[QueueJob]:
        ids = self._r.zrange(self.QUEUE_KEY, 0, -1)
        jobs = []
        for jid in ids:
            if isinstance(jid, bytes):
                jid = jid.decode()
            raw = self._r.get(f"{self.JOB_PREFIX}{jid}")
            if raw:
                jobs.append(QueueJob.from_dict(json.loads(raw)))
        return jobs


class MemoryQueue:
    """In-memory fallback queue (no Redis)."""

    def __init__(self) -> None:
        self._lock = threading.Lock()
        self._queue: List[QueueJob] = []
        self._all: Dict[str, QueueJob] = {}
        self._running: set = set()

    def enqueue(self, job: QueueJob) -> None:
        with self._lock:
            self._all[job.job_id] = job
            self._queue.append(job)
            self._queue.sort(key=lambda j: (int(j.priority), j.created_at))

    def dequeue(self) -> Optional[QueueJob]:
        with self._lock:
            for job in self._queue:
                if job.status == "QUEUED":
                    job.status = "RUNNING"
                    job.started_at = time.time()
                    self._queue.remove(job)
                    self._running.add(job.job_id)
                    return job
        return None

    def update(self, job: QueueJob) -> None:
        with self._lock:
            self._all[job.job_id] = job
            if job.status in ("COMPLETE", "FAILED", "CANCELLED"):
                self._running.discard(job.job_id)

    def get_job(self, job_id: str) -> Optional[QueueJob]:
        with self._lock:
            return self._all.get(job_id)

    def queue_depth(self) -> int:
        with self._lock:
            return len([j for j in self._queue if j.status == "QUEUED"])

    def running_count(self) -> int:
        with self._lock:
            return len(self._running)

    def list_queued(self) -> List[QueueJob]:
        with self._lock:
            return [j for j in self._queue if j.status == "QUEUED"]


# ---------- Singleton ---------- #

_queue_instance: Optional[Any] = None
_queue_lock = threading.Lock()


def get_queue() -> Any:
    global _queue_instance
    with _queue_lock:
        if _queue_instance is None:
            _queue_instance = _build_queue()
    return _queue_instance


def _build_queue() -> Any:
    if not REDIS_URL:
        logger.info("Phase 4 Queue: Redis not configured — using in-memory queue")
        return MemoryQueue()
    try:
        import redis as redis_lib
        # Support TLS (rediss://) and plain (redis://) URLs
        client = redis_lib.from_url(REDIS_URL, decode_responses=False, socket_timeout=5)
        client.ping()
        logger.info("Phase 4 Queue: Connected to Redis")
        return RedisQueue(client)
    except Exception as exc:
        logger.warning("Phase 4 Queue: Redis connection failed (%s) — using in-memory queue", exc)
        return MemoryQueue()


# ---------- Background Worker ---------- #

class QueueWorker:
    """Background worker that drains the job queue."""

    def __init__(self, max_concurrent: int = 3) -> None:
        self._max_concurrent = max_concurrent
        self._semaphore: Optional[asyncio.Semaphore] = None
        self._running = False
        self._loop: Optional[asyncio.AbstractEventLoop] = None
        self._thread: Optional[threading.Thread] = None

    def start(self) -> None:
        if self._running:
            return
        self._running = True
        self._thread = threading.Thread(target=self._run_loop, daemon=True)
        self._thread.start()
        logger.info("Phase 4 QueueWorker started (max_concurrent=%d)", self._max_concurrent)

    def stop(self) -> None:
        self._running = False

    def _run_loop(self) -> None:
        self._loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self._loop)
        self._semaphore = asyncio.Semaphore(self._max_concurrent)
        self._loop.run_until_complete(self._worker_loop())

    async def _worker_loop(self) -> None:
        from .runner_v2 import AgentRunnerV2
        from ..db.schema import SessionLocal, Task

        q = get_queue()
        while self._running:
            try:
                job = q.dequeue()
                if job is None:
                    await asyncio.sleep(1.0)
                    continue

                logger.info("Phase 4: Dequeued job %s (task %d)", job.job_id, job.task_id)

                async def run_job(j: QueueJob) -> None:
                    async with self._semaphore:
                        try:
                            runner = AgentRunnerV2(j.task_id)
                            result = await runner.run(j.prompt)
                            j.status = "COMPLETE"
                            j.completed_at = time.time()
                        except Exception as exc:
                            logger.exception("Job %s failed: %s", j.job_id, exc)
                            j.status = "FAILED"
                            j.error = str(exc)
                            j.completed_at = time.time()
                        finally:
                            q.update(j)

                asyncio.create_task(run_job(job))

            except Exception as exc:
                logger.exception("QueueWorker loop error: %s", exc)
                await asyncio.sleep(2.0)


_worker_instance: Optional[QueueWorker] = None


def get_worker() -> QueueWorker:
    global _worker_instance
    if _worker_instance is None:
        _worker_instance = QueueWorker(max_concurrent=3)
    return _worker_instance