File size: 3,502 Bytes
9d29748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Async SQLAlchemy Database Engine & Session Management.

Uses async SQLAlchemy with support for both SQLite (dev) and PostgreSQL (prod).
Provides a session dependency for FastAPI route injection.
"""

from __future__ import annotations

from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase

from app.config import get_settings

settings = get_settings()

# ── Normalize DATABASE_URL for asyncpg ────────────────────────────────────
_db_url = settings.database_url

# Ensure asyncpg driver is specified
if _db_url.startswith("postgresql://"):
    _db_url = _db_url.replace("postgresql://", "postgresql+asyncpg://", 1)

# Remove Neon-specific params that asyncpg doesn't support
_needs_ssl = False
if "sslmode=require" in _db_url:
    _needs_ssl = True
    _db_url = _db_url.replace("?sslmode=require", "").replace("&sslmode=require", "").replace("sslmode=require&", "")
if "channel_binding=require" in _db_url:
    _db_url = _db_url.replace("?channel_binding=require", "").replace("&channel_binding=require", "").replace("channel_binding=require&", "")
# Clean up trailing ? or &
_db_url = _db_url.rstrip("?&")
if "?ssl=require" in _db_url:
    _needs_ssl = True
    _db_url = _db_url.replace("?ssl=require", "").replace("&ssl=require", "")
    _db_url = _db_url.rstrip("?&")

# ── Engine ───────────────────────────────────────────────────────────────
_connect_args: dict = {}
_engine_kwargs: dict = {
    "echo": False,
    "pool_pre_ping": True,
}

if _db_url.startswith("sqlite"):
    _connect_args = {"check_same_thread": False}
else:
    # PostgreSQL pool settings
    _engine_kwargs.update({
        "pool_size": 5,
        "max_overflow": 5,
        "pool_recycle": 300,
    })
    # SSL for cloud providers (Neon, Supabase, etc.)
    if _needs_ssl:
        import ssl as _ssl
        _ssl_ctx = _ssl.create_default_context()
        _ssl_ctx.check_hostname = False
        _ssl_ctx.verify_mode = _ssl.CERT_NONE
        _connect_args["ssl"] = _ssl_ctx

engine = create_async_engine(
    _db_url,
    connect_args=_connect_args,
    **_engine_kwargs,
)

# ── Session Factory ──────────────────────────────────────────────────────
async_session_factory = async_sessionmaker(
    bind=engine,
    class_=AsyncSession,
    expire_on_commit=False,
)


# ── Declarative Base ─────────────────────────────────────────────────────
class Base(DeclarativeBase):
    """Base class for all ORM models."""
    pass


# ── Dependency ───────────────────────────────────────────────────────────
async def get_db() -> AsyncSession:  # type: ignore[misc]
    """FastAPI dependency that yields an async database session."""
    async with async_session_factory() as session:
        try:
            yield session
            await session.commit()
        except Exception:
            await session.rollback()
            raise
        finally:
            await session.close()