File size: 6,480 Bytes
ef737d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
database.py β€” SQLite persistence layer for Autonomy Calibration Environment.

Uses stdlib sqlite3 only β€” no external dependencies.

Tables:
  episodes  β€” one row per episode (id, task, seed, start_time, end_time, total_reward)
  steps     β€” one row per environment step (episode_id, step_index, decision, reward, done)

Public API:
  init_db()                  β€” create tables (idempotent)
  create_episode(task, seed) β€” insert episode row, return episode_id
  log_step(...)              β€” insert step row
  close_episode(id, score)   β€” update episode with final score + end_time
  get_episode(id)            β€” fetch episode + all steps
  list_episodes(limit)       β€” list recent episodes
  replay_episode(id)         β€” return ordered step list for replay
"""
from __future__ import annotations

import json
import logging
import os
import sqlite3
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Any, Generator

logger = logging.getLogger(__name__)

DB_PATH = os.getenv("AUTONOMY_ENV_DB", "autonomy_env.db")

_SCHEMA = """
CREATE TABLE IF NOT EXISTS episodes (
    id            INTEGER PRIMARY KEY AUTOINCREMENT,
    task          TEXT    NOT NULL,
    seed          INTEGER,
    started_at    TEXT    NOT NULL,
    ended_at      TEXT,
    total_reward  REAL    DEFAULT 0.0,
    done          INTEGER DEFAULT 0
);

CREATE TABLE IF NOT EXISTS steps (
    id            INTEGER PRIMARY KEY AUTOINCREMENT,
    episode_id    INTEGER NOT NULL REFERENCES episodes(id),
    step_index    INTEGER NOT NULL,
    decision      TEXT    NOT NULL,
    reward        REAL    NOT NULL,
    done          INTEGER NOT NULL DEFAULT 0,
    timestamp     TEXT    NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_steps_episode ON steps(episode_id);
"""


# ─── Connection ───────────────────────────────────────────────────────────────

@contextmanager
def _conn(path: str = DB_PATH) -> Generator[sqlite3.Connection, None, None]:
    """Context-managed SQLite connection with WAL mode for concurrent safety."""
    conn = sqlite3.connect(path, check_same_thread=False)
    conn.execute("PRAGMA journal_mode=WAL")
    conn.execute("PRAGMA foreign_keys=ON")
    conn.row_factory = sqlite3.Row
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()


# ─── Init ─────────────────────────────────────────────────────────────────────

def init_db(path: str = DB_PATH) -> None:
    """Create all tables if they don't exist. Safe to call multiple times."""
    with _conn(path) as c:
        c.executescript(_SCHEMA)
    logger.info("DB: Initialised SQLite at %s", path)


# ─── Write ────────────────────────────────────────────────────────────────────

def create_episode(task: str, seed: int | None, path: str = DB_PATH) -> int:
    """Insert a new episode row. Returns the new episode_id."""
    _ensure(path)
    now = _now()
    with _conn(path) as c:
        cur = c.execute(
            "INSERT INTO episodes (task, seed, started_at) VALUES (?, ?, ?)",
            (task, seed, now),
        )
        eid = cur.lastrowid
    logger.debug("DB: Episode created id=%d task=%s seed=%s", eid, task, seed)
    return eid


def log_step(
    episode_id: int,
    step_index: int,
    decision: str,
    reward: float,
    done: bool,
    path: str = DB_PATH,
) -> None:
    """Record a single environment step."""
    with _conn(path) as c:
        c.execute(
            "INSERT INTO steps (episode_id, step_index, decision, reward, done, timestamp) "
            "VALUES (?, ?, ?, ?, ?, ?)",
            (episode_id, step_index, decision, round(reward, 4), int(done), _now()),
        )


def close_episode(episode_id: int, total_reward: float, path: str = DB_PATH) -> None:
    """Mark episode as done and record final score."""
    with _conn(path) as c:
        c.execute(
            "UPDATE episodes SET ended_at=?, total_reward=?, done=1 WHERE id=?",
            (_now(), round(total_reward, 4), episode_id),
        )
    logger.debug("DB: Episode closed id=%d score=%.4f", episode_id, total_reward)


# ─── Read ─────────────────────────────────────────────────────────────────────

def list_episodes(limit: int = 20, path: str = DB_PATH) -> list[dict[str, Any]]:
    """Return the most recent `limit` episodes."""
    _ensure(path)
    with _conn(path) as c:
        rows = c.execute(
            "SELECT * FROM episodes ORDER BY id DESC LIMIT ?", (limit,)
        ).fetchall()
    return [dict(r) for r in rows]


def get_episode(episode_id: int, path: str = DB_PATH) -> dict[str, Any]:
    """Return full episode dict including all steps."""
    _ensure(path)
    with _conn(path) as c:
        ep = c.execute("SELECT * FROM episodes WHERE id=?", (episode_id,)).fetchone()
        if ep is None:
            raise ValueError(f"Episode {episode_id} not found.")
        steps = c.execute(
            "SELECT * FROM steps WHERE episode_id=? ORDER BY step_index ASC",
            (episode_id,),
        ).fetchall()
    return {
        "episode": dict(ep),
        "steps": [dict(s) for s in steps],
    }


def replay_episode(episode_id: int, path: str = DB_PATH) -> list[dict[str, Any]]:
    """Return ordered step list for replay β€” same as get_episode but steps only."""
    return get_episode(episode_id, path)["steps"]


# ─── Helpers ──────────────────────────────────────────────────────────────────

_initialised: set[str] = set()

def _ensure(path: str = DB_PATH) -> None:
    """Lazy init β€” create schema on first use."""
    if path not in _initialised:
        init_db(path)
        _initialised.add(path)


def _now() -> str:
    return datetime.now(timezone.utc).isoformat()


# Auto-init on import
_ensure(DB_PATH)