File size: 2,031 Bytes
c1dbdc6
 
 
 
 
 
 
12fa3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0f1a88
 
 
12fa3c2
c1dbdc6
 
 
12fa3c2
 
 
 
 
48a85ea
10090b1
3ff680a
10090b1
 
 
 
a0f1a88
12fa3c2
022fb5a
 
 
 
 
 
c1dbdc6
 
 
 
 
 
 
 
 
12fa3c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import re
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from .config import get_settings


def _make_async_url(url: str) -> tuple[str, dict]:
    """Convert a standard postgres:// URL to asyncpg-compatible form.

    asyncpg does NOT accept sslmode or channel_binding as URL query params.
    Strip them and return connect_args with ssl=True when sslmode was present.
    """
    needs_ssl = bool(re.search(r"[?&]sslmode=", url))
    # Switch scheme
    url = re.sub(r"^postgresql(\+[^:]+)?:", "postgresql+asyncpg:", url)
    # Remove unsupported query params
    for param in ("sslmode", "channel_binding"):
        url = re.sub(rf"[?&]{param}=[^&]*", "", url)
    # Clean up trailing ? or & left behind
    url = re.sub(r"\?$", "", url)
    url = re.sub(r"&$", "", url)
    connect_args: dict = {}
    if needs_ssl:
        import ssl as _ssl
        ctx = _ssl.create_default_context()
        ctx.check_hostname = False
        ctx.verify_mode = _ssl.CERT_NONE
        connect_args["ssl"] = ctx
    
    # Crucial for Neon Serverless PgBouncer functionality
    connect_args["statement_cache_size"] = 0
    return url, connect_args


settings = get_settings()
_db_url, _connect_args = _make_async_url(settings.database_url)
engine = create_async_engine(
    _db_url,
    echo=False,
    pool_pre_ping=True,
    pool_recycle=300,        # Refreshes old connections every 5 min to prevent stale connection errors
    pool_size=10,
    max_overflow=5,
    connect_args={
        **_connect_args,
        "command_timeout": 60,
    },
    execution_options={"prepared_statement_cache_size": 0}
)
from sqlalchemy.orm import sessionmaker
AsyncSessionLocal = sessionmaker(
    bind=engine,
    class_=AsyncSession,
    expire_on_commit=False
)


class Base(DeclarativeBase):
    pass


async def get_db() -> AsyncGenerator[AsyncSession, None]:
    async with AsyncSessionLocal() as session:
        yield session