from __future__ import annotations import ssl as ssl_module from typing import Any import asyncpg from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult from app.core.logger import get_logger _logger = get_logger(__name__) class PostgreSQLExecutor(BaseExecutor): async def _create_pool(self) -> asyncpg.Pool: ssl_ctx: ssl_module.SSLContext | None = None if self._config.ssl_enabled: ssl_ctx = self._build_ssl_context() return await asyncpg.create_pool( host=self._config.host, port=self._config.port, user=self._config.username, password=self._config.password, database=self._config.database, min_size=1, max_size=10, timeout=self._config.connection_timeout_seconds, ssl=ssl_ctx, max_inactive_connection_lifetime=3600.0, ) async def _execute_queries( self, pool: asyncpg.Pool, queries: list[str], use_transaction: bool ) -> list[StatementResult]: async with pool.acquire() as conn: if use_transaction: tr = conn.transaction() await tr.start() try: results: list[StatementResult] = [] for query in queries: result = await self._execute_one(conn, query) results.append(result) if not result.success and use_transaction: await tr.rollback() _logger.warning( "Query failed, transaction rolled back: %s", result.error, ) return results if use_transaction: await tr.commit() return results except Exception as exc: if use_transaction: try: await tr.rollback() except Exception: pass raise async def _execute_one(self, conn: asyncpg.Connection, query: str) -> StatementResult: try: stripped = query.strip().upper() if stripped.startswith("SELECT") or stripped.startswith("WITH"): rows = await conn.fetch(query) data = [dict(row) for row in rows][: self._config.max_rows] return StatementResult(success=True, rows=len(data), data=data) if "RETURNING" in stripped: rows = await conn.fetch(query) data = [dict(row) for row in rows][: self._config.max_rows] return StatementResult(success=True, rows=len(data), data=data) result = await conn.execute(query) parts = result.split() rowcount = int(parts[-1]) if parts[-1].isdigit() else 0 return StatementResult(success=True, rows=rowcount, data=[]) except Exception as exc: return StatementResult( success=False, error=str(exc), error_code=type(exc).__name__, ) async def _close_pool(self, pool: asyncpg.Pool) -> None: await pool.close() @staticmethod def _build_ssl_context() -> ssl_module.SSLContext: ctx = ssl_module.create_default_context() return ctx