Spaces:
Running
Running
| """ | |
| Async SQLAlchemy Database Engine & Session Management. | |
| Uses async SQLAlchemy with support for both SQLite (dev) and PostgreSQL (prod). | |
| Provides a session dependency for FastAPI route injection. | |
| """ | |
| from __future__ import annotations | |
| from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine | |
| from sqlalchemy.orm import DeclarativeBase | |
| from app.config import get_settings | |
| settings = get_settings() | |
| # ββ Normalize DATABASE_URL for asyncpg ββββββββββββββββββββββββββββββββββββ | |
| _db_url = settings.database_url | |
| # Ensure asyncpg driver is specified | |
| if _db_url.startswith("postgresql://"): | |
| _db_url = _db_url.replace("postgresql://", "postgresql+asyncpg://", 1) | |
| # Remove Neon-specific params that asyncpg doesn't support | |
| _needs_ssl = False | |
| if "sslmode=require" in _db_url: | |
| _needs_ssl = True | |
| _db_url = _db_url.replace("?sslmode=require", "").replace("&sslmode=require", "").replace("sslmode=require&", "") | |
| if "channel_binding=require" in _db_url: | |
| _db_url = _db_url.replace("?channel_binding=require", "").replace("&channel_binding=require", "").replace("channel_binding=require&", "") | |
| # Clean up trailing ? or & | |
| _db_url = _db_url.rstrip("?&") | |
| if "?ssl=require" in _db_url: | |
| _needs_ssl = True | |
| _db_url = _db_url.replace("?ssl=require", "").replace("&ssl=require", "") | |
| _db_url = _db_url.rstrip("?&") | |
| # ββ Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _connect_args: dict = {} | |
| _engine_kwargs: dict = { | |
| "echo": False, | |
| "pool_pre_ping": True, | |
| } | |
| if _db_url.startswith("sqlite"): | |
| _connect_args = {"check_same_thread": False} | |
| else: | |
| # PostgreSQL pool settings | |
| _engine_kwargs.update({ | |
| "pool_size": 5, | |
| "max_overflow": 5, | |
| "pool_recycle": 300, | |
| }) | |
| # SSL for cloud providers (Neon, Supabase, etc.) | |
| if _needs_ssl: | |
| import ssl as _ssl | |
| _ssl_ctx = _ssl.create_default_context() | |
| _ssl_ctx.check_hostname = False | |
| _ssl_ctx.verify_mode = _ssl.CERT_NONE | |
| _connect_args["ssl"] = _ssl_ctx | |
| engine = create_async_engine( | |
| _db_url, | |
| connect_args=_connect_args, | |
| **_engine_kwargs, | |
| ) | |
| # ββ Session Factory ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async_session_factory = async_sessionmaker( | |
| bind=engine, | |
| class_=AsyncSession, | |
| expire_on_commit=False, | |
| ) | |
| # ββ Declarative Base βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Base(DeclarativeBase): | |
| """Base class for all ORM models.""" | |
| pass | |
| # ββ Dependency βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_db() -> AsyncSession: # type: ignore[misc] | |
| """FastAPI dependency that yields an async database session.""" | |
| async with async_session_factory() as session: | |
| try: | |
| yield session | |
| await session.commit() | |
| except Exception: | |
| await session.rollback() | |
| raise | |
| finally: | |
| await session.close() | |