cevheri's picture
feat: enhanced repository usages
3ef184a
from abc import ABC, abstractmethod
from typing import Optional, Any
from motor.motor_asyncio import AsyncIOMotorClient
from unittest.mock import AsyncMock
class DatabaseClient(ABC):
"""Abstract base class for database clients"""
@abstractmethod
async def connect(self) -> None:
"""Connect to the database"""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Disconnect from the database"""
pass
@abstractmethod
async def get_database(self) -> Any:
"""Get database instance"""
pass
class MongoClient(DatabaseClient):
"""Real MongoDB client implementation"""
def __init__(self, connection_string: str):
self.connection_string = connection_string
self._client: Optional[AsyncIOMotorClient] = None
async def connect(self) -> None:
self._client = AsyncIOMotorClient(self.connection_string)
async def disconnect(self) -> None:
if self._client:
self._client.close()
async def get_database(self) -> AsyncIOMotorClient:
if not self._client:
await self.connect()
return self._client
class MockMongoClient(DatabaseClient):
"""Mock MongoDB client for testing"""
def __init__(self):
self._client = AsyncMock()
async def connect(self) -> None:
pass
async def disconnect(self) -> None:
pass
async def get_database(self) -> AsyncMock:
return self._client
class DatabaseClientFactory:
"""Factory for creating database clients"""
_instance: Optional["DatabaseClientFactory"] = None
_client: Optional[DatabaseClient] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def create_client(cls, db_type: str, connection_string: Optional[str] = None) -> DatabaseClient:
"""
Create a database client based on the database type
Args:
db_type: Type of database ('mongodb' or 'mock')
connection_string: Connection string for the database
Returns:
DatabaseClient: Instance of the appropriate database client
"""
if cls._client is None:
if db_type.lower() == "mongodb":
if not connection_string:
raise ValueError("Connection string is required for MongoDB")
cls._client = MongoClient(connection_string)
elif db_type.lower() == "mock":
cls._client = MockMongoClient()
else:
raise ValueError(f"Unsupported database type: {db_type}")
return cls._client
@classmethod
async def get_client(cls) -> DatabaseClient:
"""
Get the current database client instance
Returns:
DatabaseClient: Current database client instance
"""
if cls._client is None:
raise RuntimeError("Database client not initialized")
return cls._client
@classmethod
async def reset_client(cls) -> None:
"""Reset the current database client instance"""
if cls._client:
await cls._client.disconnect()
cls._client = None