File size: 7,137 Bytes
aceb1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Universal SQLite Helper

Provides a process-wide cache of WAL-mode SQLite connections keyed by task_dir,
plus a migration registry so each feature module declares the tables it owns.

Design choices:
- One database file per project: `<task_dir>/project.sqlite`.
- WAL journal mode for concurrent reads alongside writes.
- `foreign_keys = ON` enforced on every connection.
- Migrations are idempotent and tracked in a `schema_migrations` table.
- Modules register migrations at import time via `register_migration()`;
  pending migrations run on first `get_db(task_dir)` call.
- Thread-safe: a per-process lock guards the cache and the migration runner.

Usage:
    from potato.persistence import register_migration, Migration, get_db

    register_migration(Migration(
        name="0001_memos",
        sql=\"""CREATE TABLE IF NOT EXISTS memos (
            id TEXT PRIMARY KEY,
            ...
        );\""",
    ))

    conn = get_db(task_dir)
    conn.execute("INSERT INTO memos ...")
    conn.commit()
"""

from __future__ import annotations

import logging
import os
import sqlite3
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class Migration:
    """A single, idempotent schema migration.

    Attributes:
        name: Unique identifier (e.g. "0001_memos"). Used as the migration
            key in the `schema_migrations` table — running twice is a no-op.
        sql: SQL statement(s). Multi-statement scripts allowed; runs via
            ``connection.executescript``.
    """
    name: str
    sql: str


_MIGRATIONS: List[Migration] = []
_MIGRATION_NAMES: set = set()
_REGISTRY_LOCK = threading.Lock()

_DB_CACHE: Dict[str, sqlite3.Connection] = {}
_DB_CACHE_LOCK = threading.Lock()

_SCHEMA_MIGRATIONS_DDL = """
CREATE TABLE IF NOT EXISTS schema_migrations (
    name TEXT PRIMARY KEY,
    applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
);
"""


def register_migration(migration: Migration) -> None:
    """Register a migration so it runs on the next `get_db()` call.

    Re-registering a migration with the same name is a no-op; this lets
    modules call `register_migration` unconditionally at import time without
    worrying about double-imports.
    """
    with _REGISTRY_LOCK:
        if migration.name in _MIGRATION_NAMES:
            return
        _MIGRATIONS.append(migration)
        _MIGRATION_NAMES.add(migration.name)
        logger.debug(f"Registered migration: {migration.name}")


def registered_migrations() -> List[Migration]:
    """Return a copy of the current migration registry (in registration order)."""
    with _REGISTRY_LOCK:
        return list(_MIGRATIONS)


def clear_migrations() -> None:
    """Reset the process-global migration registry. Tests only.

    The registry is process-global and migrations are normally registered
    once at import time. A test that registers an ad-hoc migration would
    otherwise leak it into every later test's `get_db()` in the same
    pytest process. Call this (and `clear_db_cache()`) for isolation.
    """
    with _REGISTRY_LOCK:
        _MIGRATIONS.clear()
        _MIGRATION_NAMES.clear()


def get_db(task_dir: str) -> sqlite3.Connection:
    """Return the cached WAL-mode SQLite connection for this project.

    On first call for a given task_dir, opens `<task_dir>/project.sqlite`,
    sets WAL + foreign_keys, and runs any pending migrations.

    Connections are cached per task_dir and reused across requests. They
    are NOT thread-local — SQLite connections created with
    `check_same_thread=False` are safe for serialized access from multiple
    threads, which matches Flask's per-request threading model.
    """
    abs_dir = os.path.abspath(task_dir)
    with _DB_CACHE_LOCK:
        existing = _DB_CACHE.get(abs_dir)
        if existing is not None:
            # Apply any migrations registered AFTER this connection was
            # first opened. Idempotent and cheap (one indexed SELECT when
            # nothing is pending). Without this, a feature whose module
            # is imported/registered later than another feature's first
            # get_db() call would never get its tables created.
            _run_pending_migrations(existing)
            return existing
        os.makedirs(abs_dir, exist_ok=True)
        db_path = os.path.join(abs_dir, "project.sqlite")
        conn = sqlite3.connect(
            db_path,
            check_same_thread=False,
            detect_types=sqlite3.PARSE_DECLTYPES,
            isolation_level=None,  # autocommit; callers manage transactions
        )
        conn.row_factory = sqlite3.Row
        conn.execute("PRAGMA journal_mode = WAL")
        conn.execute("PRAGMA foreign_keys = ON")
        conn.execute("PRAGMA synchronous = NORMAL")
        _run_pending_migrations(conn)
        _DB_CACHE[abs_dir] = conn
        logger.info(f"Opened SQLite project DB: {db_path}")
        return conn


def close_db(task_dir: str) -> None:
    """Close and evict the cached connection for one task_dir."""
    abs_dir = os.path.abspath(task_dir)
    with _DB_CACHE_LOCK:
        conn = _DB_CACHE.pop(abs_dir, None)
    if conn is not None:
        try:
            conn.close()
        except sqlite3.Error as e:
            logger.warning(f"Error closing DB for {abs_dir}: {e}")


def clear_db_cache() -> None:
    """Close every cached connection. Primarily for tests."""
    with _DB_CACHE_LOCK:
        connections = list(_DB_CACHE.values())
        _DB_CACHE.clear()
    for conn in connections:
        try:
            conn.close()
        except sqlite3.Error:
            pass


def _run_pending_migrations(conn: sqlite3.Connection) -> None:
    """Apply migrations that haven't been recorded in schema_migrations yet."""
    conn.executescript(_SCHEMA_MIGRATIONS_DDL)
    applied = {
        row["name"]
        for row in conn.execute("SELECT name FROM schema_migrations").fetchall()
    }
    with _REGISTRY_LOCK:
        pending = [m for m in _MIGRATIONS if m.name not in applied]

    if not pending:
        return

    # Note on atomicity: Python's `executescript()` issues an implicit COMMIT
    # at the start, which dissolves any transaction (BEGIN/COMMIT *or*
    # SAVEPOINT/RELEASE) we wrap around it. So we don't wrap. The convention
    # is that every migration uses idempotent DDL (`CREATE TABLE IF NOT
    # EXISTS`, `CREATE INDEX IF NOT EXISTS`, etc.); if `executescript()`
    # raises, the migration record is *not* inserted, and the migration will
    # retry cleanly on the next `get_db()` call. `INSERT OR IGNORE` makes the
    # success path safe against double-application from a benign race.
    for migration in pending:
        try:
            conn.executescript(migration.sql)
            conn.execute(
                "INSERT OR IGNORE INTO schema_migrations (name) VALUES (?)",
                (migration.name,),
            )
            logger.info(f"Applied migration: {migration.name}")
        except sqlite3.Error as e:
            logger.error(f"Migration {migration.name} failed: {e}")
            raise