File size: 2,733 Bytes
4e88df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""SQLite connection management and schema initialisation."""

import sqlite3
import threading
from pathlib import Path

from db.models import ALL_TABLES

PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "vital.db"

_connection: sqlite3.Connection | None = None
_write_lock = threading.Lock()
_db_path: Path = DEFAULT_DB_PATH


def get_db_path() -> Path:
    """Return the active database file path."""
    return _db_path


def set_db_path(path: Path) -> None:
    """Override the database path (used by tests)."""
    global _db_path, _connection
    close_connection()
    _db_path = path


def get_connection() -> sqlite3.Connection:
    """Return the shared SQLite connection, creating it if needed."""
    global _connection
    if _connection is None:
        _db_path.parent.mkdir(parents=True, exist_ok=True)
        _connection = sqlite3.connect(
            str(_db_path),
            check_same_thread=False,
        )
        _connection.row_factory = sqlite3.Row
        _connection.execute("PRAGMA foreign_keys = ON;")
    return _connection


def close_connection() -> None:
    """Close the shared connection so a new path can be used."""
    global _connection
    if _connection is not None:
        _connection.close()
        _connection = None


def _run_migrations(connection: sqlite3.Connection) -> None:
    """Apply lightweight schema migrations for existing databases."""
    profile_columns = connection.execute("PRAGMA table_info(profile);").fetchall()
    column_names = {row["name"] for row in profile_columns}
    if "profession" not in column_names:
        connection.execute("ALTER TABLE profile ADD COLUMN profession TEXT DEFAULT '';")


def initialize_database() -> None:
    """Create all tables if they do not already exist."""
    connection = get_connection()
    with _write_lock:
        for table_sql in ALL_TABLES:
            connection.execute(table_sql)
        _run_migrations(connection)
        connection.commit()


def reset_database() -> None:
    """Drop and recreate all tables (test helper only)."""
    connection = get_connection()
    with _write_lock:
        connection.execute("PRAGMA foreign_keys = OFF;")
        cursor = connection.execute(
            "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%';"
        )
        table_names = [row["name"] for row in cursor.fetchall()]
        for table_name in table_names:
            connection.execute(f"DROP TABLE IF EXISTS {table_name};")
        connection.execute("PRAGMA foreign_keys = ON;")
        for table_sql in ALL_TABLES:
            connection.execute(table_sql)
        _run_migrations(connection)
        connection.commit()