| """ |
| Database connection and session management. |
| Uses async SQLAlchemy with asyncpg driver for PostgreSQL, |
| or aiosqlite for local/free-tier deployment. |
| """ |
|
|
| import uuid |
| import os |
|
|
| from sqlalchemy import CHAR, text |
| from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker |
| from sqlalchemy.orm import DeclarativeBase |
| from sqlalchemy.types import TypeDecorator |
|
|
| from app.config import get_settings |
|
|
|
|
| class GUID(TypeDecorator): |
| """ |
| Platform-independent GUID type. |
| Uses PostgreSQL's UUID type when available, otherwise uses CHAR(32) for SQLite. |
| """ |
| impl = CHAR |
| cache_ok = True |
|
|
| def load_dialect_impl(self, dialect): |
| if dialect.name == 'postgresql': |
| from sqlalchemy.dialects.postgresql import UUID as PG_UUID |
| return dialect.type_descriptor(PG_UUID(as_uuid=True)) |
| else: |
| return dialect.type_descriptor(CHAR(32)) |
|
|
| def process_bind_param(self, value, dialect): |
| if value is None: |
| return value |
| elif dialect.name == 'postgresql': |
| return value |
| else: |
| if isinstance(value, uuid.UUID): |
| return value.hex |
| else: |
| return uuid.UUID(value).hex |
|
|
| def process_result_value(self, value, dialect): |
| if value is None: |
| return value |
| elif dialect.name == 'postgresql': |
| return value |
| else: |
| if isinstance(value, uuid.UUID): |
| return value |
| return uuid.UUID(value) |
|
|
| settings = get_settings() |
|
|
| |
| _db_url = settings.database_url |
|
|
| if not _db_url or _db_url == "": |
| |
| _db_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") |
| os.makedirs(_db_dir, exist_ok=True) |
| _db_url = f"sqlite+aiosqlite:///{_db_dir}/fairrelay.db" |
| _is_sqlite = True |
| else: |
| _is_sqlite = False |
|
|
| |
| if _is_sqlite: |
| engine = create_async_engine( |
| _db_url, |
| echo=settings.debug, |
| future=True, |
| connect_args={"check_same_thread": False}, |
| ) |
| else: |
| engine = create_async_engine( |
| _db_url, |
| echo=settings.debug, |
| future=True, |
| pool_size=20, |
| max_overflow=10, |
| pool_recycle=3600, |
| pool_pre_ping=True, |
| ) |
|
|
| |
| async_session_maker = async_sessionmaker( |
| engine, |
| class_=AsyncSession, |
| expire_on_commit=False, |
| ) |
|
|
|
|
| class Base(DeclarativeBase): |
| """Base class for all SQLAlchemy models.""" |
| pass |
|
|
|
|
| async def get_db() -> AsyncSession: |
| """ |
| Dependency that provides a database session. |
| Yields a session and ensures it's closed after use. |
| """ |
| async with async_session_maker() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
| finally: |
| await session.close() |
|
|
|
|
| async def check_db_health() -> bool: |
| """Check database connectivity by running a simple query.""" |
| try: |
| async with engine.connect() as conn: |
| await conn.execute(text("SELECT 1")) |
| return True |
| except Exception: |
| return False |
|
|
|
|
| async def init_db() -> None: |
| """Initialize database tables.""" |
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
|
|