Vital / db /database.py
eddyejembi's picture
Upload folder using huggingface_hub
4e88df3 verified
Raw
History Blame Contribute Delete
2.73 kB
"""SQLite connection management and schema initialisation."""
import sqlite3
import threading
from pathlib import Path
from db.models import ALL_TABLES
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "vital.db"
_connection: sqlite3.Connection | None = None
_write_lock = threading.Lock()
_db_path: Path = DEFAULT_DB_PATH
def get_db_path() -> Path:
"""Return the active database file path."""
return _db_path
def set_db_path(path: Path) -> None:
"""Override the database path (used by tests)."""
global _db_path, _connection
close_connection()
_db_path = path
def get_connection() -> sqlite3.Connection:
"""Return the shared SQLite connection, creating it if needed."""
global _connection
if _connection is None:
_db_path.parent.mkdir(parents=True, exist_ok=True)
_connection = sqlite3.connect(
str(_db_path),
check_same_thread=False,
)
_connection.row_factory = sqlite3.Row
_connection.execute("PRAGMA foreign_keys = ON;")
return _connection
def close_connection() -> None:
"""Close the shared connection so a new path can be used."""
global _connection
if _connection is not None:
_connection.close()
_connection = None
def _run_migrations(connection: sqlite3.Connection) -> None:
"""Apply lightweight schema migrations for existing databases."""
profile_columns = connection.execute("PRAGMA table_info(profile);").fetchall()
column_names = {row["name"] for row in profile_columns}
if "profession" not in column_names:
connection.execute("ALTER TABLE profile ADD COLUMN profession TEXT DEFAULT '';")
def initialize_database() -> None:
"""Create all tables if they do not already exist."""
connection = get_connection()
with _write_lock:
for table_sql in ALL_TABLES:
connection.execute(table_sql)
_run_migrations(connection)
connection.commit()
def reset_database() -> None:
"""Drop and recreate all tables (test helper only)."""
connection = get_connection()
with _write_lock:
connection.execute("PRAGMA foreign_keys = OFF;")
cursor = connection.execute(
"SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%';"
)
table_names = [row["name"] for row in cursor.fetchall()]
for table_name in table_names:
connection.execute(f"DROP TABLE IF EXISTS {table_name};")
connection.execute("PRAGMA foreign_keys = ON;")
for table_sql in ALL_TABLES:
connection.execute(table_sql)
_run_migrations(connection)
connection.commit()