File size: 14,511 Bytes
d7b2379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
"""
Persistent task memory (Phase 2).

A lightweight SQLite store that keeps the full lifecycle, plan and event
history of every agent run. The DB lives in $TASK_DB_PATH (default
``/home/user/app/data/tasks.db``); the directory is created on demand so the
container starts cleanly even on a fresh volume.

This module is INTENTIONALLY additive β€” none of the Phase-1 endpoints
import it, so existing behaviour cannot regress if SQLite is unavailable.

States (must match TaskState below):

    queued β†’ planning β†’ thinking β†’ executing β†’ retrying β†’ completed | failed
                                          β†˜β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β†—

Concurrency model: a single ``sqlite3.connect`` per call, ``WAL`` journal mode,
all writes guarded by a process-wide ``asyncio.Lock``.  This is plenty for the
single-worker uvicorn deployment we run on HF Spaces.
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import sqlite3
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Lifecycle states
# ---------------------------------------------------------------------------

class TaskState:
    QUEUED = "queued"
    PLANNING = "planning"
    THINKING = "thinking"
    EXECUTING = "executing"
    RETRYING = "retrying"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"

    ALL = {QUEUED, PLANNING, THINKING, EXECUTING, RETRYING, COMPLETED, FAILED, CANCELLED}
    TERMINAL = {COMPLETED, FAILED, CANCELLED}


# ---------------------------------------------------------------------------
# DB location & schema
# ---------------------------------------------------------------------------

def _db_path() -> str:
    raw = os.environ.get("TASK_DB_PATH", "/home/user/app/data/tasks.db")
    os.makedirs(os.path.dirname(raw) or ".", exist_ok=True)
    return raw


_SCHEMA = """
CREATE TABLE IF NOT EXISTS tasks (
    id           TEXT PRIMARY KEY,
    created_at   REAL NOT NULL,
    updated_at   REAL NOT NULL,
    state        TEXT NOT NULL,
    user_message TEXT NOT NULL,
    metadata     TEXT NOT NULL DEFAULT '{}',
    sandbox_id   TEXT,
    final_reply  TEXT,
    error        TEXT
);
CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_tasks_state      ON tasks(state);

CREATE TABLE IF NOT EXISTS steps (
    id          INTEGER PRIMARY KEY AUTOINCREMENT,
    task_id     TEXT NOT NULL REFERENCES tasks(id) ON DELETE CASCADE,
    idx         INTEGER NOT NULL,
    title       TEXT NOT NULL,
    description TEXT NOT NULL DEFAULT '',
    state       TEXT NOT NULL DEFAULT 'queued',
    attempts    INTEGER NOT NULL DEFAULT 0,
    started_at  REAL,
    finished_at REAL,
    result      TEXT,
    error       TEXT
);
CREATE INDEX IF NOT EXISTS idx_steps_task ON steps(task_id, idx);

CREATE TABLE IF NOT EXISTS events (
    id        INTEGER PRIMARY KEY AUTOINCREMENT,
    task_id   TEXT NOT NULL REFERENCES tasks(id) ON DELETE CASCADE,
    step_idx  INTEGER,
    ts        REAL NOT NULL,
    kind      TEXT NOT NULL,
    payload   TEXT NOT NULL DEFAULT '{}'
);
CREATE INDEX IF NOT EXISTS idx_events_task ON events(task_id, id);
"""


_init_lock = asyncio.Lock()
_initialised = False
_write_lock = asyncio.Lock()


@contextmanager
def _conn():
    path = _db_path()
    c = sqlite3.connect(path, timeout=10, isolation_level=None)
    c.row_factory = sqlite3.Row
    try:
        c.execute("PRAGMA journal_mode=WAL;")
        c.execute("PRAGMA synchronous=NORMAL;")
        c.execute("PRAGMA foreign_keys=ON;")
        yield c
    finally:
        c.close()


async def init() -> None:
    """Create tables if necessary. Safe to call many times."""
    global _initialised
    if _initialised:
        return
    async with _init_lock:
        if _initialised:
            return
        def _do():
            with _conn() as c:
                c.executescript(_SCHEMA)
        await asyncio.to_thread(_do)
        _initialised = True
        logger.info("task DB ready at %s", _db_path())


# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------

@dataclass
class Step:
    idx: int
    title: str
    description: str = ""
    state: str = TaskState.QUEUED
    attempts: int = 0
    started_at: Optional[float] = None
    finished_at: Optional[float] = None
    result: Optional[str] = None
    error: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            "idx": self.idx,
            "title": self.title,
            "description": self.description,
            "state": self.state,
            "attempts": self.attempts,
            "started_at": self.started_at,
            "finished_at": self.finished_at,
            "result": self.result,
            "error": self.error,
        }


@dataclass
class Task:
    id: str
    created_at: float
    updated_at: float
    state: str
    user_message: str
    metadata: Dict[str, Any] = field(default_factory=dict)
    sandbox_id: Optional[str] = None
    final_reply: Optional[str] = None
    error: Optional[str] = None
    steps: List[Step] = field(default_factory=list)

    def to_dict(self, include_steps: bool = True) -> Dict[str, Any]:
        out: Dict[str, Any] = {
            "id": self.id,
            "created_at": self.created_at,
            "updated_at": self.updated_at,
            "state": self.state,
            "user_message": self.user_message,
            "metadata": self.metadata,
            "sandbox_id": self.sandbox_id,
            "final_reply": self.final_reply,
            "error": self.error,
        }
        if include_steps:
            out["steps"] = [s.to_dict() for s in self.steps]
        return out


# ---------------------------------------------------------------------------
# CRUD helpers
# ---------------------------------------------------------------------------

async def create_task(user_message: str, metadata: Optional[Dict[str, Any]] = None) -> Task:
    await init()
    now = time.time()
    task = Task(
        id=uuid.uuid4().hex,
        created_at=now,
        updated_at=now,
        state=TaskState.QUEUED,
        user_message=user_message,
        metadata=metadata or {},
    )
    def _do():
        with _conn() as c:
            c.execute(
                "INSERT INTO tasks (id, created_at, updated_at, state, user_message, metadata) "
                "VALUES (?,?,?,?,?,?)",
                (task.id, task.created_at, task.updated_at, task.state,
                 task.user_message, json.dumps(task.metadata)),
            )
    async with _write_lock:
        await asyncio.to_thread(_do)
    return task


async def update_state(task_id: str, state: str, *, error: Optional[str] = None,
                       sandbox_id: Optional[str] = None,
                       final_reply: Optional[str] = None) -> None:
    await init()
    if state not in TaskState.ALL:
        raise ValueError(f"invalid state: {state}")
    now = time.time()
    def _do():
        fields = ["state = ?", "updated_at = ?"]
        params: List[Any] = [state, now]
        if error is not None:
            fields.append("error = ?")
            params.append(error)
        if sandbox_id is not None:
            fields.append("sandbox_id = ?")
            params.append(sandbox_id)
        if final_reply is not None:
            fields.append("final_reply = ?")
            params.append(final_reply)
        params.append(task_id)
        with _conn() as c:
            c.execute(f"UPDATE tasks SET {', '.join(fields)} WHERE id = ?", params)
    async with _write_lock:
        await asyncio.to_thread(_do)


async def set_steps(task_id: str, steps: Iterable[Dict[str, str]]) -> List[Step]:
    """Replace the plan for a task. Each input dict needs ``title`` and
    optional ``description``."""
    await init()
    rows: List[Step] = []
    for i, raw in enumerate(steps):
        rows.append(Step(idx=i, title=str(raw.get("title", f"Step {i+1}")),
                         description=str(raw.get("description", ""))))
    def _do():
        with _conn() as c:
            c.execute("DELETE FROM steps WHERE task_id = ?", (task_id,))
            c.executemany(
                "INSERT INTO steps (task_id, idx, title, description, state) VALUES (?,?,?,?,?)",
                [(task_id, s.idx, s.title, s.description, s.state) for s in rows],
            )
            c.execute("UPDATE tasks SET updated_at = ? WHERE id = ?", (time.time(), task_id))
    async with _write_lock:
        await asyncio.to_thread(_do)
    return rows


async def update_step(task_id: str, idx: int, *, state: Optional[str] = None,
                      attempts_delta: int = 0, result: Optional[str] = None,
                      error: Optional[str] = None) -> None:
    await init()
    now = time.time()
    def _do():
        fields: List[str] = []
        params: List[Any] = []
        if state is not None:
            fields.append("state = ?")
            params.append(state)
            if state == TaskState.EXECUTING:
                fields.append("started_at = COALESCE(started_at, ?)")
                params.append(now)
            elif state in TaskState.TERMINAL or state in (TaskState.COMPLETED, TaskState.FAILED):
                fields.append("finished_at = ?")
                params.append(now)
        if attempts_delta:
            fields.append("attempts = attempts + ?")
            params.append(attempts_delta)
        if result is not None:
            fields.append("result = ?")
            params.append(result[:4000])
        if error is not None:
            fields.append("error = ?")
            params.append(error[:4000])
        if not fields:
            return
        params.extend([task_id, idx])
        with _conn() as c:
            c.execute(
                f"UPDATE steps SET {', '.join(fields)} WHERE task_id = ? AND idx = ?",
                params,
            )
            c.execute("UPDATE tasks SET updated_at = ? WHERE id = ?", (now, task_id))
    async with _write_lock:
        await asyncio.to_thread(_do)


async def append_event(task_id: str, kind: str, payload: Any, step_idx: Optional[int] = None) -> None:
    await init()
    def _do():
        with _conn() as c:
            c.execute(
                "INSERT INTO events (task_id, step_idx, ts, kind, payload) VALUES (?,?,?,?,?)",
                (task_id, step_idx, time.time(), kind,
                 json.dumps(payload, ensure_ascii=False, default=str)[:8000]),
            )
    async with _write_lock:
        try:
            await asyncio.to_thread(_do)
        except Exception as e:
            # Logging only β€” events are diagnostics; never break the stream.
            logger.warning("append_event failed: %s", e)


# ---------------------------------------------------------------------------
# Read helpers
# ---------------------------------------------------------------------------

async def get_task(task_id: str) -> Optional[Task]:
    await init()
    def _do():
        with _conn() as c:
            row = c.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone()
            if not row:
                return None
            steps_rows = c.execute(
                "SELECT * FROM steps WHERE task_id = ? ORDER BY idx", (task_id,)
            ).fetchall()
            return row, steps_rows
    result = await asyncio.to_thread(_do)
    if not result:
        return None
    row, steps_rows = result
    return Task(
        id=row["id"],
        created_at=row["created_at"],
        updated_at=row["updated_at"],
        state=row["state"],
        user_message=row["user_message"],
        metadata=json.loads(row["metadata"] or "{}"),
        sandbox_id=row["sandbox_id"],
        final_reply=row["final_reply"],
        error=row["error"],
        steps=[
            Step(
                idx=s["idx"], title=s["title"], description=s["description"] or "",
                state=s["state"], attempts=s["attempts"],
                started_at=s["started_at"], finished_at=s["finished_at"],
                result=s["result"], error=s["error"],
            )
            for s in steps_rows
        ],
    )


async def list_tasks(limit: int = 50, state: Optional[str] = None) -> List[Dict[str, Any]]:
    await init()
    def _do():
        with _conn() as c:
            if state:
                rows = c.execute(
                    "SELECT id, created_at, updated_at, state, user_message, sandbox_id "
                    "FROM tasks WHERE state = ? ORDER BY created_at DESC LIMIT ?",
                    (state, limit),
                ).fetchall()
            else:
                rows = c.execute(
                    "SELECT id, created_at, updated_at, state, user_message, sandbox_id "
                    "FROM tasks ORDER BY created_at DESC LIMIT ?",
                    (limit,),
                ).fetchall()
            return [dict(r) for r in rows]
    return await asyncio.to_thread(_do)


async def get_events(task_id: str, after_id: int = 0, limit: int = 500) -> List[Dict[str, Any]]:
    await init()
    def _do():
        with _conn() as c:
            rows = c.execute(
                "SELECT id, step_idx, ts, kind, payload FROM events "
                "WHERE task_id = ? AND id > ? ORDER BY id LIMIT ?",
                (task_id, after_id, limit),
            ).fetchall()
            out: List[Dict[str, Any]] = []
            for r in rows:
                try:
                    payload = json.loads(r["payload"])
                except Exception:
                    payload = {"raw": r["payload"]}
                out.append({
                    "id": r["id"],
                    "step_idx": r["step_idx"],
                    "ts": r["ts"],
                    "kind": r["kind"],
                    "payload": payload,
                })
            return out
    return await asyncio.to_thread(_do)


async def delete_task(task_id: str) -> bool:
    await init()
    def _do():
        with _conn() as c:
            cur = c.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
            return cur.rowcount > 0
    async with _write_lock:
        return await asyncio.to_thread(_do)