Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import sys | |
| from typing import Any | |
| import aiomysql | |
| from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult | |
| from app.core.logger import get_logger | |
| _logger = get_logger(__name__) | |
| class MySQLExecutor(BaseExecutor): | |
| async def _create_pool(self) -> aiomysql.Pool: | |
| return await aiomysql.create_pool( | |
| host=self._config.host, | |
| port=self._config.port, | |
| user=self._config.username, | |
| password=self._config.password, | |
| db=self._config.database, | |
| minsize=1, | |
| maxsize=10, | |
| autocommit=False, | |
| connect_timeout=self._config.connection_timeout_seconds, | |
| pool_recycle=3600, | |
| ) | |
| async def _get_or_create_pool(self) -> aiomysql.Pool: | |
| async with self._lock: | |
| if self._pool is not None: | |
| return self._pool | |
| last_error: Exception | None = None | |
| for attempt in range(self._max_connection_retries): | |
| try: | |
| self._pool = await asyncio.wait_for( | |
| self._create_pool(), | |
| timeout=self._config.connection_timeout_seconds, | |
| ) | |
| return self._pool | |
| except Exception as exc: | |
| last_error = exc | |
| _logger.warning( | |
| "MySQL connection failed (attempt %d): %s", | |
| attempt + 1, exc, | |
| ) | |
| if attempt < self._max_connection_retries - 1: | |
| await asyncio.sleep(0.1 * (2**attempt)) | |
| raise RuntimeError( | |
| f"Failed to connect to MySQL after {self._max_connection_retries} attempts" | |
| ) from last_error | |
| async def _execute_queries( | |
| self, pool: aiomysql.Pool, queries: list[str], use_transaction: bool | |
| ) -> list[StatementResult]: | |
| async with pool.acquire() as conn: | |
| if use_transaction: | |
| await conn.begin() | |
| try: | |
| results: list[StatementResult] = [] | |
| for query in queries: | |
| result = await self._execute_one(conn, query) | |
| results.append(result) | |
| if not result.success and use_transaction: | |
| await conn.rollback() | |
| return results | |
| if use_transaction: | |
| await conn.commit() | |
| return results | |
| except Exception: | |
| if use_transaction: | |
| try: | |
| await conn.rollback() | |
| except Exception: | |
| pass | |
| raise | |
| async def _execute_one(self, conn: Any, query: str) -> StatementResult: | |
| try: | |
| async with conn.cursor(aiomysql.DictCursor) as cursor: | |
| await cursor.execute(query) | |
| if cursor.description: | |
| rows = await cursor.fetchall() | |
| data = [dict(row) for row in rows][: self._config.max_rows] | |
| return StatementResult(success=True, rows=len(data), data=data) | |
| await conn.commit() | |
| return StatementResult(success=True, rows=cursor.rowcount, data=[]) | |
| except Exception as exc: | |
| error_code = getattr(exc, "args", [None])[0] | |
| if isinstance(error_code, int): | |
| error_code = str(error_code) | |
| else: | |
| error_code = type(exc).__name__ | |
| return StatementResult(success=False, error=str(exc), error_code=error_code) | |
| async def _close_pool(self, pool: aiomysql.Pool) -> None: | |
| pool.close() | |
| await pool.wait_closed() | |