quanthedge / backend /app /database.py
jashdoshi77's picture
QuantHedge: Full deployment with Docker + nginx + uvicorn
9d29748
"""
Async SQLAlchemy Database Engine & Session Management.
Uses async SQLAlchemy with support for both SQLite (dev) and PostgreSQL (prod).
Provides a session dependency for FastAPI route injection.
"""
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import get_settings
settings = get_settings()
# ── Normalize DATABASE_URL for asyncpg ────────────────────────────────────
_db_url = settings.database_url
# Ensure asyncpg driver is specified
if _db_url.startswith("postgresql://"):
_db_url = _db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
# Remove Neon-specific params that asyncpg doesn't support
_needs_ssl = False
if "sslmode=require" in _db_url:
_needs_ssl = True
_db_url = _db_url.replace("?sslmode=require", "").replace("&sslmode=require", "").replace("sslmode=require&", "")
if "channel_binding=require" in _db_url:
_db_url = _db_url.replace("?channel_binding=require", "").replace("&channel_binding=require", "").replace("channel_binding=require&", "")
# Clean up trailing ? or &
_db_url = _db_url.rstrip("?&")
if "?ssl=require" in _db_url:
_needs_ssl = True
_db_url = _db_url.replace("?ssl=require", "").replace("&ssl=require", "")
_db_url = _db_url.rstrip("?&")
# ── Engine ───────────────────────────────────────────────────────────────
_connect_args: dict = {}
_engine_kwargs: dict = {
"echo": False,
"pool_pre_ping": True,
}
if _db_url.startswith("sqlite"):
_connect_args = {"check_same_thread": False}
else:
# PostgreSQL pool settings
_engine_kwargs.update({
"pool_size": 5,
"max_overflow": 5,
"pool_recycle": 300,
})
# SSL for cloud providers (Neon, Supabase, etc.)
if _needs_ssl:
import ssl as _ssl
_ssl_ctx = _ssl.create_default_context()
_ssl_ctx.check_hostname = False
_ssl_ctx.verify_mode = _ssl.CERT_NONE
_connect_args["ssl"] = _ssl_ctx
engine = create_async_engine(
_db_url,
connect_args=_connect_args,
**_engine_kwargs,
)
# ── Session Factory ──────────────────────────────────────────────────────
async_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
# ── Declarative Base ─────────────────────────────────────────────────────
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
# ── Dependency ───────────────────────────────────────────────────────────
async def get_db() -> AsyncSession: # type: ignore[misc]
"""FastAPI dependency that yields an async database session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()