| """ |
| database.py β SQLite persistence layer for Autonomy Calibration Environment. |
| |
| Uses stdlib sqlite3 only β no external dependencies. |
| |
| Tables: |
| episodes β one row per episode (id, task, seed, start_time, end_time, total_reward) |
| steps β one row per environment step (episode_id, step_index, decision, reward, done) |
| |
| Public API: |
| init_db() β create tables (idempotent) |
| create_episode(task, seed) β insert episode row, return episode_id |
| log_step(...) β insert step row |
| close_episode(id, score) β update episode with final score + end_time |
| get_episode(id) β fetch episode + all steps |
| list_episodes(limit) β list recent episodes |
| replay_episode(id) β return ordered step list for replay |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import sqlite3 |
| from contextlib import contextmanager |
| from datetime import datetime, timezone |
| from typing import Any, Generator |
|
|
| logger = logging.getLogger(__name__) |
|
|
| DB_PATH = os.getenv("AUTONOMY_ENV_DB", "autonomy_env.db") |
|
|
| _SCHEMA = """ |
| CREATE TABLE IF NOT EXISTS episodes ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| task TEXT NOT NULL, |
| seed INTEGER, |
| started_at TEXT NOT NULL, |
| ended_at TEXT, |
| total_reward REAL DEFAULT 0.0, |
| done INTEGER DEFAULT 0 |
| ); |
| |
| CREATE TABLE IF NOT EXISTS steps ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| episode_id INTEGER NOT NULL REFERENCES episodes(id), |
| step_index INTEGER NOT NULL, |
| decision TEXT NOT NULL, |
| reward REAL NOT NULL, |
| done INTEGER NOT NULL DEFAULT 0, |
| timestamp TEXT NOT NULL |
| ); |
| |
| CREATE INDEX IF NOT EXISTS idx_steps_episode ON steps(episode_id); |
| """ |
|
|
|
|
| |
|
|
| @contextmanager |
| def _conn(path: str = DB_PATH) -> Generator[sqlite3.Connection, None, None]: |
| """Context-managed SQLite connection with WAL mode for concurrent safety.""" |
| conn = sqlite3.connect(path, check_same_thread=False) |
| conn.execute("PRAGMA journal_mode=WAL") |
| conn.execute("PRAGMA foreign_keys=ON") |
| conn.row_factory = sqlite3.Row |
| try: |
| yield conn |
| conn.commit() |
| except Exception: |
| conn.rollback() |
| raise |
| finally: |
| conn.close() |
|
|
|
|
| |
|
|
| def init_db(path: str = DB_PATH) -> None: |
| """Create all tables if they don't exist. Safe to call multiple times.""" |
| with _conn(path) as c: |
| c.executescript(_SCHEMA) |
| logger.info("DB: Initialised SQLite at %s", path) |
|
|
|
|
| |
|
|
| def create_episode(task: str, seed: int | None, path: str = DB_PATH) -> int: |
| """Insert a new episode row. Returns the new episode_id.""" |
| _ensure(path) |
| now = _now() |
| with _conn(path) as c: |
| cur = c.execute( |
| "INSERT INTO episodes (task, seed, started_at) VALUES (?, ?, ?)", |
| (task, seed, now), |
| ) |
| eid = cur.lastrowid |
| logger.debug("DB: Episode created id=%d task=%s seed=%s", eid, task, seed) |
| return eid |
|
|
|
|
| def log_step( |
| episode_id: int, |
| step_index: int, |
| decision: str, |
| reward: float, |
| done: bool, |
| path: str = DB_PATH, |
| ) -> None: |
| """Record a single environment step.""" |
| with _conn(path) as c: |
| c.execute( |
| "INSERT INTO steps (episode_id, step_index, decision, reward, done, timestamp) " |
| "VALUES (?, ?, ?, ?, ?, ?)", |
| (episode_id, step_index, decision, round(reward, 4), int(done), _now()), |
| ) |
|
|
|
|
| def close_episode(episode_id: int, total_reward: float, path: str = DB_PATH) -> None: |
| """Mark episode as done and record final score.""" |
| with _conn(path) as c: |
| c.execute( |
| "UPDATE episodes SET ended_at=?, total_reward=?, done=1 WHERE id=?", |
| (_now(), round(total_reward, 4), episode_id), |
| ) |
| logger.debug("DB: Episode closed id=%d score=%.4f", episode_id, total_reward) |
|
|
|
|
| |
|
|
| def list_episodes(limit: int = 20, path: str = DB_PATH) -> list[dict[str, Any]]: |
| """Return the most recent `limit` episodes.""" |
| _ensure(path) |
| with _conn(path) as c: |
| rows = c.execute( |
| "SELECT * FROM episodes ORDER BY id DESC LIMIT ?", (limit,) |
| ).fetchall() |
| return [dict(r) for r in rows] |
|
|
|
|
| def get_episode(episode_id: int, path: str = DB_PATH) -> dict[str, Any]: |
| """Return full episode dict including all steps.""" |
| _ensure(path) |
| with _conn(path) as c: |
| ep = c.execute("SELECT * FROM episodes WHERE id=?", (episode_id,)).fetchone() |
| if ep is None: |
| raise ValueError(f"Episode {episode_id} not found.") |
| steps = c.execute( |
| "SELECT * FROM steps WHERE episode_id=? ORDER BY step_index ASC", |
| (episode_id,), |
| ).fetchall() |
| return { |
| "episode": dict(ep), |
| "steps": [dict(s) for s in steps], |
| } |
|
|
|
|
| def replay_episode(episode_id: int, path: str = DB_PATH) -> list[dict[str, Any]]: |
| """Return ordered step list for replay β same as get_episode but steps only.""" |
| return get_episode(episode_id, path)["steps"] |
|
|
|
|
| |
|
|
| _initialised: set[str] = set() |
|
|
| def _ensure(path: str = DB_PATH) -> None: |
| """Lazy init β create schema on first use.""" |
| if path not in _initialised: |
| init_db(path) |
| _initialised.add(path) |
|
|
|
|
| def _now() -> str: |
| return datetime.now(timezone.utc).isoformat() |
|
|
|
|
| |
| _ensure(DB_PATH) |
|
|