llm-ready-data / app /core /database /mongodb.py
Soumik-404's picture
feat: add db apis
89157f5
Raw
History Blame Contribute Delete
3.72 kB
from __future__ import annotations
from typing import Any
from urllib.parse import quote_plus
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from app.core.database.base import BaseExecutor, ConnectionConfig, StatementResult
from app.core.logger import get_logger
_logger = get_logger(__name__)
_ATLAS_SRV_DOMAINS = frozenset({"mongodb.net", "mongodbatlas.com"})
_DATABASE_LEVEL_MARKER = "$database"
class MongoDBExecutor(BaseExecutor):
async def _create_pool(self) -> AsyncIOMotorClient:
client_kwargs: dict[str, Any] = {
"serverSelectionTimeoutMS": int(
self._config.connection_timeout_seconds * 1000
),
"connectTimeoutMS": int(self._config.connection_timeout_seconds * 1000),
"maxPoolSize": 10,
"minPoolSize": 1,
}
if self._config.ssl_enabled:
client_kwargs["tls"] = True
host = self._config.host
if self._is_atlas_srv(host):
return self._connect_via_uri(client_kwargs)
client_kwargs["host"] = host
client_kwargs["port"] = self._config.port
client_kwargs["username"] = self._config.username
client_kwargs["password"] = self._config.password
return AsyncIOMotorClient(**client_kwargs)
def _connect_via_uri(self, client_kwargs: dict[str, Any]) -> AsyncIOMotorClient:
escaped_user = quote_plus(self._config.username)
escaped_pass = quote_plus(self._config.password)
uri = (
f"mongodb+srv://{escaped_user}:{escaped_pass}@{self._config.host}/"
f"{self._config.database}?retryWrites=true&w=majority"
)
client_kwargs.pop("tls", None)
return AsyncIOMotorClient(uri, **client_kwargs)
@staticmethod
def _is_atlas_srv(host: str) -> bool:
host_lower = host.lower()
for domain in _ATLAS_SRV_DOMAINS:
if domain in host_lower:
return True
return host_lower.startswith("mongodb+srv://")
async def _execute_queries(
self,
pool: AsyncIOMotorClient,
queries: list[dict[str, Any]],
use_transaction: bool,
) -> list[StatementResult]:
db: AsyncIOMotorDatabase = pool[self._config.database]
results: list[StatementResult] = []
for query_item in queries:
collection_name = query_item.get("collection", _DATABASE_LEVEL_MARKER)
pipeline = query_item.get("pipeline", [])
try:
if collection_name == _DATABASE_LEVEL_MARKER:
cursor = db.aggregate(pipeline)
else:
cursor = db[collection_name].aggregate(pipeline)
data = []
async for doc in cursor:
if "_id" in doc:
doc["_id"] = str(doc["_id"])
data.append(doc)
if len(data) >= self._config.max_rows:
break
results.append(
StatementResult(success=True, rows=len(data), data=data)
)
except Exception as exc:
results.append(
StatementResult(
success=False,
error=str(exc),
error_code=type(exc).__name__,
)
)
break
return results
async def _close_pool(self, pool: AsyncIOMotorClient) -> None:
pool.close()
async def _get_or_create_pool(self) -> AsyncIOMotorClient:
if self._pool is not None:
return self._pool
return await super()._get_or_create_pool()