Spaces:
Running
Running
File size: 2,733 Bytes
4e88df3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | """SQLite connection management and schema initialisation."""
import sqlite3
import threading
from pathlib import Path
from db.models import ALL_TABLES
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "vital.db"
_connection: sqlite3.Connection | None = None
_write_lock = threading.Lock()
_db_path: Path = DEFAULT_DB_PATH
def get_db_path() -> Path:
"""Return the active database file path."""
return _db_path
def set_db_path(path: Path) -> None:
"""Override the database path (used by tests)."""
global _db_path, _connection
close_connection()
_db_path = path
def get_connection() -> sqlite3.Connection:
"""Return the shared SQLite connection, creating it if needed."""
global _connection
if _connection is None:
_db_path.parent.mkdir(parents=True, exist_ok=True)
_connection = sqlite3.connect(
str(_db_path),
check_same_thread=False,
)
_connection.row_factory = sqlite3.Row
_connection.execute("PRAGMA foreign_keys = ON;")
return _connection
def close_connection() -> None:
"""Close the shared connection so a new path can be used."""
global _connection
if _connection is not None:
_connection.close()
_connection = None
def _run_migrations(connection: sqlite3.Connection) -> None:
"""Apply lightweight schema migrations for existing databases."""
profile_columns = connection.execute("PRAGMA table_info(profile);").fetchall()
column_names = {row["name"] for row in profile_columns}
if "profession" not in column_names:
connection.execute("ALTER TABLE profile ADD COLUMN profession TEXT DEFAULT '';")
def initialize_database() -> None:
"""Create all tables if they do not already exist."""
connection = get_connection()
with _write_lock:
for table_sql in ALL_TABLES:
connection.execute(table_sql)
_run_migrations(connection)
connection.commit()
def reset_database() -> None:
"""Drop and recreate all tables (test helper only)."""
connection = get_connection()
with _write_lock:
connection.execute("PRAGMA foreign_keys = OFF;")
cursor = connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%';"
)
table_names = [row["name"] for row in cursor.fetchall()]
for table_name in table_names:
connection.execute(f"DROP TABLE IF EXISTS {table_name};")
connection.execute("PRAGMA foreign_keys = ON;")
for table_sql in ALL_TABLES:
connection.execute(table_sql)
_run_migrations(connection)
connection.commit()
|