light-infer-chat's picture
ok
55ae875
Raw
History Blame Contribute Delete
4.41 kB
from __future__ import annotations
import asyncio
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from app.core.logger import get_logger
_logger = get_logger(__name__)
@dataclass(frozen=True)
class ConnectionConfig:
db_type: str
host: str
port: int
database: str
username: str
password: str
selected_schema: str = "public"
ssl_enabled: bool = False
ssl_ca_cert: str | None = None
ssl_cert: str | None = None
ssl_key: str | None = None
query_timeout_seconds: float = 30.0
connection_timeout_seconds: float = 10.0
max_rows: int = 10000
@property
def pool_key(self) -> str:
return f"{self.db_type}:{self.username}@{self.host}:{self.port}/{self.database}"
@property
def safe_repr(self) -> str:
return (
f"ConnectionConfig(db_type={self.db_type}, host={self.host}, "
f"port={self.port}, database={self.database}, "
f"selected_schema={self.selected_schema}, username={self.username}, ssl={self.ssl_enabled})"
)
@dataclass
class StatementResult:
success: bool
rows: int = 0
data: list[dict[str, Any]] = field(default_factory=list)
error: str | None = None
error_code: str | None = None
class BaseExecutor(ABC):
def __init__(self, config: ConnectionConfig) -> None:
self._config = config
self._pool: Any = None
self._closed = False
self._lock = asyncio.Lock()
self._max_connection_retries = 3
@abstractmethod
async def _create_pool(self) -> Any:
...
@abstractmethod
async def _execute_queries(
self, pool: Any, queries: list[Any], use_transaction: bool
) -> list[StatementResult]:
...
async def execute(
self, queries: list[Any], use_transaction: bool = True
) -> list[StatementResult]:
if self._closed:
raise RuntimeError("Executor has been closed")
pool = await self._get_or_create_pool()
return await self._execute_queries(pool, queries, use_transaction)
async def _get_or_create_pool(self) -> Any:
async with self._lock:
if self._pool is not None:
return self._pool
last_exc: Exception | None = None
for attempt in range(self._max_connection_retries):
try:
self._pool = await asyncio.wait_for(
self._create_pool(),
timeout=self._config.connection_timeout_seconds,
)
_logger.info(
"Created pool for %s (attempt %d)",
self._config.safe_repr, attempt + 1,
)
return self._pool
except asyncio.TimeoutError:
last_exc = TimeoutError(
f"Connection timed out after {self._config.connection_timeout_seconds}s"
)
_logger.warning(
"Pool creation timeout for %s (attempt %d)",
self._config.safe_repr, attempt + 1,
)
except Exception as exc:
last_exc = exc
_logger.warning(
"Pool creation failed for %s (attempt %d): %s",
self._config.safe_repr, attempt + 1, exc,
)
if attempt < self._max_connection_retries - 1:
wait = 0.1 * (2**attempt)
await asyncio.sleep(wait)
msg = (
f"Failed to create connection pool after {self._max_connection_retries} attempts"
)
if last_exc is not None:
msg += f": {last_exc}"
raise RuntimeError(msg) from last_exc
async def close(self) -> None:
async with self._lock:
if self._closed:
return
self._closed = True
if self._pool is not None:
await self._close_pool(self._pool)
self._pool = None
_logger.info("Closed pool for %s", self._config.safe_repr)
@abstractmethod
async def _close_pool(self, pool: Any) -> None:
...
@property
def is_closed(self) -> bool:
return self._closed