File size: 2,031 Bytes
c1dbdc6 12fa3c2 a0f1a88 12fa3c2 c1dbdc6 12fa3c2 48a85ea 10090b1 3ff680a 10090b1 a0f1a88 12fa3c2 022fb5a c1dbdc6 12fa3c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import re
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from .config import get_settings
def _make_async_url(url: str) -> tuple[str, dict]:
"""Convert a standard postgres:// URL to asyncpg-compatible form.
asyncpg does NOT accept sslmode or channel_binding as URL query params.
Strip them and return connect_args with ssl=True when sslmode was present.
"""
needs_ssl = bool(re.search(r"[?&]sslmode=", url))
# Switch scheme
url = re.sub(r"^postgresql(\+[^:]+)?:", "postgresql+asyncpg:", url)
# Remove unsupported query params
for param in ("sslmode", "channel_binding"):
url = re.sub(rf"[?&]{param}=[^&]*", "", url)
# Clean up trailing ? or & left behind
url = re.sub(r"\?$", "", url)
url = re.sub(r"&$", "", url)
connect_args: dict = {}
if needs_ssl:
import ssl as _ssl
ctx = _ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = _ssl.CERT_NONE
connect_args["ssl"] = ctx
# Crucial for Neon Serverless PgBouncer functionality
connect_args["statement_cache_size"] = 0
return url, connect_args
settings = get_settings()
_db_url, _connect_args = _make_async_url(settings.database_url)
engine = create_async_engine(
_db_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300, # Refreshes old connections every 5 min to prevent stale connection errors
pool_size=10,
max_overflow=5,
connect_args={
**_connect_args,
"command_timeout": 60,
},
execution_options={"prepared_statement_cache_size": 0}
)
from sqlalchemy.orm import sessionmaker
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False
)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session:
yield session
|