| import re |
| from typing import AsyncGenerator |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker |
| from sqlalchemy.orm import DeclarativeBase |
| from .config import get_settings |
|
|
|
|
| def _make_async_url(url: str) -> tuple[str, dict]: |
| """Convert a standard postgres:// URL to asyncpg-compatible form. |
| |
| asyncpg does NOT accept sslmode or channel_binding as URL query params. |
| Strip them and return connect_args with ssl=True when sslmode was present. |
| """ |
| needs_ssl = bool(re.search(r"[?&]sslmode=", url)) |
| |
| url = re.sub(r"^postgresql(\+[^:]+)?:", "postgresql+asyncpg:", url) |
| |
| for param in ("sslmode", "channel_binding"): |
| url = re.sub(rf"[?&]{param}=[^&]*", "", url) |
| |
| url = re.sub(r"\?$", "", url) |
| url = re.sub(r"&$", "", url) |
| connect_args: dict = {} |
| if needs_ssl: |
| import ssl as _ssl |
| ctx = _ssl.create_default_context() |
| ctx.check_hostname = False |
| ctx.verify_mode = _ssl.CERT_NONE |
| connect_args["ssl"] = ctx |
| |
| |
| connect_args["statement_cache_size"] = 0 |
| return url, connect_args |
|
|
|
|
| settings = get_settings() |
| _db_url, _connect_args = _make_async_url(settings.database_url) |
| engine = create_async_engine( |
| _db_url, |
| echo=False, |
| pool_pre_ping=True, |
| pool_recycle=300, |
| pool_size=10, |
| max_overflow=5, |
| connect_args={ |
| **_connect_args, |
| "command_timeout": 60, |
| }, |
| execution_options={"prepared_statement_cache_size": 0} |
| ) |
| from sqlalchemy.orm import sessionmaker |
| AsyncSessionLocal = sessionmaker( |
| bind=engine, |
| class_=AsyncSession, |
| expire_on_commit=False |
| ) |
|
|
|
|
| class Base(DeclarativeBase): |
| pass |
|
|
|
|
| async def get_db() -> AsyncGenerator[AsyncSession, None]: |
| async with AsyncSessionLocal() as session: |
| yield session |
|
|
|
|