Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict, Any, Optional | |
| from pymongo import MongoClient | |
| from pymongo.errors import ( | |
| ConnectionFailure, | |
| OperationFailure, | |
| ServerSelectionTimeoutError, | |
| InvalidName | |
| ) | |
| from dotenv import load_dotenv | |
| class DatabaseError(Exception): | |
| """Base class for database operation errors""" | |
| pass | |
| class ConnectionError(DatabaseError): | |
| """Error when connecting to MongoDB Atlas""" | |
| pass | |
| class OperationError(DatabaseError): | |
| """Error during database operations""" | |
| pass | |
| class DatabaseUtils: | |
| """Utility class for MongoDB Atlas database operations | |
| This class provides methods to interact with MongoDB Atlas databases and collections, | |
| including listing databases, collections, and retrieving collection information. | |
| Attributes: | |
| atlas_uri (str): MongoDB Atlas connection string | |
| client (MongoClient): MongoDB client instance | |
| """ | |
| def __init__(self): | |
| """Initialize DatabaseUtils with MongoDB Atlas connection | |
| Raises: | |
| ConnectionError: If unable to connect to MongoDB Atlas | |
| ValueError: If ATLAS_URI environment variable is not set | |
| """ | |
| # Load environment variables | |
| load_dotenv() | |
| self.atlas_uri = os.getenv("ATLAS_URI") | |
| if not self.atlas_uri: | |
| raise ValueError("ATLAS_URI environment variable is not set") | |
| try: | |
| self.client = MongoClient(self.atlas_uri) | |
| # Test connection | |
| self.client.admin.command('ping') | |
| except (ConnectionFailure, ServerSelectionTimeoutError) as e: | |
| raise ConnectionError(f"Failed to connect to MongoDB Atlas: {str(e)}") | |
| def get_databases(self) -> List[str]: | |
| """Get list of all databases in Atlas cluster | |
| Returns: | |
| List[str]: List of database names | |
| Raises: | |
| OperationError: If unable to list databases | |
| """ | |
| try: | |
| return self.client.list_database_names() | |
| except OperationFailure as e: | |
| raise OperationError(f"Failed to list databases: {str(e)}") | |
| def get_collections(self, db_name: str) -> List[str]: | |
| """Get list of collections in a database | |
| Args: | |
| db_name (str): Name of the database | |
| Returns: | |
| List[str]: List of collection names | |
| Raises: | |
| OperationError: If unable to list collections | |
| ValueError: If db_name is empty or invalid | |
| """ | |
| if not db_name or not isinstance(db_name, str): | |
| raise ValueError("Database name must be a non-empty string") | |
| try: | |
| db = self.client[db_name] | |
| return db.list_collection_names() | |
| except (OperationFailure, InvalidName) as e: | |
| raise OperationError(f"Failed to list collections for database '{db_name}': {str(e)}") | |
| def get_collection_info(self, db_name: str, collection_name: str) -> Dict[str, Any]: | |
| """Get information about a collection including document count and sample document | |
| Args: | |
| db_name (str): Name of the database | |
| collection_name (str): Name of the collection | |
| Returns: | |
| Dict[str, Any]: Dictionary containing collection information: | |
| - count: Number of documents in collection | |
| - sample: Sample document from collection (if exists) | |
| Raises: | |
| OperationError: If unable to get collection information | |
| ValueError: If db_name or collection_name is empty or invalid | |
| """ | |
| if not db_name or not isinstance(db_name, str): | |
| raise ValueError("Database name must be a non-empty string") | |
| if not collection_name or not isinstance(collection_name, str): | |
| raise ValueError("Collection name must be a non-empty string") | |
| try: | |
| db = self.client[db_name] | |
| collection = db[collection_name] | |
| return { | |
| 'count': collection.count_documents({}), | |
| 'sample': collection.find_one() | |
| } | |
| except (OperationFailure, InvalidName) as e: | |
| raise OperationError( | |
| f"Failed to get info for collection '{collection_name}' " | |
| f"in database '{db_name}': {str(e)}" | |
| ) | |
| def get_field_names(self, db_name: str, collection_name: str) -> List[str]: | |
| """Get list of fields in a collection based on sample document | |
| Args: | |
| db_name (str): Name of the database | |
| collection_name (str): Name of the collection | |
| Returns: | |
| List[str]: List of field names (excluding _id and embedding fields) | |
| Raises: | |
| OperationError: If unable to get field names | |
| ValueError: If db_name or collection_name is empty or invalid | |
| """ | |
| try: | |
| info = self.get_collection_info(db_name, collection_name) | |
| sample = info.get('sample', {}) | |
| if sample: | |
| # Get all field names except _id and any existing embedding fields | |
| return [field for field in sample.keys() | |
| if field != '_id' and not field.endswith('_embedding')] | |
| return [] | |
| except DatabaseError as e: | |
| raise OperationError( | |
| f"Failed to get field names for collection '{collection_name}' " | |
| f"in database '{db_name}': {str(e)}" | |
| ) | |
| def close(self): | |
| """Close MongoDB connection safely""" | |
| if hasattr(self, 'client'): | |
| self.client.close() | |