Spaces:
Running
Running
| from __future__ import annotations | |
| import ssl as ssl_module | |
| from typing import Any | |
| import asyncpg | |
| from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult | |
| from app.core.logger import get_logger | |
| _logger = get_logger(__name__) | |
| class PostgreSQLExecutor(BaseExecutor): | |
| async def _create_pool(self) -> asyncpg.Pool: | |
| ssl_ctx: ssl_module.SSLContext | None = None | |
| if self._config.ssl_enabled: | |
| ssl_ctx = self._build_ssl_context() | |
| return await asyncpg.create_pool( | |
| host=self._config.host, | |
| port=self._config.port, | |
| user=self._config.username, | |
| password=self._config.password, | |
| database=self._config.database, | |
| min_size=1, | |
| max_size=10, | |
| timeout=self._config.connection_timeout_seconds, | |
| ssl=ssl_ctx, | |
| max_inactive_connection_lifetime=3600.0, | |
| ) | |
| async def _execute_queries( | |
| self, pool: asyncpg.Pool, queries: list[str], use_transaction: bool | |
| ) -> list[StatementResult]: | |
| async with pool.acquire() as conn: | |
| if use_transaction: | |
| tr = conn.transaction() | |
| await tr.start() | |
| try: | |
| results: list[StatementResult] = [] | |
| for query in queries: | |
| result = await self._execute_one(conn, query) | |
| results.append(result) | |
| if not result.success and use_transaction: | |
| await tr.rollback() | |
| _logger.warning( | |
| "Query failed, transaction rolled back: %s", | |
| result.error, | |
| ) | |
| return results | |
| if use_transaction: | |
| await tr.commit() | |
| return results | |
| except Exception as exc: | |
| if use_transaction: | |
| try: | |
| await tr.rollback() | |
| except Exception: | |
| pass | |
| raise | |
| async def _execute_one(self, conn: asyncpg.Connection, query: str) -> StatementResult: | |
| try: | |
| stripped = query.strip().upper() | |
| if stripped.startswith("SELECT") or stripped.startswith("WITH"): | |
| rows = await conn.fetch(query) | |
| data = [dict(row) for row in rows][: self._config.max_rows] | |
| return StatementResult(success=True, rows=len(data), data=data) | |
| if "RETURNING" in stripped: | |
| rows = await conn.fetch(query) | |
| data = [dict(row) for row in rows][: self._config.max_rows] | |
| return StatementResult(success=True, rows=len(data), data=data) | |
| result = await conn.execute(query) | |
| parts = result.split() | |
| rowcount = int(parts[-1]) if parts[-1].isdigit() else 0 | |
| return StatementResult(success=True, rows=rowcount, data=[]) | |
| except Exception as exc: | |
| return StatementResult( | |
| success=False, | |
| error=str(exc), | |
| error_code=type(exc).__name__, | |
| ) | |
| async def _close_pool(self, pool: asyncpg.Pool) -> None: | |
| await pool.close() | |
| def _build_ssl_context() -> ssl_module.SSLContext: | |
| ctx = ssl_module.create_default_context() | |
| return ctx | |