| from __future__ import annotations |
|
|
| from collections.abc import Generator |
| from pathlib import Path |
|
|
| from sqlalchemy import create_engine, inspect, text |
| from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker |
|
|
| from app.config import BASE_DIR, get_settings |
|
|
|
|
| class Base(DeclarativeBase): |
| pass |
|
|
|
|
| def _sqlite_path_from_url(database_url: str) -> Path | None: |
| if not database_url.startswith("sqlite:///"): |
| return None |
| raw_path = database_url.removeprefix("sqlite:///") |
| path = Path(raw_path) |
| if not path.is_absolute(): |
| path = BASE_DIR / path |
| return path |
|
|
|
|
| settings = get_settings() |
| sqlite_path = _sqlite_path_from_url(settings.database_url) |
| if sqlite_path is not None: |
| sqlite_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| engine = create_engine( |
| settings.database_url, |
| connect_args={"check_same_thread": False} |
| if settings.database_url.startswith("sqlite") |
| else {}, |
| ) |
| SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) |
|
|
|
|
| def init_db() -> None: |
| |
| from app import models |
|
|
| Base.metadata.create_all(bind=engine) |
| _upgrade_sqlite_schema() |
|
|
|
|
| def _upgrade_sqlite_schema() -> None: |
| if not settings.database_url.startswith("sqlite"): |
| return |
| upgrades = { |
| "users": { |
| "shirt_color": "VARCHAR(20) NOT NULL DEFAULT 'red'", |
| "pants_color": "VARCHAR(20) NOT NULL DEFAULT 'navy'", |
| "hair_color": "VARCHAR(20) NOT NULL DEFAULT 'dark_brown'", |
| }, |
| "lessons": { |
| "villain_image_url": "VARCHAR(500)", |
| }, |
| } |
| with engine.begin() as connection: |
| inspector = inspect(connection) |
| for table, columns in upgrades.items(): |
| if table not in inspector.get_table_names(): |
| continue |
| existing = {column["name"] for column in inspector.get_columns(table)} |
| for name, definition in columns.items(): |
| if name not in existing: |
| connection.execute(text(f"ALTER TABLE {table} ADD COLUMN {name} {definition}")) |
|
|
|
|
| def get_db() -> Generator[Session, None, None]: |
| db = SessionLocal() |
| try: |
| yield db |
| finally: |
| db.close() |
|
|