llm-ready-data / app /core /database /postgresql.py
Soumik-404's picture
feat: add db apis
89157f5
Raw
History Blame Contribute Delete
3.42 kB
from __future__ import annotations
import ssl as ssl_module
from typing import Any
import asyncpg
from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult
from app.core.logger import get_logger
_logger = get_logger(__name__)
class PostgreSQLExecutor(BaseExecutor):
async def _create_pool(self) -> asyncpg.Pool:
ssl_ctx: ssl_module.SSLContext | None = None
if self._config.ssl_enabled:
ssl_ctx = self._build_ssl_context()
return await asyncpg.create_pool(
host=self._config.host,
port=self._config.port,
user=self._config.username,
password=self._config.password,
database=self._config.database,
min_size=1,
max_size=10,
timeout=self._config.connection_timeout_seconds,
ssl=ssl_ctx,
max_inactive_connection_lifetime=3600.0,
)
async def _execute_queries(
self, pool: asyncpg.Pool, queries: list[str], use_transaction: bool
) -> list[StatementResult]:
async with pool.acquire() as conn:
if use_transaction:
tr = conn.transaction()
await tr.start()
try:
results: list[StatementResult] = []
for query in queries:
result = await self._execute_one(conn, query)
results.append(result)
if not result.success and use_transaction:
await tr.rollback()
_logger.warning(
"Query failed, transaction rolled back: %s",
result.error,
)
return results
if use_transaction:
await tr.commit()
return results
except Exception as exc:
if use_transaction:
try:
await tr.rollback()
except Exception:
pass
raise
async def _execute_one(self, conn: asyncpg.Connection, query: str) -> StatementResult:
try:
stripped = query.strip().upper()
if stripped.startswith("SELECT") or stripped.startswith("WITH"):
rows = await conn.fetch(query)
data = [dict(row) for row in rows][: self._config.max_rows]
return StatementResult(success=True, rows=len(data), data=data)
if "RETURNING" in stripped:
rows = await conn.fetch(query)
data = [dict(row) for row in rows][: self._config.max_rows]
return StatementResult(success=True, rows=len(data), data=data)
result = await conn.execute(query)
parts = result.split()
rowcount = int(parts[-1]) if parts[-1].isdigit() else 0
return StatementResult(success=True, rows=rowcount, data=[])
except Exception as exc:
return StatementResult(
success=False,
error=str(exc),
error_code=type(exc).__name__,
)
async def _close_pool(self, pool: asyncpg.Pool) -> None:
await pool.close()
@staticmethod
def _build_ssl_context() -> ssl_module.SSLContext:
ctx = ssl_module.create_default_context()
return ctx