secret-student / app /database.py
asanwari's picture
feat: polish Secret Student game experience
7aeedfb
Raw
History Blame Contribute Delete
2.28 kB
from __future__ import annotations
from collections.abc import Generator
from pathlib import Path
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from app.config import BASE_DIR, get_settings
class Base(DeclarativeBase):
pass
def _sqlite_path_from_url(database_url: str) -> Path | None:
if not database_url.startswith("sqlite:///"):
return None
raw_path = database_url.removeprefix("sqlite:///")
path = Path(raw_path)
if not path.is_absolute():
path = BASE_DIR / path
return path
settings = get_settings()
sqlite_path = _sqlite_path_from_url(settings.database_url)
if sqlite_path is not None:
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
engine = create_engine(
settings.database_url,
connect_args={"check_same_thread": False}
if settings.database_url.startswith("sqlite")
else {},
)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
def init_db() -> None:
# Importing models registers their tables with SQLAlchemy metadata.
from app import models # noqa: F401
Base.metadata.create_all(bind=engine)
_upgrade_sqlite_schema()
def _upgrade_sqlite_schema() -> None:
if not settings.database_url.startswith("sqlite"):
return
upgrades = {
"users": {
"shirt_color": "VARCHAR(20) NOT NULL DEFAULT 'red'",
"pants_color": "VARCHAR(20) NOT NULL DEFAULT 'navy'",
"hair_color": "VARCHAR(20) NOT NULL DEFAULT 'dark_brown'",
},
"lessons": {
"villain_image_url": "VARCHAR(500)",
},
}
with engine.begin() as connection:
inspector = inspect(connection)
for table, columns in upgrades.items():
if table not in inspector.get_table_names():
continue
existing = {column["name"] for column in inspector.get_columns(table)}
for name, definition in columns.items():
if name not in existing:
connection.execute(text(f"ALTER TABLE {table} ADD COLUMN {name} {definition}"))
def get_db() -> Generator[Session, None, None]:
db = SessionLocal()
try:
yield db
finally:
db.close()