File size: 3,602 Bytes
1fed057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import unittest
from unittest.mock import patch, MagicMock
from src.database.connection import DatabaseConnection
from src.utils.constants import Constants

class TestDatabaseConnection(unittest.TestCase):
    """
    Test cases for the DatabaseConnection class.
    """
    
    @patch('src.database.connection.load_dotenv')
    @patch('src.database.connection.os.getenv')
    @patch('src.database.connection.MongoClient')
    def test_init_success(self, mock_mongo_client, mock_getenv, mock_load_dotenv):
        """
        Test successful initialization of DatabaseConnection.
        """
        # Setup mocks
        mock_getenv.return_value = "mongodb://test_uri"
        mock_client = MagicMock()
        mock_db = MagicMock()
        mock_collection = MagicMock()
        mock_client.__getitem__.return_value = mock_db
        mock_db.__getitem__.return_value = mock_collection
        mock_mongo_client.return_value = mock_client
        
        # Create instance
        db_connection = DatabaseConnection()
        
        # Assertions
        mock_load_dotenv.assert_called_once()
        mock_getenv.assert_called_once_with("MONGODB_URI")
        mock_mongo_client.assert_called_once_with("mongodb://test_uri")
        mock_client.__getitem__.assert_called_once_with(Constants.DB_NAME)
        mock_db.__getitem__.assert_called_once_with(Constants.COLLECTION_NAME)
        self.assertEqual(db_connection.collection, mock_collection)
    
    @patch('src.database.connection.load_dotenv')
    @patch('src.database.connection.os.getenv')
    def test_init_missing_uri(self, mock_getenv, mock_load_dotenv):
        """
        Test initialization with missing MongoDB URI.
        """
        # Setup mocks
        mock_getenv.return_value = None
        
        # Assert that ValueError is raised
        with self.assertRaises(ValueError) as context:
            DatabaseConnection()
        
        self.assertEqual(str(context.exception), "MONGODB_URI environment variable is not set")
    
    @patch('src.database.connection.load_dotenv')
    @patch('src.database.connection.os.getenv')
    @patch('src.database.connection.MongoClient')
    def test_get_collection(self, mock_mongo_client, mock_getenv, mock_load_dotenv):
        """
        Test get_collection method.
        """
        # Setup mocks
        mock_getenv.return_value = "mongodb://test_uri"
        mock_client = MagicMock()
        mock_db = MagicMock()
        mock_collection = MagicMock()
        mock_client.__getitem__.return_value = mock_db
        mock_db.__getitem__.return_value = mock_collection
        mock_mongo_client.return_value = mock_client
        
        # Create instance and call method
        db_connection = DatabaseConnection()
        result = db_connection.get_collection()
        
        # Assertions
        self.assertEqual(result, mock_collection)
    
    @patch('src.database.connection.load_dotenv')
    @patch('src.database.connection.os.getenv')
    @patch('src.database.connection.MongoClient')
    def test_close_connection(self, mock_mongo_client, mock_getenv, mock_load_dotenv):
        """
        Test close_connection method.
        """
        # Setup mocks
        mock_getenv.return_value = "mongodb://test_uri"
        mock_client = MagicMock()
        mock_mongo_client.return_value = mock_client
        
        # Create instance and call method
        db_connection = DatabaseConnection()
        db_connection.close_connection()
        
        # Assertions
        mock_client.close.assert_called_once()

if __name__ == '__main__':
    unittest.main()