prediqai / tests /test_database.py
ganesh-vilje's picture
Deploy to Hugging Face Main
f8f02c0
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from app.database.mongodb import AsyncMongoDB, async_mongodb, get_async_database, get_async_collection
@pytest.fixture
def clean_db():
# Make sure we reset the state before each test
async_mongodb.client = None
async_mongodb.database = None
yield async_mongodb
async_mongodb.client = None
async_mongodb.database = None
def test_init_missing_uri():
with patch("app.core.config.Config.MONGODB_URI", ""):
with pytest.raises(ValueError, match="MONGODB_URI environment variable is required"):
AsyncMongoDB()
@pytest.mark.asyncio
async def test_connect_success(clean_db):
mock_client = MagicMock()
mock_db = MagicMock()
mock_client.__getitem__.return_value = mock_db
mock_client.admin.command = AsyncMock(return_value={"ok": 1})
mock_client.list_database_names = AsyncMock(return_value=["test1", "test2"])
with patch("app.database.mongodb.AsyncIOMotorClient", return_value=mock_client):
success = await clean_db.connect()
assert success is True
assert clean_db.client is not None
assert clean_db.database is not None
mock_client.admin.command.assert_called_once_with('ping')
@pytest.mark.asyncio
async def test_connect_failure(clean_db):
mock_client = MagicMock()
mock_client.admin.command = AsyncMock(side_effect=Exception("Connection Refused"))
with patch("app.database.mongodb.AsyncIOMotorClient", return_value=mock_client):
success = await clean_db.connect()
assert success is False
@pytest.mark.asyncio
async def test_disconnect(clean_db):
clean_db.client = MagicMock()
await clean_db.disconnect()
clean_db.client.close.assert_called_once()
def test_get_collection_not_connected(clean_db):
with pytest.raises(ValueError, match="Database not connected. Call await connect.. first."):
clean_db.get_collection("users")
def test_get_collection_connected(clean_db):
mock_db = MagicMock()
mock_collection = MagicMock()
mock_db.__getitem__.return_value = mock_collection
clean_db.database = mock_db
col = clean_db.get_collection("users")
assert col is mock_collection
mock_db.__getitem__.assert_called_once_with("users")
@pytest.mark.asyncio
async def test_global_get_async_database_not_connected(clean_db):
with pytest.raises(ValueError, match="Async database not connected"):
await get_async_database()
@pytest.mark.asyncio
async def test_global_get_async_database_connected(clean_db):
clean_db.database = MagicMock()
db = await get_async_database()
assert db is clean_db.database
def test_global_get_async_collection(clean_db):
mock_db = MagicMock()
clean_db.database = mock_db
col = get_async_collection("users")
mock_db.__getitem__.assert_called_once_with("users")