| """ |
| Database connection and session management. |
| Uses async SQLAlchemy with asyncpg driver. |
| """ |
|
|
| import uuid |
|
|
| from sqlalchemy import CHAR, text |
| from sqlalchemy.dialects.postgresql import UUID as PG_UUID |
| from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker |
| from sqlalchemy.orm import DeclarativeBase |
| from sqlalchemy.types import TypeDecorator |
|
|
| from app.config import get_settings |
|
|
|
|
| class GUID(TypeDecorator): |
| """ |
| Platform-independent GUID type. |
| Uses PostgreSQL's UUID type when available, otherwise uses CHAR(32) for SQLite. |
| Stores as stringified hex values in SQLite. |
| """ |
| impl = CHAR |
| cache_ok = True |
|
|
| def load_dialect_impl(self, dialect): |
| if dialect.name == 'postgresql': |
| return dialect.type_descriptor(PG_UUID(as_uuid=True)) |
| else: |
| return dialect.type_descriptor(CHAR(32)) |
|
|
| def process_bind_param(self, value, dialect): |
| if value is None: |
| return value |
| elif dialect.name == 'postgresql': |
| return value |
| else: |
| if isinstance(value, uuid.UUID): |
| return value.hex |
| else: |
| return uuid.UUID(value).hex |
|
|
| def process_result_value(self, value, dialect): |
| if value is None: |
| return value |
| elif dialect.name == 'postgresql': |
| return value |
| else: |
| if isinstance(value, uuid.UUID): |
| return value |
| return uuid.UUID(value) |
|
|
| settings = get_settings() |
|
|
| |
| engine = create_async_engine( |
| settings.database_url, |
| echo=settings.debug, |
| future=True, |
| pool_size=20, |
| max_overflow=10, |
| pool_recycle=3600, |
| pool_pre_ping=True, |
| ) |
|
|
| |
| async_session_maker = async_sessionmaker( |
| engine, |
| class_=AsyncSession, |
| expire_on_commit=False, |
| ) |
|
|
|
|
| class Base(DeclarativeBase): |
| """Base class for all SQLAlchemy models.""" |
| pass |
|
|
|
|
| async def get_db() -> AsyncSession: |
| """ |
| Dependency that provides a database session. |
| Yields a session and ensures it's closed after use. |
| """ |
| async with async_session_maker() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
| finally: |
| await session.close() |
|
|
|
|
| async def check_db_health() -> bool: |
| """Check database connectivity by running a simple query.""" |
| try: |
| async with engine.connect() as conn: |
| await conn.execute(text("SELECT 1")) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| async def init_db() -> None: |
| """Initialize database tables (for development only).""" |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
|
|