|
|
import time |
|
|
import json |
|
|
from typing import Dict, Any, List, Union, Optional |
|
|
from pathlib import Path |
|
|
from bson import json_util |
|
|
from pymongo import MongoClient |
|
|
|
|
|
from .database_base import DatabaseBase, DatabaseType, QueryType, DatabaseConnection |
|
|
from .tool import Tool, Toolkit |
|
|
from ..core.logging import logger |
|
|
|
|
|
|
|
|
class MongoDBConnection(DatabaseConnection): |
|
|
"""MongoDB-specific connection management""" |
|
|
|
|
|
def __init__(self, connection_string: str, **kwargs): |
|
|
super().__init__(connection_string, **kwargs) |
|
|
self.client = None |
|
|
self.database = None |
|
|
|
|
|
def connect(self) -> bool: |
|
|
"""Establish connection to MongoDB""" |
|
|
try: |
|
|
|
|
|
if "mongodb://" in self.connection_string or "mongodb+srv://" in self.connection_string: |
|
|
|
|
|
self.client = MongoClient(self.connection_string, **self.connection_params) |
|
|
else: |
|
|
|
|
|
self.client = MongoClient(self.connection_string, **self.connection_params) |
|
|
|
|
|
|
|
|
self.client.admin.command('ping') |
|
|
self._is_connected = True |
|
|
logger.info("Successfully connected to MongoDB") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to connect to MongoDB: {str(e)}") |
|
|
self._is_connected = False |
|
|
return False |
|
|
|
|
|
def disconnect(self) -> bool: |
|
|
"""Close MongoDB connection""" |
|
|
try: |
|
|
if self.client: |
|
|
self.client.close() |
|
|
self.client = None |
|
|
self.database = None |
|
|
self._is_connected = False |
|
|
logger.info("Disconnected from MongoDB") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error disconnecting from MongoDB: {str(e)}") |
|
|
return False |
|
|
|
|
|
def test_connection(self) -> bool: |
|
|
"""Test MongoDB connection""" |
|
|
try: |
|
|
if self.client: |
|
|
self.client.admin.command('ping') |
|
|
return True |
|
|
return False |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def get_database(self, database_name: str): |
|
|
"""Get database instance""" |
|
|
if self.client and database_name: |
|
|
return self.client[database_name] |
|
|
return None |
|
|
|
|
|
|
|
|
class MongoDBDatabase(DatabaseBase): |
|
|
""" |
|
|
MongoDB database implementation with automatic initialization. |
|
|
Handles remote connections, existing local databases, and new local database creation. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
connection_string: str = None, |
|
|
database_name: str = None, |
|
|
local_path: str = None, |
|
|
auto_save: bool = True, |
|
|
read_only: bool = False, |
|
|
**kwargs): |
|
|
""" |
|
|
Initialize MongoDB database with automatic detection and setup. |
|
|
|
|
|
Args: |
|
|
connection_string: MongoDB connection string (for remote) |
|
|
database_name: Name of the database |
|
|
local_path: Path for local file-based database |
|
|
auto_save: Automatically save changes to local files |
|
|
read_only: If True, only read operations are allowed (no insert, update, delete) |
|
|
**kwargs: Additional connection parameters |
|
|
""" |
|
|
|
|
|
init_params = { |
|
|
'connection_string': connection_string, |
|
|
'database_name': database_name |
|
|
} |
|
|
|
|
|
|
|
|
super().__init__(**init_params, **kwargs) |
|
|
|
|
|
|
|
|
self.local_path = Path(local_path) if local_path else None |
|
|
self.auto_save = auto_save |
|
|
self.read_only = read_only |
|
|
self.connection_params = kwargs |
|
|
|
|
|
|
|
|
self.is_local_database = False |
|
|
self.client = None |
|
|
self.database = None |
|
|
|
|
|
|
|
|
if self._is_remote_connection(): |
|
|
self._init_remote_database() |
|
|
elif self._is_existing_local_database(): |
|
|
self._init_existing_local_database() |
|
|
else: |
|
|
self._init_new_local_database() |
|
|
|
|
|
def _is_remote_connection(self) -> bool: |
|
|
"""Check if this is a remote MongoDB connection""" |
|
|
return (self.connection_string and |
|
|
(self.connection_string.startswith(('mongodb://', 'mongodb+srv://')) or |
|
|
'localhost' in self.connection_string or |
|
|
'127.0.0.1' in self.connection_string)) |
|
|
|
|
|
def _is_existing_local_database(self) -> bool: |
|
|
"""Check if there's an existing local database""" |
|
|
if not self.local_path: |
|
|
return False |
|
|
|
|
|
if not self.local_path.exists(): |
|
|
return False |
|
|
|
|
|
|
|
|
json_files = list(self.local_path.glob("*.json")) |
|
|
db_info_file = self.local_path / "db_info.json" |
|
|
|
|
|
return len(json_files) > 0 or db_info_file.exists() |
|
|
|
|
|
def _init_remote_database(self): |
|
|
"""Initialize remote MongoDB connection""" |
|
|
try: |
|
|
self.client = MongoClient(self.connection_string, **self.connection_params) |
|
|
self.client.admin.command('ping') |
|
|
|
|
|
if self.database_name: |
|
|
self.database = self.client[self.database_name] |
|
|
|
|
|
self._is_initialized = True |
|
|
self.is_local_database = False |
|
|
logger.info(f"Connected to remote MongoDB: {self.database_name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to connect to remote MongoDB: {str(e)}") |
|
|
self._is_initialized = False |
|
|
raise |
|
|
|
|
|
def _init_existing_local_database(self): |
|
|
"""Initialize existing local database""" |
|
|
try: |
|
|
|
|
|
self.connection_string = "mongodb://localhost:27017" |
|
|
self.client = MongoClient(self.connection_string, **self.connection_params) |
|
|
|
|
|
|
|
|
if not self.database_name: |
|
|
self.database_name = self.local_path.name |
|
|
|
|
|
self.database = self.client[self.database_name] |
|
|
|
|
|
|
|
|
self._load_local_collections() |
|
|
|
|
|
self._is_initialized = True |
|
|
self.is_local_database = True |
|
|
logger.info(f"Loaded existing local database from: {self.local_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load existing local database: {str(e)}") |
|
|
self._is_initialized = False |
|
|
raise |
|
|
|
|
|
def _init_new_local_database(self): |
|
|
"""Initialize new local database""" |
|
|
try: |
|
|
|
|
|
if not self.local_path: |
|
|
self.local_path = Path("./mongodb_local") |
|
|
|
|
|
|
|
|
self.local_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.connection_string = "mongodb://localhost:27017" |
|
|
self.client = MongoClient(self.connection_string, **self.connection_params) |
|
|
|
|
|
|
|
|
if not self.database_name: |
|
|
self.database_name = self.local_path.name |
|
|
|
|
|
self.database = self.client[self.database_name] |
|
|
|
|
|
|
|
|
self._create_db_info_file() |
|
|
|
|
|
self._is_initialized = True |
|
|
self.is_local_database = True |
|
|
logger.info(f"Created new local database at: {self.local_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create new local database: {str(e)}") |
|
|
self._is_initialized = False |
|
|
raise |
|
|
|
|
|
def _load_local_collections(self): |
|
|
"""Load collections from local JSON files""" |
|
|
if not self.local_path or not self.local_path.exists(): |
|
|
return |
|
|
|
|
|
json_files = [f for f in self.local_path.glob("*.json") if f.name != "db_info.json"] |
|
|
|
|
|
for json_file in json_files: |
|
|
collection_name = json_file.stem |
|
|
try: |
|
|
with open(json_file, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if isinstance(data, dict): |
|
|
documents = [data] |
|
|
elif isinstance(data, list): |
|
|
documents = data |
|
|
else: |
|
|
continue |
|
|
|
|
|
if documents: |
|
|
|
|
|
cleaned_documents = [] |
|
|
for doc in documents: |
|
|
cleaned_doc = self._clean_document_for_insert(doc) |
|
|
cleaned_documents.append(cleaned_doc) |
|
|
|
|
|
collection = self.database[collection_name] |
|
|
collection.drop() |
|
|
if cleaned_documents: |
|
|
collection.insert_many(cleaned_documents) |
|
|
logger.info(f"Loaded {len(cleaned_documents)} documents into '{collection_name}'") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load collection from {json_file}: {str(e)}") |
|
|
|
|
|
def _clean_document_for_insert(self, doc: Dict) -> Dict: |
|
|
"""Clean document by removing problematic MongoDB-specific fields""" |
|
|
if isinstance(doc, dict): |
|
|
cleaned = {} |
|
|
for key, value in doc.items(): |
|
|
if key == '_id' and isinstance(value, dict) and '$oid' in value: |
|
|
|
|
|
continue |
|
|
elif isinstance(value, dict): |
|
|
cleaned[key] = self._clean_document_for_insert(value) |
|
|
elif isinstance(value, list): |
|
|
cleaned[key] = [self._clean_document_for_insert(item) if isinstance(item, dict) else item for item in value] |
|
|
else: |
|
|
cleaned[key] = value |
|
|
return cleaned |
|
|
return doc |
|
|
|
|
|
def _create_db_info_file(self): |
|
|
"""Create database info file for new local database""" |
|
|
try: |
|
|
db_info = { |
|
|
"database_name": self.database_name, |
|
|
"created_at": time.time(), |
|
|
"local_path": str(self.local_path.absolute()), |
|
|
"auto_save": self.auto_save, |
|
|
"version": "1.0" |
|
|
} |
|
|
|
|
|
info_file = self.local_path / "db_info.json" |
|
|
with open(info_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(db_info, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create db info file: {str(e)}") |
|
|
|
|
|
def _save_collection_to_file(self, collection_name: str): |
|
|
"""Save collection to local JSON file""" |
|
|
if not self.is_local_database or not self.local_path: |
|
|
return |
|
|
|
|
|
try: |
|
|
collection = self.database[collection_name] |
|
|
documents = list(collection.find()) |
|
|
|
|
|
|
|
|
for doc in documents: |
|
|
if '_id' in doc: |
|
|
doc['_id'] = str(doc['_id']) |
|
|
|
|
|
file_path = self.local_path / f"{collection_name}.json" |
|
|
with open(file_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(documents, f, indent=2, ensure_ascii=False, default=str) |
|
|
|
|
|
logger.debug(f"Saved collection '{collection_name}' to {file_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to save collection '{collection_name}': {str(e)}") |
|
|
|
|
|
def _auto_save_if_needed(self, collection_name: str): |
|
|
"""Auto-save collection if local database and auto_save is enabled""" |
|
|
if self.is_local_database and self.auto_save: |
|
|
self._save_collection_to_file(collection_name) |
|
|
|
|
|
def _get_database_type(self) -> DatabaseType: |
|
|
return DatabaseType.MONGODB |
|
|
|
|
|
def connect(self) -> bool: |
|
|
"""Connection is already established in __init__""" |
|
|
return self._is_initialized |
|
|
|
|
|
def disconnect(self) -> bool: |
|
|
"""Close MongoDB connection""" |
|
|
try: |
|
|
if self.client: |
|
|
self.client.close() |
|
|
self.client = None |
|
|
self.database = None |
|
|
self._is_initialized = False |
|
|
logger.info("Disconnected from MongoDB") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error disconnecting: {str(e)}") |
|
|
return False |
|
|
|
|
|
def test_connection(self) -> bool: |
|
|
"""Test MongoDB connection""" |
|
|
try: |
|
|
if self.client: |
|
|
self.client.admin.command('ping') |
|
|
return True |
|
|
return False |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def execute_query(self, |
|
|
query: Union[str, Dict, List], |
|
|
query_type: QueryType = None, |
|
|
collection_name: str = None, |
|
|
**kwargs) -> Dict[str, Any]: |
|
|
"""Execute a query on MongoDB with automatic result handling""" |
|
|
if not self._is_initialized or self.database is None: |
|
|
return self.format_error_result("Database not connected") |
|
|
|
|
|
if not collection_name: |
|
|
return self.format_error_result("Collection name is required") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
collection = self.database[collection_name] |
|
|
|
|
|
|
|
|
if not query_type: |
|
|
query_type = self._infer_query_type(query) |
|
|
|
|
|
|
|
|
if self.read_only and query_type in [QueryType.INSERT, QueryType.UPDATE, QueryType.DELETE, QueryType.CREATE, QueryType.DROP]: |
|
|
return self.format_error_result( |
|
|
f"Write operation '{query_type.value}' is not allowed in read-only mode. " |
|
|
"Only SELECT and AGGREGATE operations are permitted.", |
|
|
query_type, |
|
|
execution_time=time.time() - start_time |
|
|
) |
|
|
|
|
|
|
|
|
if query_type == QueryType.SELECT: |
|
|
result = self._execute_find(collection, query, **kwargs) |
|
|
elif query_type == QueryType.INSERT: |
|
|
result = self._execute_insert(collection, query, **kwargs) |
|
|
self._auto_save_if_needed(collection_name) |
|
|
elif query_type == QueryType.UPDATE: |
|
|
result = self._execute_update(collection, query, **kwargs) |
|
|
self._auto_save_if_needed(collection_name) |
|
|
elif query_type == QueryType.DELETE: |
|
|
result = self._execute_delete(collection, query, **kwargs) |
|
|
self._auto_save_if_needed(collection_name) |
|
|
elif query_type == QueryType.AGGREGATE: |
|
|
result = self._execute_aggregate(collection, query, **kwargs) |
|
|
else: |
|
|
return self.format_error_result(f"Unsupported query type: {query_type}") |
|
|
|
|
|
execution_time = time.time() - start_time |
|
|
if isinstance(result, dict): |
|
|
result["execution_time"] = execution_time |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
execution_time = time.time() - start_time |
|
|
logger.error(f"Error executing MongoDB query: {str(e)}") |
|
|
return self.format_error_result(str(e), query_type, execution_time=execution_time) |
|
|
|
|
|
def _infer_query_type(self, query: Union[str, Dict, List]) -> QueryType: |
|
|
"""Infer query type from the query structure""" |
|
|
if isinstance(query, list): |
|
|
return QueryType.AGGREGATE |
|
|
elif isinstance(query, dict): |
|
|
|
|
|
if self.read_only: |
|
|
if "insert" in query or "insertOne" in query or "insertMany" in query: |
|
|
return QueryType.SELECT |
|
|
elif "update" in query or "updateOne" in query or "updateMany" in query: |
|
|
return QueryType.SELECT |
|
|
elif "delete" in query or "deleteOne" in query or "deleteMany" in query: |
|
|
return QueryType.SELECT |
|
|
elif "create" in query or "createCollection" in query: |
|
|
return QueryType.SELECT |
|
|
elif "drop" in query or "dropCollection" in query: |
|
|
return QueryType.SELECT |
|
|
else: |
|
|
return QueryType.SELECT |
|
|
else: |
|
|
|
|
|
if "insert" in query or "insertOne" in query or "insertMany" in query: |
|
|
return QueryType.INSERT |
|
|
elif "update" in query or "updateOne" in query or "updateMany" in query: |
|
|
return QueryType.UPDATE |
|
|
elif "delete" in query or "deleteOne" in query or "deleteMany" in query: |
|
|
return QueryType.DELETE |
|
|
elif "create" in query or "createCollection" in query: |
|
|
return QueryType.CREATE |
|
|
elif "drop" in query or "dropCollection" in query: |
|
|
return QueryType.DROP |
|
|
else: |
|
|
return QueryType.SELECT |
|
|
elif isinstance(query, str): |
|
|
query_lower = query.lower().strip() |
|
|
if self.read_only: |
|
|
|
|
|
return QueryType.SELECT |
|
|
else: |
|
|
|
|
|
if query_lower.startswith(("insert", "create")): |
|
|
return QueryType.INSERT |
|
|
elif query_lower.startswith("update"): |
|
|
return QueryType.UPDATE |
|
|
elif query_lower.startswith("delete"): |
|
|
return QueryType.DELETE |
|
|
elif query_lower.startswith("drop"): |
|
|
return QueryType.DROP |
|
|
else: |
|
|
return QueryType.SELECT |
|
|
|
|
|
return QueryType.SELECT |
|
|
|
|
|
def _execute_find(self, collection, query: Dict, **kwargs) -> Dict[str, Any]: |
|
|
"""Execute find query""" |
|
|
try: |
|
|
|
|
|
if isinstance(query, str): |
|
|
|
|
|
if "=" in query: |
|
|
field, value = query.split("=", 1) |
|
|
query = {field.strip(): value.strip()} |
|
|
else: |
|
|
query = {} |
|
|
|
|
|
|
|
|
filter_query = query.get("filter", query) |
|
|
projection = query.get("projection", {}) |
|
|
sort = query.get("sort", None) |
|
|
limit = query.get("limit", kwargs.get("limit", 0)) |
|
|
skip = query.get("skip", kwargs.get("skip", 0)) |
|
|
|
|
|
|
|
|
cursor = collection.find(filter_query, projection) |
|
|
|
|
|
if sort: |
|
|
cursor = cursor.sort(sort) |
|
|
if skip: |
|
|
cursor = cursor.skip(skip) |
|
|
if limit: |
|
|
cursor = cursor.limit(limit) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for doc in cursor: |
|
|
|
|
|
doc = json.loads(json_util.dumps(doc)) |
|
|
results.append(doc) |
|
|
|
|
|
return self.format_query_result( |
|
|
results, |
|
|
QueryType.SELECT, |
|
|
collection_name=collection.name, |
|
|
filter_applied=filter_query |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e), QueryType.SELECT) |
|
|
|
|
|
def _execute_insert(self, collection, query: Union[Dict, List], **kwargs) -> Dict[str, Any]: |
|
|
"""Execute insert operation""" |
|
|
try: |
|
|
if isinstance(query, dict): |
|
|
|
|
|
if "document" in query: |
|
|
document = query["document"] |
|
|
else: |
|
|
document = query |
|
|
|
|
|
result = collection.insert_one(document) |
|
|
return self.format_query_result( |
|
|
{"inserted_id": str(result.inserted_id)}, |
|
|
QueryType.INSERT, |
|
|
collection_name=collection.name |
|
|
) |
|
|
elif isinstance(query, list): |
|
|
|
|
|
if all(isinstance(item, dict) for item in query): |
|
|
documents = query |
|
|
else: |
|
|
documents = [{"documents": query}] |
|
|
|
|
|
result = collection.insert_many(documents) |
|
|
return self.format_query_result( |
|
|
{"inserted_ids": [str(id) for id in result.inserted_ids]}, |
|
|
QueryType.INSERT, |
|
|
collection_name=collection.name |
|
|
) |
|
|
else: |
|
|
return self.format_error_result("Invalid insert query format", QueryType.INSERT) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e), QueryType.INSERT) |
|
|
|
|
|
def _execute_update(self, collection, query: Dict, **kwargs) -> Dict[str, Any]: |
|
|
"""Execute update operation""" |
|
|
try: |
|
|
filter_query = query.get("filter", {}) |
|
|
update_query = query.get("update", {}) |
|
|
upsert = query.get("upsert", False) |
|
|
multi = query.get("multi", False) |
|
|
|
|
|
if multi: |
|
|
result = collection.update_many(filter_query, update_query, upsert=upsert) |
|
|
else: |
|
|
result = collection.update_one(filter_query, update_query, upsert=upsert) |
|
|
|
|
|
return self.format_query_result( |
|
|
{ |
|
|
"matched_count": result.matched_count, |
|
|
"modified_count": result.modified_count, |
|
|
"upserted_id": str(result.upserted_id) if result.upserted_id else None |
|
|
}, |
|
|
QueryType.UPDATE, |
|
|
collection_name=collection.name |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e), QueryType.UPDATE) |
|
|
|
|
|
def _execute_delete(self, collection, query: Dict, **kwargs) -> Dict[str, Any]: |
|
|
"""Execute delete operation""" |
|
|
try: |
|
|
filter_query = query.get("filter", query) |
|
|
multi = query.get("multi", False) |
|
|
|
|
|
if multi: |
|
|
result = collection.delete_many(filter_query) |
|
|
else: |
|
|
result = collection.delete_one(filter_query) |
|
|
|
|
|
return self.format_query_result( |
|
|
{"deleted_count": result.deleted_count}, |
|
|
QueryType.DELETE, |
|
|
collection_name=collection.name |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e), QueryType.DELETE) |
|
|
|
|
|
def _execute_aggregate(self, collection, pipeline: List, **kwargs) -> Dict[str, Any]: |
|
|
"""Execute aggregation pipeline""" |
|
|
try: |
|
|
cursor = collection.aggregate(pipeline) |
|
|
results = [] |
|
|
|
|
|
for doc in cursor: |
|
|
|
|
|
doc = json.loads(json_util.dumps(doc)) |
|
|
results.append(doc) |
|
|
|
|
|
return self.format_query_result( |
|
|
results, |
|
|
QueryType.AGGREGATE, |
|
|
collection_name=collection.name, |
|
|
pipeline_stages=len(pipeline) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e), QueryType.AGGREGATE) |
|
|
|
|
|
|
|
|
|
|
|
def get_database_info(self) -> Dict[str, Any]: |
|
|
"""Get MongoDB database information""" |
|
|
try: |
|
|
if not self._is_initialized or self.database is None: |
|
|
return self.format_error_result("Database not connected") |
|
|
|
|
|
|
|
|
stats = self.database.command("dbStats") |
|
|
|
|
|
|
|
|
server_info = self.client.server_info() |
|
|
|
|
|
info = { |
|
|
"database_name": self.database_name, |
|
|
"collections": stats.get("collections", 0), |
|
|
"data_size": stats.get("dataSize", 0), |
|
|
"storage_size": stats.get("storageSize", 0), |
|
|
"indexes": stats.get("indexes", 0), |
|
|
"index_size": stats.get("indexSize", 0), |
|
|
"server_version": server_info.get("version", "Unknown"), |
|
|
"server_type": server_info.get("type", "Unknown"), |
|
|
"connection_string": self.connection_string, |
|
|
"is_connected": self._is_initialized |
|
|
} |
|
|
|
|
|
return self.format_query_result(info, QueryType.SELECT) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e)) |
|
|
|
|
|
def list_collections(self) -> List[str]: |
|
|
"""List all collections in the database""" |
|
|
try: |
|
|
if not self._is_initialized or self.database is None: |
|
|
return [] |
|
|
|
|
|
return self.database.list_collection_names() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error listing collections: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def get_collection_info(self, collection_name: str) -> Dict[str, Any]: |
|
|
"""Get information about a specific collection""" |
|
|
try: |
|
|
if not self._is_initialized or not self.database: |
|
|
return self.format_error_result("Database not connected") |
|
|
|
|
|
collection = self.database[collection_name] |
|
|
|
|
|
|
|
|
stats = self.database.command("collStats", collection_name) |
|
|
|
|
|
|
|
|
indexes = list(collection.list_indexes()) |
|
|
|
|
|
|
|
|
sample_docs = list(collection.find().limit(5)) |
|
|
|
|
|
info = { |
|
|
"collection_name": collection_name, |
|
|
"document_count": stats.get("count", 0), |
|
|
"data_size": stats.get("size", 0), |
|
|
"storage_size": stats.get("storageSize", 0), |
|
|
"index_count": stats.get("nindexes", 0), |
|
|
"indexes": [{"name": idx["name"], "keys": idx["key"]} for idx in indexes], |
|
|
"sample_documents": sample_docs[:2] |
|
|
} |
|
|
|
|
|
return self.format_query_result(info, QueryType.SELECT) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e)) |
|
|
|
|
|
def get_schema(self, collection_name: str = None) -> Dict[str, Any]: |
|
|
"""Get schema information for database or specific collection""" |
|
|
try: |
|
|
if not self._is_initialized or not self.database: |
|
|
return self.format_error_result("Database not connected") |
|
|
|
|
|
if collection_name: |
|
|
|
|
|
collection = self.database[collection_name] |
|
|
sample_docs = list(collection.find().limit(100)) |
|
|
|
|
|
if not sample_docs: |
|
|
return self.format_query_result( |
|
|
{"collection_name": collection_name, "schema": {}, "message": "No documents found"}, |
|
|
QueryType.SELECT |
|
|
) |
|
|
|
|
|
|
|
|
schema = self._infer_schema_from_documents(sample_docs) |
|
|
|
|
|
return self.format_query_result( |
|
|
{ |
|
|
"collection_name": collection_name, |
|
|
"schema": schema, |
|
|
"sample_count": len(sample_docs) |
|
|
}, |
|
|
QueryType.SELECT |
|
|
) |
|
|
else: |
|
|
|
|
|
collections = self.list_collections() |
|
|
schemas = {} |
|
|
|
|
|
for coll_name in collections[:10]: |
|
|
coll_schema = self.get_schema(coll_name) |
|
|
if coll_schema.get("success"): |
|
|
schemas[coll_name] = coll_schema.get("data", {}).get("schema", {}) |
|
|
|
|
|
return self.format_query_result( |
|
|
{"database_name": self.database_name, "schemas": schemas}, |
|
|
QueryType.SELECT |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e)) |
|
|
|
|
|
def _infer_schema_from_documents(self, documents: List[Dict]) -> Dict[str, Any]: |
|
|
"""Infer schema from a list of documents""" |
|
|
if not documents: |
|
|
return {} |
|
|
|
|
|
schema = {} |
|
|
|
|
|
for doc in documents: |
|
|
self._update_schema_from_document(schema, doc) |
|
|
|
|
|
return schema |
|
|
|
|
|
def _update_schema_from_document(self, schema: Dict, document: Dict, path: str = ""): |
|
|
"""Recursively update schema from a document""" |
|
|
for key, value in document.items(): |
|
|
current_path = f"{path}.{key}" if path else key |
|
|
|
|
|
if isinstance(value, dict): |
|
|
if current_path not in schema: |
|
|
schema[current_path] = {"type": "object", "fields": {}} |
|
|
self._update_schema_from_document(schema[current_path]["fields"], value, current_path) |
|
|
elif isinstance(value, list): |
|
|
if current_path not in schema: |
|
|
schema[current_path] = {"type": "array", "element_types": set()} |
|
|
|
|
|
for item in value[:3]: |
|
|
if isinstance(item, dict): |
|
|
schema[current_path]["element_types"].add("object") |
|
|
else: |
|
|
schema[current_path]["element_types"].add(type(item).__name__) |
|
|
schema[current_path]["element_types"] = list(schema[current_path]["element_types"]) |
|
|
else: |
|
|
if current_path not in schema: |
|
|
schema[current_path] = {"type": type(value).__name__} |
|
|
elif schema[current_path]["type"] != type(value).__name__: |
|
|
|
|
|
schema[current_path]["type"] = "mixed" |
|
|
|
|
|
def get_supported_query_types(self) -> List[QueryType]: |
|
|
"""Get MongoDB-specific supported query types""" |
|
|
if self.read_only: |
|
|
return [ |
|
|
QueryType.SELECT, |
|
|
QueryType.AGGREGATE |
|
|
] |
|
|
else: |
|
|
return [ |
|
|
QueryType.SELECT, |
|
|
QueryType.INSERT, |
|
|
QueryType.UPDATE, |
|
|
QueryType.DELETE, |
|
|
QueryType.CREATE, |
|
|
QueryType.DROP, |
|
|
QueryType.AGGREGATE, |
|
|
QueryType.INDEX |
|
|
] |
|
|
|
|
|
def get_capabilities(self) -> Dict[str, Any]: |
|
|
"""Get MongoDB-specific capabilities""" |
|
|
base_capabilities = super().get_capabilities() |
|
|
base_capabilities.update({ |
|
|
"supports_aggregation": True, |
|
|
"supports_full_text_search": True, |
|
|
"supports_geospatial_queries": True, |
|
|
"supports_change_streams": True, |
|
|
"supports_transactions": True, |
|
|
"supports_indexing": True, |
|
|
"document_oriented": True, |
|
|
"schema_flexible": True, |
|
|
"read_only": self.read_only, |
|
|
"write_operations_allowed": not self.read_only |
|
|
}) |
|
|
return base_capabilities |
|
|
|
|
|
|
|
|
class MongoDBExecuteQueryTool(Tool): |
|
|
name: str = "mongodb_execute_query" |
|
|
description: str = "Execute MongoDB queries including find and aggregation pipelines (read-only operations)" |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "MongoDB query (JSON string for find, array for aggregation pipeline)" |
|
|
}, |
|
|
"query_type": { |
|
|
"type": "string", |
|
|
"description": "Type of query (select, aggregate) - auto-detected if not provided" |
|
|
}, |
|
|
"collection_name": { |
|
|
"type": "string", |
|
|
"description": "Collection name (required for all operations)" |
|
|
} |
|
|
} |
|
|
required: Optional[List[str]] = ["query", "collection_name"] |
|
|
|
|
|
def __init__(self, database: MongoDBDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
|
|
|
def __call__(self, query: str, query_type: str = None, collection_name: str = None) -> Dict[str, Any]: |
|
|
"""Execute a MongoDB query""" |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "MongoDB database not initialized", "data": None} |
|
|
|
|
|
|
|
|
parsed_query = self._parse_query(query) |
|
|
|
|
|
|
|
|
query_type_enum = None |
|
|
if query_type: |
|
|
try: |
|
|
query_type_enum = QueryType(query_type.lower()) |
|
|
except ValueError: |
|
|
return {"success": False, "error": f"Invalid query type: {query_type}", "data": None} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = self.database.execute_query( |
|
|
query=parsed_query, |
|
|
query_type=query_type_enum, |
|
|
collection_name=collection_name |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
logger.info(f"Successfully executed MongoDB query on collection {collection_name}") |
|
|
else: |
|
|
logger.error(f"Failed to execute MongoDB query: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in mongodb_execute_query tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
def _parse_query(self, query: str) -> Union[str, Dict, List]: |
|
|
"""Parse query string into appropriate format""" |
|
|
try: |
|
|
import json |
|
|
return json.loads(query) |
|
|
except (json.JSONDecodeError, ValueError): |
|
|
return query |
|
|
|
|
|
|
|
|
class MongoDBFindTool(Tool): |
|
|
name: str = "mongodb_find" |
|
|
description: str = "Find documents in a MongoDB collection with filtering, projection, sorting, and pagination" |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"collection_name": { |
|
|
"type": "string", |
|
|
"description": "Collection name to query" |
|
|
}, |
|
|
"filter": { |
|
|
"type": "string", |
|
|
"description": "MongoDB filter query (JSON string, e.g., '{\"age\": {\"$gt\": 18}}')" |
|
|
}, |
|
|
"projection": { |
|
|
"type": "string", |
|
|
"description": "Fields to include/exclude (JSON string, e.g., '{\"name\": 1, \"_id\": 0}')" |
|
|
}, |
|
|
"sort": { |
|
|
"type": "string", |
|
|
"description": "Sort criteria (JSON string, e.g., '{\"age\": -1}')" |
|
|
}, |
|
|
"limit": { |
|
|
"type": "integer", |
|
|
"description": "Maximum number of documents to return" |
|
|
}, |
|
|
"skip": { |
|
|
"type": "integer", |
|
|
"description": "Number of documents to skip" |
|
|
} |
|
|
} |
|
|
required: Optional[List[str]] = ["collection_name"] |
|
|
|
|
|
def __init__(self, database: MongoDBDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
|
|
|
def __call__(self, collection_name: str, filter: str = "{}", projection: str = "{}", |
|
|
sort: str = None, limit: int = 0, skip: int = 0) -> Dict[str, Any]: |
|
|
"""Find documents in MongoDB collection""" |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "MongoDB database not initialized", "data": None} |
|
|
|
|
|
|
|
|
import json |
|
|
filter_dict = json.loads(filter) if filter else {} |
|
|
projection_dict = json.loads(projection) if projection else {} |
|
|
sort_dict = json.loads(sort) if sort else None |
|
|
|
|
|
|
|
|
query = { |
|
|
"filter": filter_dict, |
|
|
"projection": projection_dict, |
|
|
"limit": limit, |
|
|
"skip": skip |
|
|
} |
|
|
|
|
|
if sort_dict: |
|
|
query["sort"] = sort_dict |
|
|
|
|
|
|
|
|
result = self.database.execute_query( |
|
|
query=query, |
|
|
query_type=QueryType.SELECT, |
|
|
collection_name=collection_name |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
logger.info(f"Successfully found documents in collection {collection_name}") |
|
|
else: |
|
|
logger.error(f"Failed to find documents: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in mongodb_find tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
|
|
|
class MongoDBUpdateTool(Tool): |
|
|
name: str = "mongodb_update" |
|
|
description: str = "Update documents in a MongoDB collection" |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"collection_name": { |
|
|
"type": "string", |
|
|
"description": "Collection name to update" |
|
|
}, |
|
|
"filter": { |
|
|
"type": "string", |
|
|
"description": "Filter to match documents to update (JSON string)" |
|
|
}, |
|
|
"update": { |
|
|
"type": "string", |
|
|
"description": "Update operations (JSON string, e.g., '{\"$set\": {\"status\": \"active\"}}')" |
|
|
}, |
|
|
"upsert": { |
|
|
"type": "boolean", |
|
|
"description": "Create document if it doesn't exist" |
|
|
}, |
|
|
"multi": { |
|
|
"type": "boolean", |
|
|
"description": "Update multiple documents (default: false)" |
|
|
} |
|
|
} |
|
|
required: Optional[List[str]] = ["collection_name", "filter", "update"] |
|
|
|
|
|
def __init__(self, database: MongoDBDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
|
|
|
def __call__(self, collection_name: str, filter: str, update: str, |
|
|
upsert: bool = False, multi: bool = False) -> Dict[str, Any]: |
|
|
"""Update documents in MongoDB collection""" |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "MongoDB database not initialized", "data": None} |
|
|
|
|
|
|
|
|
import json |
|
|
filter_dict = json.loads(filter) |
|
|
update_dict = json.loads(update) |
|
|
|
|
|
|
|
|
query = { |
|
|
"filter": filter_dict, |
|
|
"update": update_dict, |
|
|
"upsert": upsert, |
|
|
"multi": multi |
|
|
} |
|
|
|
|
|
|
|
|
result = self.database.execute_query( |
|
|
query=query, |
|
|
query_type=QueryType.UPDATE, |
|
|
collection_name=collection_name |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
logger.info(f"Successfully updated documents in collection {collection_name}") |
|
|
else: |
|
|
logger.error(f"Failed to update documents: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in mongodb_update tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
|
|
|
class MongoDBDeleteTool(Tool): |
|
|
name: str = "mongodb_delete" |
|
|
description: str = "Delete documents from a MongoDB collection" |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"collection_name": { |
|
|
"type": "string", |
|
|
"description": "Collection name to delete from" |
|
|
}, |
|
|
"filter": { |
|
|
"type": "string", |
|
|
"description": "Filter to match documents to delete (JSON string)" |
|
|
}, |
|
|
"multi": { |
|
|
"type": "boolean", |
|
|
"description": "Delete multiple documents (default: false)" |
|
|
} |
|
|
} |
|
|
required: Optional[List[str]] = ["collection_name", "filter"] |
|
|
|
|
|
def __init__(self, database: MongoDBDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
|
|
|
def __call__(self, collection_name: str, filter: str, multi: bool = False) -> Dict[str, Any]: |
|
|
"""Delete documents from MongoDB collection""" |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "MongoDB database not initialized", "data": None} |
|
|
|
|
|
|
|
|
import json |
|
|
filter_dict = json.loads(filter) |
|
|
|
|
|
|
|
|
query = { |
|
|
"filter": filter_dict, |
|
|
"multi": multi |
|
|
} |
|
|
|
|
|
|
|
|
result = self.database.execute_query( |
|
|
query=query, |
|
|
query_type=QueryType.DELETE, |
|
|
collection_name=collection_name |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
logger.info(f"Successfully deleted documents from collection {collection_name}") |
|
|
else: |
|
|
logger.error(f"Failed to delete documents: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in mongodb_delete tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
|
|
|
class MongoDBInfoTool(Tool): |
|
|
name: str = "mongodb_info" |
|
|
description: str = "Get MongoDB database and collection information" |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"info_type": { |
|
|
"type": "string", |
|
|
"description": "Type of information (database, collections, collection, schema, capabilities)" |
|
|
}, |
|
|
"collection_name": { |
|
|
"type": "string", |
|
|
"description": "Collection name for collection-specific info (optional)" |
|
|
} |
|
|
} |
|
|
required: Optional[List[str]] = [] |
|
|
|
|
|
def __init__(self, database: MongoDBDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
|
|
|
def __call__(self, info_type: str = "database", collection_name: str = None) -> Dict[str, Any]: |
|
|
"""Get MongoDB information""" |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "MongoDB database not initialized", "data": None} |
|
|
|
|
|
info_type = info_type.lower() |
|
|
|
|
|
if info_type == "database": |
|
|
result = self.database.get_database_info() |
|
|
elif info_type == "collections": |
|
|
collections = self.database.list_collections() |
|
|
result = {"success": True, "data": collections, "collection_count": len(collections)} |
|
|
elif info_type == "collection" and collection_name: |
|
|
result = self.database.get_collection_info(collection_name) |
|
|
elif info_type == "schema": |
|
|
result = self.database.get_schema(collection_name) |
|
|
elif info_type == "capabilities": |
|
|
result = {"success": True, "data": self.database.get_capabilities()} |
|
|
else: |
|
|
return {"success": False, "error": f"Invalid info type: {info_type}", "data": None} |
|
|
|
|
|
if result["success"]: |
|
|
logger.info(f"Successfully retrieved {info_type} information") |
|
|
else: |
|
|
logger.error(f"Failed to retrieve {info_type} information: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in mongodb_info tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
|
|
|
class MongoDBToolkit(Toolkit): |
|
|
""" |
|
|
MongoDB-specific toolkit with simplified design. |
|
|
Automatically handles remote, local file-based, or new database creation. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
name: str = "MongoDBToolkit", |
|
|
connection_string: str = None, |
|
|
database_name: str = None, |
|
|
local_path: str = None, |
|
|
auto_save: bool = True, |
|
|
read_only: bool = False, |
|
|
**kwargs): |
|
|
""" |
|
|
Initialize the MongoDB toolkit. |
|
|
|
|
|
Args: |
|
|
name: Name of the toolkit |
|
|
connection_string: MongoDB connection string (for remote/existing) |
|
|
database_name: Name of the database to use |
|
|
local_path: Path for local file-based database |
|
|
auto_save: Automatically save changes to local files |
|
|
read_only: If True, only read operations are allowed (no insert, update, delete) |
|
|
**kwargs: Additional connection parameters |
|
|
""" |
|
|
|
|
|
database = MongoDBDatabase( |
|
|
connection_string=connection_string, |
|
|
database_name=database_name, |
|
|
local_path=local_path, |
|
|
auto_save=auto_save, |
|
|
read_only=read_only, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if read_only: |
|
|
|
|
|
tools = [ |
|
|
MongoDBExecuteQueryTool(database=database), |
|
|
MongoDBFindTool(database=database), |
|
|
MongoDBInfoTool(database=database) |
|
|
] |
|
|
else: |
|
|
|
|
|
tools = [ |
|
|
MongoDBExecuteQueryTool(database=database), |
|
|
MongoDBFindTool(database=database), |
|
|
MongoDBUpdateTool(database=database), |
|
|
MongoDBDeleteTool(database=database), |
|
|
MongoDBInfoTool(database=database) |
|
|
] |
|
|
|
|
|
|
|
|
super().__init__(name=name, tools=tools) |
|
|
|
|
|
|
|
|
self.database = database |
|
|
self.connection_string = connection_string |
|
|
self.database_name = database_name |
|
|
self.local_path = local_path |
|
|
self.auto_save = auto_save |
|
|
|
|
|
|
|
|
import atexit |
|
|
atexit.register(self._cleanup) |
|
|
|
|
|
def _cleanup(self): |
|
|
"""Cleanup function called when program exits""" |
|
|
try: |
|
|
if self.database.is_local_database and self.database.auto_save: |
|
|
logger.info("Auto-saving local database before exit...") |
|
|
collections = self.database.list_collections() |
|
|
for collection_name in collections: |
|
|
self.database._save_collection_to_file(collection_name) |
|
|
|
|
|
if self.database: |
|
|
self.database.disconnect() |
|
|
logger.info("Disconnected from MongoDB database") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error during cleanup: {str(e)}") |
|
|
|
|
|
def get_capabilities(self) -> Dict[str, Any]: |
|
|
"""Get MongoDB-specific capabilities""" |
|
|
if self.database: |
|
|
capabilities = self.database.get_capabilities() |
|
|
capabilities.update({ |
|
|
"is_local_database": self.database.is_local_database, |
|
|
"local_path": str(self.database.local_path) if self.database.local_path else None, |
|
|
"auto_save": self.database.auto_save, |
|
|
"read_only": self.database.read_only |
|
|
}) |
|
|
return capabilities |
|
|
return {"error": "MongoDB database not initialized"} |
|
|
|
|
|
def connect(self) -> bool: |
|
|
"""Connect to MongoDB""" |
|
|
return self.database.connect() if self.database else False |
|
|
|
|
|
def disconnect(self) -> bool: |
|
|
"""Disconnect from MongoDB""" |
|
|
return self.database.disconnect() if self.database else False |
|
|
|
|
|
def test_connection(self) -> bool: |
|
|
"""Test MongoDB connection""" |
|
|
return self.database.test_connection() if self.database else False |
|
|
|
|
|
def get_database(self) -> MongoDBDatabase: |
|
|
"""Get the underlying MongoDB database instance""" |
|
|
return self.database |
|
|
|
|
|
def get_local_info(self) -> Dict[str, Any]: |
|
|
"""Get information about local database setup""" |
|
|
return { |
|
|
"is_local_database": self.database.is_local_database, |
|
|
"local_path": str(self.database.local_path) if self.database.local_path else None, |
|
|
"auto_save": self.database.auto_save, |
|
|
"read_only": self.database.read_only, |
|
|
"database_name": self.database_name, |
|
|
"connection_string": self.connection_string |
|
|
} if self.database else {"error": "Database not initialized"} |