""" Async SQLAlchemy Database Engine & Session Management. Uses async SQLAlchemy with support for both SQLite (dev) and PostgreSQL (prod). Provides a session dependency for FastAPI route injection. """ from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase from app.config import get_settings settings = get_settings() # ── Normalize DATABASE_URL for asyncpg ──────────────────────────────────── _db_url = settings.database_url # Ensure asyncpg driver is specified if _db_url.startswith("postgresql://"): _db_url = _db_url.replace("postgresql://", "postgresql+asyncpg://", 1) # Remove Neon-specific params that asyncpg doesn't support _needs_ssl = False if "sslmode=require" in _db_url: _needs_ssl = True _db_url = _db_url.replace("?sslmode=require", "").replace("&sslmode=require", "").replace("sslmode=require&", "") if "channel_binding=require" in _db_url: _db_url = _db_url.replace("?channel_binding=require", "").replace("&channel_binding=require", "").replace("channel_binding=require&", "") # Clean up trailing ? or & _db_url = _db_url.rstrip("?&") if "?ssl=require" in _db_url: _needs_ssl = True _db_url = _db_url.replace("?ssl=require", "").replace("&ssl=require", "") _db_url = _db_url.rstrip("?&") # ── Engine ─────────────────────────────────────────────────────────────── _connect_args: dict = {} _engine_kwargs: dict = { "echo": False, "pool_pre_ping": True, } if _db_url.startswith("sqlite"): _connect_args = {"check_same_thread": False} else: # PostgreSQL pool settings _engine_kwargs.update({ "pool_size": 5, "max_overflow": 5, "pool_recycle": 300, }) # SSL for cloud providers (Neon, Supabase, etc.) if _needs_ssl: import ssl as _ssl _ssl_ctx = _ssl.create_default_context() _ssl_ctx.check_hostname = False _ssl_ctx.verify_mode = _ssl.CERT_NONE _connect_args["ssl"] = _ssl_ctx engine = create_async_engine( _db_url, connect_args=_connect_args, **_engine_kwargs, ) # ── Session Factory ────────────────────────────────────────────────────── async_session_factory = async_sessionmaker( bind=engine, class_=AsyncSession, expire_on_commit=False, ) # ── Declarative Base ───────────────────────────────────────────────────── class Base(DeclarativeBase): """Base class for all ORM models.""" pass # ── Dependency ─────────────────────────────────────────────────────────── async def get_db() -> AsyncSession: # type: ignore[misc] """FastAPI dependency that yields an async database session.""" async with async_session_factory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close()