Spaces:
Sleeping
Sleeping
File size: 3,502 Bytes
9d29748 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | """
Async SQLAlchemy Database Engine & Session Management.
Uses async SQLAlchemy with support for both SQLite (dev) and PostgreSQL (prod).
Provides a session dependency for FastAPI route injection.
"""
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config import get_settings
settings = get_settings()
# ββ Normalize DATABASE_URL for asyncpg ββββββββββββββββββββββββββββββββββββ
_db_url = settings.database_url
# Ensure asyncpg driver is specified
if _db_url.startswith("postgresql://"):
_db_url = _db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
# Remove Neon-specific params that asyncpg doesn't support
_needs_ssl = False
if "sslmode=require" in _db_url:
_needs_ssl = True
_db_url = _db_url.replace("?sslmode=require", "").replace("&sslmode=require", "").replace("sslmode=require&", "")
if "channel_binding=require" in _db_url:
_db_url = _db_url.replace("?channel_binding=require", "").replace("&channel_binding=require", "").replace("channel_binding=require&", "")
# Clean up trailing ? or &
_db_url = _db_url.rstrip("?&")
if "?ssl=require" in _db_url:
_needs_ssl = True
_db_url = _db_url.replace("?ssl=require", "").replace("&ssl=require", "")
_db_url = _db_url.rstrip("?&")
# ββ Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_connect_args: dict = {}
_engine_kwargs: dict = {
"echo": False,
"pool_pre_ping": True,
}
if _db_url.startswith("sqlite"):
_connect_args = {"check_same_thread": False}
else:
# PostgreSQL pool settings
_engine_kwargs.update({
"pool_size": 5,
"max_overflow": 5,
"pool_recycle": 300,
})
# SSL for cloud providers (Neon, Supabase, etc.)
if _needs_ssl:
import ssl as _ssl
_ssl_ctx = _ssl.create_default_context()
_ssl_ctx.check_hostname = False
_ssl_ctx.verify_mode = _ssl.CERT_NONE
_connect_args["ssl"] = _ssl_ctx
engine = create_async_engine(
_db_url,
connect_args=_connect_args,
**_engine_kwargs,
)
# ββ Session Factory ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
async_session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
# ββ Declarative Base βββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
# ββ Dependency βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
async def get_db() -> AsyncSession: # type: ignore[misc]
"""FastAPI dependency that yields an async database session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
|