from __future__ import annotations from collections.abc import Generator from pathlib import Path from sqlalchemy import create_engine, inspect, text from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import BASE_DIR, get_settings class Base(DeclarativeBase): pass def _sqlite_path_from_url(database_url: str) -> Path | None: if not database_url.startswith("sqlite:///"): return None raw_path = database_url.removeprefix("sqlite:///") path = Path(raw_path) if not path.is_absolute(): path = BASE_DIR / path return path settings = get_settings() sqlite_path = _sqlite_path_from_url(settings.database_url) if sqlite_path is not None: sqlite_path.parent.mkdir(parents=True, exist_ok=True) engine = create_engine( settings.database_url, connect_args={"check_same_thread": False} if settings.database_url.startswith("sqlite") else {}, ) SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) def init_db() -> None: # Importing models registers their tables with SQLAlchemy metadata. from app import models # noqa: F401 Base.metadata.create_all(bind=engine) _upgrade_sqlite_schema() def _upgrade_sqlite_schema() -> None: if not settings.database_url.startswith("sqlite"): return upgrades = { "users": { "shirt_color": "VARCHAR(20) NOT NULL DEFAULT 'red'", "pants_color": "VARCHAR(20) NOT NULL DEFAULT 'navy'", "hair_color": "VARCHAR(20) NOT NULL DEFAULT 'dark_brown'", }, "lessons": { "villain_image_url": "VARCHAR(500)", }, } with engine.begin() as connection: inspector = inspect(connection) for table, columns in upgrades.items(): if table not in inspector.get_table_names(): continue existing = {column["name"] for column in inspector.get_columns(table)} for name, definition in columns.items(): if name not in existing: connection.execute(text(f"ALTER TABLE {table} ADD COLUMN {name} {definition}")) def get_db() -> Generator[Session, None, None]: db = SessionLocal() try: yield db finally: db.close()