Rhythm@28
deploy: final verified championship submission
ef737d3
"""
database.py β€” SQLite persistence layer for Autonomy Calibration Environment.
Uses stdlib sqlite3 only β€” no external dependencies.
Tables:
episodes β€” one row per episode (id, task, seed, start_time, end_time, total_reward)
steps β€” one row per environment step (episode_id, step_index, decision, reward, done)
Public API:
init_db() β€” create tables (idempotent)
create_episode(task, seed) β€” insert episode row, return episode_id
log_step(...) β€” insert step row
close_episode(id, score) β€” update episode with final score + end_time
get_episode(id) β€” fetch episode + all steps
list_episodes(limit) β€” list recent episodes
replay_episode(id) β€” return ordered step list for replay
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Any, Generator
logger = logging.getLogger(__name__)
DB_PATH = os.getenv("AUTONOMY_ENV_DB", "autonomy_env.db")
_SCHEMA = """
CREATE TABLE IF NOT EXISTS episodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task TEXT NOT NULL,
seed INTEGER,
started_at TEXT NOT NULL,
ended_at TEXT,
total_reward REAL DEFAULT 0.0,
done INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS steps (
id INTEGER PRIMARY KEY AUTOINCREMENT,
episode_id INTEGER NOT NULL REFERENCES episodes(id),
step_index INTEGER NOT NULL,
decision TEXT NOT NULL,
reward REAL NOT NULL,
done INTEGER NOT NULL DEFAULT 0,
timestamp TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_steps_episode ON steps(episode_id);
"""
# ─── Connection ───────────────────────────────────────────────────────────────
@contextmanager
def _conn(path: str = DB_PATH) -> Generator[sqlite3.Connection, None, None]:
"""Context-managed SQLite connection with WAL mode for concurrent safety."""
conn = sqlite3.connect(path, check_same_thread=False)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
# ─── Init ─────────────────────────────────────────────────────────────────────
def init_db(path: str = DB_PATH) -> None:
"""Create all tables if they don't exist. Safe to call multiple times."""
with _conn(path) as c:
c.executescript(_SCHEMA)
logger.info("DB: Initialised SQLite at %s", path)
# ─── Write ────────────────────────────────────────────────────────────────────
def create_episode(task: str, seed: int | None, path: str = DB_PATH) -> int:
"""Insert a new episode row. Returns the new episode_id."""
_ensure(path)
now = _now()
with _conn(path) as c:
cur = c.execute(
"INSERT INTO episodes (task, seed, started_at) VALUES (?, ?, ?)",
(task, seed, now),
)
eid = cur.lastrowid
logger.debug("DB: Episode created id=%d task=%s seed=%s", eid, task, seed)
return eid
def log_step(
episode_id: int,
step_index: int,
decision: str,
reward: float,
done: bool,
path: str = DB_PATH,
) -> None:
"""Record a single environment step."""
with _conn(path) as c:
c.execute(
"INSERT INTO steps (episode_id, step_index, decision, reward, done, timestamp) "
"VALUES (?, ?, ?, ?, ?, ?)",
(episode_id, step_index, decision, round(reward, 4), int(done), _now()),
)
def close_episode(episode_id: int, total_reward: float, path: str = DB_PATH) -> None:
"""Mark episode as done and record final score."""
with _conn(path) as c:
c.execute(
"UPDATE episodes SET ended_at=?, total_reward=?, done=1 WHERE id=?",
(_now(), round(total_reward, 4), episode_id),
)
logger.debug("DB: Episode closed id=%d score=%.4f", episode_id, total_reward)
# ─── Read ─────────────────────────────────────────────────────────────────────
def list_episodes(limit: int = 20, path: str = DB_PATH) -> list[dict[str, Any]]:
"""Return the most recent `limit` episodes."""
_ensure(path)
with _conn(path) as c:
rows = c.execute(
"SELECT * FROM episodes ORDER BY id DESC LIMIT ?", (limit,)
).fetchall()
return [dict(r) for r in rows]
def get_episode(episode_id: int, path: str = DB_PATH) -> dict[str, Any]:
"""Return full episode dict including all steps."""
_ensure(path)
with _conn(path) as c:
ep = c.execute("SELECT * FROM episodes WHERE id=?", (episode_id,)).fetchone()
if ep is None:
raise ValueError(f"Episode {episode_id} not found.")
steps = c.execute(
"SELECT * FROM steps WHERE episode_id=? ORDER BY step_index ASC",
(episode_id,),
).fetchall()
return {
"episode": dict(ep),
"steps": [dict(s) for s in steps],
}
def replay_episode(episode_id: int, path: str = DB_PATH) -> list[dict[str, Any]]:
"""Return ordered step list for replay β€” same as get_episode but steps only."""
return get_episode(episode_id, path)["steps"]
# ─── Helpers ──────────────────────────────────────────────────────────────────
_initialised: set[str] = set()
def _ensure(path: str = DB_PATH) -> None:
"""Lazy init β€” create schema on first use."""
if path not in _initialised:
init_db(path)
_initialised.add(path)
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
# Auto-init on import
_ensure(DB_PATH)