from __future__ import annotations import asyncio import sys from typing import Any import aiomysql from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult from app.core.logger import get_logger _logger = get_logger(__name__) class MySQLExecutor(BaseExecutor): async def _create_pool(self) -> aiomysql.Pool: return await aiomysql.create_pool( host=self._config.host, port=self._config.port, user=self._config.username, password=self._config.password, db=self._config.database, minsize=1, maxsize=10, autocommit=False, connect_timeout=self._config.connection_timeout_seconds, pool_recycle=3600, ) async def _get_or_create_pool(self) -> aiomysql.Pool: async with self._lock: if self._pool is not None: return self._pool last_error: Exception | None = None for attempt in range(self._max_connection_retries): try: self._pool = await asyncio.wait_for( self._create_pool(), timeout=self._config.connection_timeout_seconds, ) return self._pool except Exception as exc: last_error = exc _logger.warning( "MySQL connection failed (attempt %d): %s", attempt + 1, exc, ) if attempt < self._max_connection_retries - 1: await asyncio.sleep(0.1 * (2**attempt)) raise RuntimeError( f"Failed to connect to MySQL after {self._max_connection_retries} attempts" ) from last_error async def _execute_queries( self, pool: aiomysql.Pool, queries: list[str], use_transaction: bool ) -> list[StatementResult]: async with pool.acquire() as conn: if use_transaction: await conn.begin() try: results: list[StatementResult] = [] for query in queries: result = await self._execute_one(conn, query) results.append(result) if not result.success and use_transaction: await conn.rollback() return results if use_transaction: await conn.commit() return results except Exception: if use_transaction: try: await conn.rollback() except Exception: pass raise async def _execute_one(self, conn: Any, query: str) -> StatementResult: try: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query) if cursor.description: rows = await cursor.fetchall() data = [dict(row) for row in rows][: self._config.max_rows] return StatementResult(success=True, rows=len(data), data=data) await conn.commit() return StatementResult(success=True, rows=cursor.rowcount, data=[]) except Exception as exc: error_code = getattr(exc, "args", [None])[0] if isinstance(error_code, int): error_code = str(error_code) else: error_code = type(exc).__name__ return StatementResult(success=False, error=str(exc), error_code=error_code) async def _close_pool(self, pool: aiomysql.Pool) -> None: pool.close() await pool.wait_closed()