File size: 7,596 Bytes
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d852d80
 
 
e391a84
 
 
d852d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e391a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
infrastructure/database/connection.py
───────────────────────────────────────
SQLAlchemy async engine + session factory β€” Supabase / PostgreSQL ready.

Supports:
  β€’ PostgreSQL via asyncpg   (Supabase direct    port 5432)
  β€’ PostgreSQL via asyncpg   (Supabase pooler    port 6543 β€” Transaction mode)
  β€’ SQLite    via aiosqlite  (local dev / tests)

Supabase-specific tuning:
  β€’ pool_size / max_overflow read from Settings (default: 5 / 10)
  β€’ pool_recycle set to 1800 s (30 min) β€” avoids idle connection drops
  β€’ connect_args.server_settings identifies the app in pg_stat_activity
  β€’ Pooler (port 6543) mode disables prepared_statement_cache_size because
    pgBouncer Transaction mode does not support server-side prepared stmts.
"""
from __future__ import annotations

from collections.abc import AsyncGenerator

from sqlalchemy.ext.asyncio import (
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)

from src.shared.config import get_settings
from src.shared.logger import get_logger
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.sql.elements import TextClause

logger = get_logger(__name__)

# ── SQLite Compatibility Compilers ───────────────────────────────────────────
# Registers global overrides to compile PostgreSQL-specific DDL for SQLite fallback.

@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(type_, compiler, **kw):
    """Compile JSONB to standard JSON in SQLite."""
    return "JSON"

@compiles(TextClause, "sqlite")
def compile_text_sqlite(element, compiler, **kw):
    """Compile gen_random_uuid() to random hex blob in SQLite."""
    if element.text == "gen_random_uuid()":
        return (
            "(lower(hex(randomblob(4))) || '-' || "
            "lower(hex(randomblob(2))) || '-4' || "
            "substr(lower(hex(randomblob(2))),2) || '-' || "
            "substr('89ab',abs(random()) % 4 + 1, 1) || "
            "substr(lower(hex(randomblob(2))),2) || '-' || "
            "lower(hex(randomblob(6))))"
        )
    return element.text

# ── Module-level singletons (created once per process) ───────────────────────
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None


def get_engine() -> AsyncEngine:
    """Return (or create) the async SQLAlchemy engine singleton."""
    global _engine
    if _engine is None:
        settings = get_settings()
        db_url = settings.database_url

        connect_args: dict = {}

        if settings.is_sqlite:
            # SQLite does not support connection pools or server_settings.
            connect_args = {"check_same_thread": False}
            _engine = create_async_engine(
                db_url,
                echo=settings.debug,
                connect_args=connect_args,
            )
        else:
            # PostgreSQL / Supabase path ──────────────────────────────────────
            server_settings: dict[str, str] = {
                "application_name": "bp_monitoring_pipeline",
            }

            if settings.uses_pooler:
                # Supabase Transaction Pooler (pgBouncer, port 6543):
                # Prepared statements are not supported in transaction mode.
                server_settings["options"] = "-c statement_timeout=30000"
                connect_args = {
                    "server_settings": server_settings,
                    "prepared_statement_cache_size": 0,
                }
                logger.info(
                    "Supabase Connection Pooler (port 6543) detected β€” "
                    "prepared statement cache disabled."
                )
            else:
                # Supabase Direct Connection (port 5432):
                # Full asyncpg feature set available.
                connect_args = {"server_settings": server_settings}

            _engine = create_async_engine(
                db_url,
                echo=settings.debug,
                pool_pre_ping=True,             # Validate connections before use
                pool_size=settings.db_pool_size,
                max_overflow=settings.db_max_overflow,
                pool_recycle=settings.db_pool_recycle,
                connect_args=connect_args,
            )

        # Log host only (strip credentials)
        safe_url = db_url.split("@")[-1] if "@" in db_url else db_url
        logger.info("Database engine created β†’ %s", safe_url)

    return _engine


def get_session_factory() -> async_sessionmaker[AsyncSession]:
    """Return (or create) the async session factory singleton."""
    global _session_factory
    if _session_factory is None:
        _session_factory = async_sessionmaker(
            bind=get_engine(),
            class_=AsyncSession,
            expire_on_commit=False,   # Avoid lazy-load issues after commit
            autoflush=False,
            autocommit=False,
        )
    return _session_factory


async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
    """
    FastAPI dependency: yield a scoped AsyncSession per request.

    Usage::

        @router.post("/...")
        async def endpoint(session: AsyncSession = Depends(get_async_session)):
            ...
    """
    factory = get_session_factory()
    async with factory() as session:
        try:
            yield session
            await session.commit()
        except Exception:
            await session.rollback()
            raise


async def ping_database() -> bool:
    """
    Send a lightweight ``SELECT 1`` to verify the database connection.

    Returns:
        True  β€” connection is healthy.
        False β€” connection failed (logs the error).

    Used by the health-check endpoint and startup validation.
    """
    from sqlalchemy import text

    try:
        async with get_engine().connect() as conn:
            await conn.execute(text("SELECT 1"))
        return True
    except Exception as exc:
        logger.error("Database ping failed: %s", exc)
        return False


async def create_all_tables() -> None:
    """
    Create all ORM tables β€” **development and test environments only**.

    In production (Supabase), always use Alembic migrations::

        alembic upgrade head

    This function is intentionally guarded by a debug/SQLite check in app.py.
    """
    settings = get_settings()
    if settings.is_supabase and not settings.debug:
        logger.warning(
            "create_all_tables() skipped β€” Supabase production detected. "
            "Run 'alembic upgrade head' to apply migrations."
        )
        return

    from src.infrastructure.database.models.base import Base  # noqa: F401 – registers models
    import src.infrastructure.database.models.ppg_model  # noqa: F401
    import src.infrastructure.database.models.prediction_model  # noqa: F401

    engine = get_engine()
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    logger.info("All database tables created.")


async def dispose_engine() -> None:
    """Dispose the engine and all pooled connections (call on application shutdown)."""
    global _engine
    if _engine is not None:
        await _engine.dispose()
        _engine = None
        logger.info("Database engine disposed.")