Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import Any | |
| from urllib.parse import quote_plus | |
| from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase | |
| from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult | |
| from app.core.logger import get_logger | |
| _logger = get_logger(__name__) | |
| _ATLAS_SRV_DOMAINS = frozenset({"mongodb.net", "mongodbatlas.com"}) | |
| _DATABASE_LEVEL_MARKER = "$database" | |
| class MongoDBExecutor(BaseExecutor): | |
| async def _create_pool(self) -> AsyncIOMotorClient: | |
| client_kwargs: dict[str, Any] = { | |
| "serverSelectionTimeoutMS": int( | |
| self._config.connection_timeout_seconds * 1000 | |
| ), | |
| "connectTimeoutMS": int(self._config.connection_timeout_seconds * 1000), | |
| "maxPoolSize": 10, | |
| "minPoolSize": 1, | |
| } | |
| if self._config.ssl_enabled: | |
| client_kwargs["tls"] = True | |
| host = self._config.host | |
| if self._is_atlas_srv(host): | |
| return self._connect_via_uri(client_kwargs) | |
| client_kwargs["host"] = host | |
| client_kwargs["port"] = self._config.port | |
| client_kwargs["username"] = self._config.username | |
| client_kwargs["password"] = self._config.password | |
| return AsyncIOMotorClient(**client_kwargs) | |
| def _connect_via_uri(self, client_kwargs: dict[str, Any]) -> AsyncIOMotorClient: | |
| escaped_user = quote_plus(self._config.username) | |
| escaped_pass = quote_plus(self._config.password) | |
| uri = ( | |
| f"mongodb+srv://{escaped_user}:{escaped_pass}@{self._config.host}/" | |
| f"{self._config.database}?retryWrites=true&w=majority" | |
| ) | |
| client_kwargs.pop("tls", None) | |
| return AsyncIOMotorClient(uri, **client_kwargs) | |
| def _is_atlas_srv(host: str) -> bool: | |
| host_lower = host.lower() | |
| for domain in _ATLAS_SRV_DOMAINS: | |
| if domain in host_lower: | |
| return True | |
| return host_lower.startswith("mongodb+srv://") | |
| async def _execute_queries( | |
| self, | |
| pool: AsyncIOMotorClient, | |
| queries: list[dict[str, Any]], | |
| use_transaction: bool, | |
| ) -> list[StatementResult]: | |
| db: AsyncIOMotorDatabase = pool[self._config.database] | |
| results: list[StatementResult] = [] | |
| for query_item in queries: | |
| collection_name = query_item.get("collection", _DATABASE_LEVEL_MARKER) | |
| pipeline = query_item.get("pipeline", []) | |
| try: | |
| if collection_name == _DATABASE_LEVEL_MARKER: | |
| cursor = db.aggregate(pipeline) | |
| else: | |
| cursor = db[collection_name].aggregate(pipeline) | |
| data = [] | |
| async for doc in cursor: | |
| if "_id" in doc: | |
| doc["_id"] = str(doc["_id"]) | |
| data.append(doc) | |
| if len(data) >= self._config.max_rows: | |
| break | |
| results.append( | |
| StatementResult(success=True, rows=len(data), data=data) | |
| ) | |
| except Exception as exc: | |
| results.append( | |
| StatementResult( | |
| success=False, | |
| error=str(exc), | |
| error_code=type(exc).__name__, | |
| ) | |
| ) | |
| break | |
| return results | |
| async def _close_pool(self, pool: AsyncIOMotorClient) -> None: | |
| pool.close() | |
| async def _get_or_create_pool(self) -> AsyncIOMotorClient: | |
| if self._pool is not None: | |
| return self._pool | |
| return await super()._get_or_create_pool() | |