Soumik-404's picture
feat: add db apis
89157f5
Raw
History Blame Contribute Delete
3.81 kB
from __future__ import annotations
import asyncio
import sys
from typing import Any
import aiomysql
from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult
from app.core.logger import get_logger
_logger = get_logger(__name__)
class MySQLExecutor(BaseExecutor):
async def _create_pool(self) -> aiomysql.Pool:
return await aiomysql.create_pool(
host=self._config.host,
port=self._config.port,
user=self._config.username,
password=self._config.password,
db=self._config.database,
minsize=1,
maxsize=10,
autocommit=False,
connect_timeout=self._config.connection_timeout_seconds,
pool_recycle=3600,
)
async def _get_or_create_pool(self) -> aiomysql.Pool:
async with self._lock:
if self._pool is not None:
return self._pool
last_error: Exception | None = None
for attempt in range(self._max_connection_retries):
try:
self._pool = await asyncio.wait_for(
self._create_pool(),
timeout=self._config.connection_timeout_seconds,
)
return self._pool
except Exception as exc:
last_error = exc
_logger.warning(
"MySQL connection failed (attempt %d): %s",
attempt + 1, exc,
)
if attempt < self._max_connection_retries - 1:
await asyncio.sleep(0.1 * (2**attempt))
raise RuntimeError(
f"Failed to connect to MySQL after {self._max_connection_retries} attempts"
) from last_error
async def _execute_queries(
self, pool: aiomysql.Pool, queries: list[str], use_transaction: bool
) -> list[StatementResult]:
async with pool.acquire() as conn:
if use_transaction:
await conn.begin()
try:
results: list[StatementResult] = []
for query in queries:
result = await self._execute_one(conn, query)
results.append(result)
if not result.success and use_transaction:
await conn.rollback()
return results
if use_transaction:
await conn.commit()
return results
except Exception:
if use_transaction:
try:
await conn.rollback()
except Exception:
pass
raise
async def _execute_one(self, conn: Any, query: str) -> StatementResult:
try:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query)
if cursor.description:
rows = await cursor.fetchall()
data = [dict(row) for row in rows][: self._config.max_rows]
return StatementResult(success=True, rows=len(data), data=data)
await conn.commit()
return StatementResult(success=True, rows=cursor.rowcount, data=[])
except Exception as exc:
error_code = getattr(exc, "args", [None])[0]
if isinstance(error_code, int):
error_code = str(error_code)
else:
error_code = type(exc).__name__
return StatementResult(success=False, error=str(exc), error_code=error_code)
async def _close_pool(self, pool: aiomysql.Pool) -> None:
pool.close()
await pool.wait_closed()