import re from typing import AsyncGenerator from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase from .config import get_settings def _make_async_url(url: str) -> tuple[str, dict]: """Convert a standard postgres:// URL to asyncpg-compatible form. asyncpg does NOT accept sslmode or channel_binding as URL query params. Strip them and return connect_args with ssl=True when sslmode was present. """ needs_ssl = bool(re.search(r"[?&]sslmode=", url)) # Switch scheme url = re.sub(r"^postgresql(\+[^:]+)?:", "postgresql+asyncpg:", url) # Remove unsupported query params for param in ("sslmode", "channel_binding"): url = re.sub(rf"[?&]{param}=[^&]*", "", url) # Clean up trailing ? or & left behind url = re.sub(r"\?$", "", url) url = re.sub(r"&$", "", url) connect_args: dict = {} if needs_ssl: import ssl as _ssl ctx = _ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = _ssl.CERT_NONE connect_args["ssl"] = ctx # Crucial for Neon Serverless PgBouncer functionality connect_args["statement_cache_size"] = 0 return url, connect_args settings = get_settings() _db_url, _connect_args = _make_async_url(settings.database_url) engine = create_async_engine( _db_url, echo=False, pool_pre_ping=True, pool_recycle=300, # Refreshes old connections every 5 min to prevent stale connection errors pool_size=10, max_overflow=5, connect_args={ **_connect_args, "command_timeout": 60, }, execution_options={"prepared_statement_cache_size": 0} ) from sqlalchemy.orm import sessionmaker AsyncSessionLocal = sessionmaker( bind=engine, class_=AsyncSession, expire_on_commit=False ) class Base(DeclarativeBase): pass async def get_db() -> AsyncGenerator[AsyncSession, None]: async with AsyncSessionLocal() as session: yield session