Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import time | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| from app.core.logger import get_logger | |
| _logger = get_logger(__name__) | |
| class ConnectionConfig: | |
| db_type: str | |
| host: str | |
| port: int | |
| database: str | |
| username: str | |
| password: str | |
| selected_schema: str = "public" | |
| ssl_enabled: bool = False | |
| ssl_ca_cert: str | None = None | |
| ssl_cert: str | None = None | |
| ssl_key: str | None = None | |
| query_timeout_seconds: float = 30.0 | |
| connection_timeout_seconds: float = 10.0 | |
| max_rows: int = 10000 | |
| def pool_key(self) -> str: | |
| return f"{self.db_type}:{self.username}@{self.host}:{self.port}/{self.database}" | |
| def safe_repr(self) -> str: | |
| return ( | |
| f"ConnectionConfig(db_type={self.db_type}, host={self.host}, " | |
| f"port={self.port}, database={self.database}, " | |
| f"selected_schema={self.selected_schema}, username={self.username}, ssl={self.ssl_enabled})" | |
| ) | |
| class StatementResult: | |
| success: bool | |
| rows: int = 0 | |
| data: list[dict[str, Any]] = field(default_factory=list) | |
| error: str | None = None | |
| error_code: str | None = None | |
| class BaseExecutor(ABC): | |
| def __init__(self, config: ConnectionConfig) -> None: | |
| self._config = config | |
| self._pool: Any = None | |
| self._closed = False | |
| self._lock = asyncio.Lock() | |
| self._max_connection_retries = 3 | |
| async def _create_pool(self) -> Any: | |
| ... | |
| async def _execute_queries( | |
| self, pool: Any, queries: list[Any], use_transaction: bool | |
| ) -> list[StatementResult]: | |
| ... | |
| async def execute( | |
| self, queries: list[Any], use_transaction: bool = True | |
| ) -> list[StatementResult]: | |
| if self._closed: | |
| raise RuntimeError("Executor has been closed") | |
| pool = await self._get_or_create_pool() | |
| return await self._execute_queries(pool, queries, use_transaction) | |
| async def _get_or_create_pool(self) -> Any: | |
| async with self._lock: | |
| if self._pool is not None: | |
| return self._pool | |
| last_exc: Exception | None = None | |
| for attempt in range(self._max_connection_retries): | |
| try: | |
| self._pool = await asyncio.wait_for( | |
| self._create_pool(), | |
| timeout=self._config.connection_timeout_seconds, | |
| ) | |
| _logger.info( | |
| "Created pool for %s (attempt %d)", | |
| self._config.safe_repr, attempt + 1, | |
| ) | |
| return self._pool | |
| except asyncio.TimeoutError: | |
| last_exc = TimeoutError( | |
| f"Connection timed out after {self._config.connection_timeout_seconds}s" | |
| ) | |
| _logger.warning( | |
| "Pool creation timeout for %s (attempt %d)", | |
| self._config.safe_repr, attempt + 1, | |
| ) | |
| except Exception as exc: | |
| last_exc = exc | |
| _logger.warning( | |
| "Pool creation failed for %s (attempt %d): %s", | |
| self._config.safe_repr, attempt + 1, exc, | |
| ) | |
| if attempt < self._max_connection_retries - 1: | |
| wait = 0.1 * (2**attempt) | |
| await asyncio.sleep(wait) | |
| msg = ( | |
| f"Failed to create connection pool after {self._max_connection_retries} attempts" | |
| ) | |
| if last_exc is not None: | |
| msg += f": {last_exc}" | |
| raise RuntimeError(msg) from last_exc | |
| async def close(self) -> None: | |
| async with self._lock: | |
| if self._closed: | |
| return | |
| self._closed = True | |
| if self._pool is not None: | |
| await self._close_pool(self._pool) | |
| self._pool = None | |
| _logger.info("Closed pool for %s", self._config.safe_repr) | |
| async def _close_pool(self, pool: Any) -> None: | |
| ... | |
| def is_closed(self) -> bool: | |
| return self._closed | |