|
|
from abc import ABC, abstractmethod |
|
|
from typing import Optional, Any |
|
|
from motor.motor_asyncio import AsyncIOMotorClient |
|
|
from unittest.mock import AsyncMock |
|
|
|
|
|
|
|
|
class DatabaseClient(ABC): |
|
|
"""Abstract base class for database clients""" |
|
|
|
|
|
@abstractmethod |
|
|
async def connect(self) -> None: |
|
|
"""Connect to the database""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def disconnect(self) -> None: |
|
|
"""Disconnect from the database""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_database(self) -> Any: |
|
|
"""Get database instance""" |
|
|
pass |
|
|
|
|
|
|
|
|
class MongoClient(DatabaseClient): |
|
|
"""Real MongoDB client implementation""" |
|
|
|
|
|
def __init__(self, connection_string: str): |
|
|
self.connection_string = connection_string |
|
|
self._client: Optional[AsyncIOMotorClient] = None |
|
|
|
|
|
async def connect(self) -> None: |
|
|
self._client = AsyncIOMotorClient(self.connection_string) |
|
|
|
|
|
async def disconnect(self) -> None: |
|
|
if self._client: |
|
|
self._client.close() |
|
|
|
|
|
async def get_database(self) -> AsyncIOMotorClient: |
|
|
if not self._client: |
|
|
await self.connect() |
|
|
return self._client |
|
|
|
|
|
|
|
|
class MockMongoClient(DatabaseClient): |
|
|
"""Mock MongoDB client for testing""" |
|
|
|
|
|
def __init__(self): |
|
|
self._client = AsyncMock() |
|
|
|
|
|
async def connect(self) -> None: |
|
|
pass |
|
|
|
|
|
async def disconnect(self) -> None: |
|
|
pass |
|
|
|
|
|
async def get_database(self) -> AsyncMock: |
|
|
return self._client |
|
|
|
|
|
|
|
|
class DatabaseClientFactory: |
|
|
"""Factory for creating database clients""" |
|
|
|
|
|
_instance: Optional["DatabaseClientFactory"] = None |
|
|
_client: Optional[DatabaseClient] = None |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super().__new__(cls) |
|
|
return cls._instance |
|
|
|
|
|
@classmethod |
|
|
def create_client(cls, db_type: str, connection_string: Optional[str] = None) -> DatabaseClient: |
|
|
""" |
|
|
Create a database client based on the database type |
|
|
|
|
|
Args: |
|
|
db_type: Type of database ('mongodb' or 'mock') |
|
|
connection_string: Connection string for the database |
|
|
|
|
|
Returns: |
|
|
DatabaseClient: Instance of the appropriate database client |
|
|
""" |
|
|
if cls._client is None: |
|
|
if db_type.lower() == "mongodb": |
|
|
if not connection_string: |
|
|
raise ValueError("Connection string is required for MongoDB") |
|
|
cls._client = MongoClient(connection_string) |
|
|
elif db_type.lower() == "mock": |
|
|
cls._client = MockMongoClient() |
|
|
else: |
|
|
raise ValueError(f"Unsupported database type: {db_type}") |
|
|
|
|
|
return cls._client |
|
|
|
|
|
@classmethod |
|
|
async def get_client(cls) -> DatabaseClient: |
|
|
""" |
|
|
Get the current database client instance |
|
|
|
|
|
Returns: |
|
|
DatabaseClient: Current database client instance |
|
|
""" |
|
|
if cls._client is None: |
|
|
raise RuntimeError("Database client not initialized") |
|
|
return cls._client |
|
|
|
|
|
@classmethod |
|
|
async def reset_client(cls) -> None: |
|
|
"""Reset the current database client instance""" |
|
|
if cls._client: |
|
|
await cls._client.disconnect() |
|
|
cls._client = None |
|
|
|