from __future__ import annotations from typing import Any from urllib.parse import quote_plus from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult from app.core.logger import get_logger _logger = get_logger(__name__) _ATLAS_SRV_DOMAINS = frozenset({"mongodb.net", "mongodbatlas.com"}) _DATABASE_LEVEL_MARKER = "$database" class MongoDBExecutor(BaseExecutor): async def _create_pool(self) -> AsyncIOMotorClient: client_kwargs: dict[str, Any] = { "serverSelectionTimeoutMS": int( self._config.connection_timeout_seconds * 1000 ), "connectTimeoutMS": int(self._config.connection_timeout_seconds * 1000), "maxPoolSize": 10, "minPoolSize": 1, } if self._config.ssl_enabled: client_kwargs["tls"] = True host = self._config.host if self._is_atlas_srv(host): return self._connect_via_uri(client_kwargs) client_kwargs["host"] = host client_kwargs["port"] = self._config.port client_kwargs["username"] = self._config.username client_kwargs["password"] = self._config.password return AsyncIOMotorClient(**client_kwargs) def _connect_via_uri(self, client_kwargs: dict[str, Any]) -> AsyncIOMotorClient: escaped_user = quote_plus(self._config.username) escaped_pass = quote_plus(self._config.password) uri = ( f"mongodb+srv://{escaped_user}:{escaped_pass}@{self._config.host}/" f"{self._config.database}?retryWrites=true&w=majority" ) client_kwargs.pop("tls", None) return AsyncIOMotorClient(uri, **client_kwargs) @staticmethod def _is_atlas_srv(host: str) -> bool: host_lower = host.lower() for domain in _ATLAS_SRV_DOMAINS: if domain in host_lower: return True return host_lower.startswith("mongodb+srv://") async def _execute_queries( self, pool: AsyncIOMotorClient, queries: list[dict[str, Any]], use_transaction: bool, ) -> list[StatementResult]: db: AsyncIOMotorDatabase = pool[self._config.database] results: list[StatementResult] = [] for query_item in queries: collection_name = query_item.get("collection", _DATABASE_LEVEL_MARKER) pipeline = query_item.get("pipeline", []) try: if collection_name == _DATABASE_LEVEL_MARKER: cursor = db.aggregate(pipeline) else: cursor = db[collection_name].aggregate(pipeline) data = [] async for doc in cursor: if "_id" in doc: doc["_id"] = str(doc["_id"]) data.append(doc) if len(data) >= self._config.max_rows: break results.append( StatementResult(success=True, rows=len(data), data=data) ) except Exception as exc: results.append( StatementResult( success=False, error=str(exc), error_code=type(exc).__name__, ) ) break return results async def _close_pool(self, pool: AsyncIOMotorClient) -> None: pool.close() async def _get_or_create_pool(self) -> AsyncIOMotorClient: if self._pool is not None: return self._pool return await super()._get_or_create_pool()