Spaces:
Running
Running
| """Database connection and session management.""" | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker | |
| from sqlalchemy.pool import NullPool | |
| from src.config import get_settings | |
| from loguru import logger | |
| settings = get_settings() | |
| # Build async URL — handle both postgresql:// and postgresql+asyncpg:// | |
| _db_url = settings.database_url | |
| if _db_url.startswith("postgresql://"): | |
| _db_url = _db_url.replace("postgresql://", "postgresql+asyncpg://", 1) | |
| # Neon (and most managed Postgres) requires SSL. | |
| # server_settings sets statement_timeout at connection level (Supabase free tier defaults to ~8s, | |
| # which causes ALTER TABLE / CREATE TABLE to randomly fail on startup migrations). | |
| _connect_args = { | |
| "server_settings": {"statement_timeout": "60000"}, # 60s per statement | |
| } | |
| if "neon.tech" in _db_url or "sslmode=require" in _db_url or "supabase.com" in _db_url: | |
| _connect_args["ssl"] = "require" | |
| _db_url = _db_url.split("?")[0] # asyncpg doesn't accept ?sslmode= in the URL | |
| # Create async engine | |
| engine = create_async_engine( | |
| _db_url, | |
| echo=settings.database_echo, | |
| poolclass=NullPool, | |
| connect_args=_connect_args, | |
| ) | |
| # Create session factory | |
| AsyncSessionLocal = async_sessionmaker( | |
| engine, class_=AsyncSession, expire_on_commit=False | |
| ) | |
| async def get_db(): | |
| """Dependency for FastAPI to get DB session.""" | |
| async with AsyncSessionLocal() as session: | |
| try: | |
| yield session | |
| except Exception as e: | |
| logger.error(f"Database session error: {e}") | |
| await session.rollback() | |
| raise | |
| finally: | |
| await session.close() | |
| async def init_db(): | |
| """Initialize database (create tables).""" | |
| from src.models.database import Base | |
| async with engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.create_all) | |
| logger.info("Database tables initialized") | |
| async def close_db(): | |
| """Close database connections.""" | |
| await engine.dispose() | |
| logger.info("Database connections closed") | |