coderound / backend /src /database.py
ketannnn's picture
feat: implement admin endpoint and UI for database and vector store reset
022fb5a
import re
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from .config import get_settings
def _make_async_url(url: str) -> tuple[str, dict]:
"""Convert a standard postgres:// URL to asyncpg-compatible form.
asyncpg does NOT accept sslmode or channel_binding as URL query params.
Strip them and return connect_args with ssl=True when sslmode was present.
"""
needs_ssl = bool(re.search(r"[?&]sslmode=", url))
# Switch scheme
url = re.sub(r"^postgresql(\+[^:]+)?:", "postgresql+asyncpg:", url)
# Remove unsupported query params
for param in ("sslmode", "channel_binding"):
url = re.sub(rf"[?&]{param}=[^&]*", "", url)
# Clean up trailing ? or & left behind
url = re.sub(r"\?$", "", url)
url = re.sub(r"&$", "", url)
connect_args: dict = {}
if needs_ssl:
import ssl as _ssl
ctx = _ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = _ssl.CERT_NONE
connect_args["ssl"] = ctx
# Crucial for Neon Serverless PgBouncer functionality
connect_args["statement_cache_size"] = 0
return url, connect_args
settings = get_settings()
_db_url, _connect_args = _make_async_url(settings.database_url)
engine = create_async_engine(
_db_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300, # Refreshes old connections every 5 min to prevent stale connection errors
pool_size=10,
max_overflow=5,
connect_args={
**_connect_args,
"command_timeout": 60,
},
execution_options={"prepared_statement_cache_size": 0}
)
from sqlalchemy.orm import sessionmaker
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False
)
class Base(DeclarativeBase):
pass
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session:
yield session