Spaces:
Sleeping
Sleeping
| """ | |
| Persistent task memory (Phase 2). | |
| A lightweight SQLite store that keeps the full lifecycle, plan and event | |
| history of every agent run. The DB lives in $TASK_DB_PATH (default | |
| ``/home/user/app/data/tasks.db``); the directory is created on demand so the | |
| container starts cleanly even on a fresh volume. | |
| This module is INTENTIONALLY additive — none of the Phase-1 endpoints | |
| import it, so existing behaviour cannot regress if SQLite is unavailable. | |
| States (must match TaskState below): | |
| queued → planning → thinking → executing → retrying → completed | failed | |
| ↘────────────↗ | |
| Concurrency model: a single ``sqlite3.connect`` per call, ``WAL`` journal mode, | |
| all writes guarded by a process-wide ``asyncio.Lock``. This is plenty for the | |
| single-worker uvicorn deployment we run on HF Spaces. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sqlite3 | |
| import time | |
| import uuid | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Iterable, List, Optional | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Lifecycle states | |
| # --------------------------------------------------------------------------- | |
| class TaskState: | |
| QUEUED = "queued" | |
| PLANNING = "planning" | |
| THINKING = "thinking" | |
| EXECUTING = "executing" | |
| RETRYING = "retrying" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| CANCELLED = "cancelled" | |
| ALL = {QUEUED, PLANNING, THINKING, EXECUTING, RETRYING, COMPLETED, FAILED, CANCELLED} | |
| TERMINAL = {COMPLETED, FAILED, CANCELLED} | |
| # --------------------------------------------------------------------------- | |
| # DB location & schema | |
| # --------------------------------------------------------------------------- | |
| def _db_path() -> str: | |
| raw = os.environ.get("TASK_DB_PATH", "/home/user/app/data/tasks.db") | |
| os.makedirs(os.path.dirname(raw) or ".", exist_ok=True) | |
| return raw | |
| _SCHEMA = """ | |
| CREATE TABLE IF NOT EXISTS tasks ( | |
| id TEXT PRIMARY KEY, | |
| created_at REAL NOT NULL, | |
| updated_at REAL NOT NULL, | |
| state TEXT NOT NULL, | |
| user_message TEXT NOT NULL, | |
| metadata TEXT NOT NULL DEFAULT '{}', | |
| sandbox_id TEXT, | |
| final_reply TEXT, | |
| error TEXT | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_tasks_created_at ON tasks(created_at DESC); | |
| CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state); | |
| CREATE TABLE IF NOT EXISTS steps ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| task_id TEXT NOT NULL REFERENCES tasks(id) ON DELETE CASCADE, | |
| idx INTEGER NOT NULL, | |
| title TEXT NOT NULL, | |
| description TEXT NOT NULL DEFAULT '', | |
| state TEXT NOT NULL DEFAULT 'queued', | |
| attempts INTEGER NOT NULL DEFAULT 0, | |
| started_at REAL, | |
| finished_at REAL, | |
| result TEXT, | |
| error TEXT | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_steps_task ON steps(task_id, idx); | |
| CREATE TABLE IF NOT EXISTS events ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| task_id TEXT NOT NULL REFERENCES tasks(id) ON DELETE CASCADE, | |
| step_idx INTEGER, | |
| ts REAL NOT NULL, | |
| kind TEXT NOT NULL, | |
| payload TEXT NOT NULL DEFAULT '{}' | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_events_task ON events(task_id, id); | |
| """ | |
| _init_lock = asyncio.Lock() | |
| _initialised = False | |
| _write_lock = asyncio.Lock() | |
| def _conn(): | |
| path = _db_path() | |
| c = sqlite3.connect(path, timeout=10, isolation_level=None) | |
| c.row_factory = sqlite3.Row | |
| try: | |
| c.execute("PRAGMA journal_mode=WAL;") | |
| c.execute("PRAGMA synchronous=NORMAL;") | |
| c.execute("PRAGMA foreign_keys=ON;") | |
| yield c | |
| finally: | |
| c.close() | |
| async def init() -> None: | |
| """Create tables if necessary. Safe to call many times.""" | |
| global _initialised | |
| if _initialised: | |
| return | |
| async with _init_lock: | |
| if _initialised: | |
| return | |
| def _do(): | |
| with _conn() as c: | |
| c.executescript(_SCHEMA) | |
| await asyncio.to_thread(_do) | |
| _initialised = True | |
| logger.info("task DB ready at %s", _db_path()) | |
| # --------------------------------------------------------------------------- | |
| # Dataclasses | |
| # --------------------------------------------------------------------------- | |
| class Step: | |
| idx: int | |
| title: str | |
| description: str = "" | |
| state: str = TaskState.QUEUED | |
| attempts: int = 0 | |
| started_at: Optional[float] = None | |
| finished_at: Optional[float] = None | |
| result: Optional[str] = None | |
| error: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "idx": self.idx, | |
| "title": self.title, | |
| "description": self.description, | |
| "state": self.state, | |
| "attempts": self.attempts, | |
| "started_at": self.started_at, | |
| "finished_at": self.finished_at, | |
| "result": self.result, | |
| "error": self.error, | |
| } | |
| class Task: | |
| id: str | |
| created_at: float | |
| updated_at: float | |
| state: str | |
| user_message: str | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| sandbox_id: Optional[str] = None | |
| final_reply: Optional[str] = None | |
| error: Optional[str] = None | |
| steps: List[Step] = field(default_factory=list) | |
| def to_dict(self, include_steps: bool = True) -> Dict[str, Any]: | |
| out: Dict[str, Any] = { | |
| "id": self.id, | |
| "created_at": self.created_at, | |
| "updated_at": self.updated_at, | |
| "state": self.state, | |
| "user_message": self.user_message, | |
| "metadata": self.metadata, | |
| "sandbox_id": self.sandbox_id, | |
| "final_reply": self.final_reply, | |
| "error": self.error, | |
| } | |
| if include_steps: | |
| out["steps"] = [s.to_dict() for s in self.steps] | |
| return out | |
| # --------------------------------------------------------------------------- | |
| # CRUD helpers | |
| # --------------------------------------------------------------------------- | |
| async def create_task(user_message: str, metadata: Optional[Dict[str, Any]] = None) -> Task: | |
| await init() | |
| now = time.time() | |
| task = Task( | |
| id=uuid.uuid4().hex, | |
| created_at=now, | |
| updated_at=now, | |
| state=TaskState.QUEUED, | |
| user_message=user_message, | |
| metadata=metadata or {}, | |
| ) | |
| def _do(): | |
| with _conn() as c: | |
| c.execute( | |
| "INSERT INTO tasks (id, created_at, updated_at, state, user_message, metadata) " | |
| "VALUES (?,?,?,?,?,?)", | |
| (task.id, task.created_at, task.updated_at, task.state, | |
| task.user_message, json.dumps(task.metadata)), | |
| ) | |
| async with _write_lock: | |
| await asyncio.to_thread(_do) | |
| return task | |
| async def update_state(task_id: str, state: str, *, error: Optional[str] = None, | |
| sandbox_id: Optional[str] = None, | |
| final_reply: Optional[str] = None) -> None: | |
| await init() | |
| if state not in TaskState.ALL: | |
| raise ValueError(f"invalid state: {state}") | |
| now = time.time() | |
| def _do(): | |
| fields = ["state = ?", "updated_at = ?"] | |
| params: List[Any] = [state, now] | |
| if error is not None: | |
| fields.append("error = ?") | |
| params.append(error) | |
| if sandbox_id is not None: | |
| fields.append("sandbox_id = ?") | |
| params.append(sandbox_id) | |
| if final_reply is not None: | |
| fields.append("final_reply = ?") | |
| params.append(final_reply) | |
| params.append(task_id) | |
| with _conn() as c: | |
| c.execute(f"UPDATE tasks SET {', '.join(fields)} WHERE id = ?", params) | |
| async with _write_lock: | |
| await asyncio.to_thread(_do) | |
| async def set_steps(task_id: str, steps: Iterable[Dict[str, str]]) -> List[Step]: | |
| """Replace the plan for a task. Each input dict needs ``title`` and | |
| optional ``description``.""" | |
| await init() | |
| rows: List[Step] = [] | |
| for i, raw in enumerate(steps): | |
| rows.append(Step(idx=i, title=str(raw.get("title", f"Step {i+1}")), | |
| description=str(raw.get("description", "")))) | |
| def _do(): | |
| with _conn() as c: | |
| c.execute("DELETE FROM steps WHERE task_id = ?", (task_id,)) | |
| c.executemany( | |
| "INSERT INTO steps (task_id, idx, title, description, state) VALUES (?,?,?,?,?)", | |
| [(task_id, s.idx, s.title, s.description, s.state) for s in rows], | |
| ) | |
| c.execute("UPDATE tasks SET updated_at = ? WHERE id = ?", (time.time(), task_id)) | |
| async with _write_lock: | |
| await asyncio.to_thread(_do) | |
| return rows | |
| async def update_step(task_id: str, idx: int, *, state: Optional[str] = None, | |
| attempts_delta: int = 0, result: Optional[str] = None, | |
| error: Optional[str] = None) -> None: | |
| await init() | |
| now = time.time() | |
| def _do(): | |
| fields: List[str] = [] | |
| params: List[Any] = [] | |
| if state is not None: | |
| fields.append("state = ?") | |
| params.append(state) | |
| if state == TaskState.EXECUTING: | |
| fields.append("started_at = COALESCE(started_at, ?)") | |
| params.append(now) | |
| elif state in TaskState.TERMINAL or state in (TaskState.COMPLETED, TaskState.FAILED): | |
| fields.append("finished_at = ?") | |
| params.append(now) | |
| if attempts_delta: | |
| fields.append("attempts = attempts + ?") | |
| params.append(attempts_delta) | |
| if result is not None: | |
| fields.append("result = ?") | |
| params.append(result[:4000]) | |
| if error is not None: | |
| fields.append("error = ?") | |
| params.append(error[:4000]) | |
| if not fields: | |
| return | |
| params.extend([task_id, idx]) | |
| with _conn() as c: | |
| c.execute( | |
| f"UPDATE steps SET {', '.join(fields)} WHERE task_id = ? AND idx = ?", | |
| params, | |
| ) | |
| c.execute("UPDATE tasks SET updated_at = ? WHERE id = ?", (now, task_id)) | |
| async with _write_lock: | |
| await asyncio.to_thread(_do) | |
| async def append_event(task_id: str, kind: str, payload: Any, step_idx: Optional[int] = None) -> None: | |
| await init() | |
| def _do(): | |
| with _conn() as c: | |
| c.execute( | |
| "INSERT INTO events (task_id, step_idx, ts, kind, payload) VALUES (?,?,?,?,?)", | |
| (task_id, step_idx, time.time(), kind, | |
| json.dumps(payload, ensure_ascii=False, default=str)[:8000]), | |
| ) | |
| async with _write_lock: | |
| try: | |
| await asyncio.to_thread(_do) | |
| except Exception as e: | |
| # Logging only — events are diagnostics; never break the stream. | |
| logger.warning("append_event failed: %s", e) | |
| # --------------------------------------------------------------------------- | |
| # Read helpers | |
| # --------------------------------------------------------------------------- | |
| async def get_task(task_id: str) -> Optional[Task]: | |
| await init() | |
| def _do(): | |
| with _conn() as c: | |
| row = c.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone() | |
| if not row: | |
| return None | |
| steps_rows = c.execute( | |
| "SELECT * FROM steps WHERE task_id = ? ORDER BY idx", (task_id,) | |
| ).fetchall() | |
| return row, steps_rows | |
| result = await asyncio.to_thread(_do) | |
| if not result: | |
| return None | |
| row, steps_rows = result | |
| return Task( | |
| id=row["id"], | |
| created_at=row["created_at"], | |
| updated_at=row["updated_at"], | |
| state=row["state"], | |
| user_message=row["user_message"], | |
| metadata=json.loads(row["metadata"] or "{}"), | |
| sandbox_id=row["sandbox_id"], | |
| final_reply=row["final_reply"], | |
| error=row["error"], | |
| steps=[ | |
| Step( | |
| idx=s["idx"], title=s["title"], description=s["description"] or "", | |
| state=s["state"], attempts=s["attempts"], | |
| started_at=s["started_at"], finished_at=s["finished_at"], | |
| result=s["result"], error=s["error"], | |
| ) | |
| for s in steps_rows | |
| ], | |
| ) | |
| async def list_tasks(limit: int = 50, state: Optional[str] = None) -> List[Dict[str, Any]]: | |
| await init() | |
| def _do(): | |
| with _conn() as c: | |
| if state: | |
| rows = c.execute( | |
| "SELECT id, created_at, updated_at, state, user_message, sandbox_id " | |
| "FROM tasks WHERE state = ? ORDER BY created_at DESC LIMIT ?", | |
| (state, limit), | |
| ).fetchall() | |
| else: | |
| rows = c.execute( | |
| "SELECT id, created_at, updated_at, state, user_message, sandbox_id " | |
| "FROM tasks ORDER BY created_at DESC LIMIT ?", | |
| (limit,), | |
| ).fetchall() | |
| return [dict(r) for r in rows] | |
| return await asyncio.to_thread(_do) | |
| async def get_events(task_id: str, after_id: int = 0, limit: int = 500) -> List[Dict[str, Any]]: | |
| await init() | |
| def _do(): | |
| with _conn() as c: | |
| rows = c.execute( | |
| "SELECT id, step_idx, ts, kind, payload FROM events " | |
| "WHERE task_id = ? AND id > ? ORDER BY id LIMIT ?", | |
| (task_id, after_id, limit), | |
| ).fetchall() | |
| out: List[Dict[str, Any]] = [] | |
| for r in rows: | |
| try: | |
| payload = json.loads(r["payload"]) | |
| except Exception: | |
| payload = {"raw": r["payload"]} | |
| out.append({ | |
| "id": r["id"], | |
| "step_idx": r["step_idx"], | |
| "ts": r["ts"], | |
| "kind": r["kind"], | |
| "payload": payload, | |
| }) | |
| return out | |
| return await asyncio.to_thread(_do) | |
| async def delete_task(task_id: str) -> bool: | |
| await init() | |
| def _do(): | |
| with _conn() as c: | |
| cur = c.execute("DELETE FROM tasks WHERE id = ?", (task_id,)) | |
| return cur.rowcount > 0 | |
| async with _write_lock: | |
| return await asyncio.to_thread(_do) | |