File size: 2,859 Bytes
f974658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import unittest
from unittest.mock import patch, MagicMock
from agents.rag_agent import RAGAgent
from langchain.schema import Document

class TestRAGAgent(unittest.TestCase):
    
    def setUp(self):
        # Create a mock for vector store
        self.vector_store_patch = patch('agents.rag_agent.VectorStore')
        self.mock_vector_store_class = self.vector_store_patch.start()
        self.mock_vector_store = self.mock_vector_store_class.return_value
        
        # Create a mock for LLM
        self.llm_patch = patch('agents.rag_agent.ChatGoogleGenerativeAI')
        self.mock_llm_class = self.llm_patch.start()
        self.mock_llm = self.mock_llm_class.return_value
        
        # Sample documents
        self.sample_docs = [
            Document(page_content="This is a test document about AI.", metadata={"source": "test1.pdf"}),
            Document(page_content="LangChain is a framework for LLM applications.", metadata={"source": "test2.pdf"})
        ]
        
        # Initialize agent
        self.agent = RAGAgent(api_key="test_api_key")
    
    def tearDown(self):
        self.vector_store_patch.stop()
        self.llm_patch.stop()
    
    def test_retrieve_context(self):
        # Configure mock
        self.mock_vector_store.similarity_search.return_value = self.sample_docs
        
        # Call the method
        result = self.agent.retrieve_context("What is LangChain?")
        
        # Assertions
        self.assertEqual(result, self.sample_docs)
        self.mock_vector_store.similarity_search.assert_called_once()
    
    def test_get_rag_response_with_context(self):
        # Mock similarity_search to return 2 documents
        self.mock_vector_store.similarity_search.return_value = self.sample_docs

        # Mock rag_chain
        mock_chain = MagicMock()
        mock_chain.invoke.return_value.content = "LangChain is a framework for building LLM applications."
        self.agent.rag_chain = mock_chain

        # Call the method
        result = self.agent.get_rag_response("What is LangChain?")
        
        # Assertions
        self.assertEqual(result["response"], "LangChain is a framework for building LLM applications.")
        self.assertEqual(len(result["context"]), 2)
        self.assertEqual(result["context"][0]["page_content"], "This is a test document about AI.")

    
    def test_get_rag_response_no_context(self):
        # Configure mock to return empty list
        self.mock_vector_store.similarity_search.return_value = []
        
        # Call the method
        result = self.agent.get_rag_response("What is LangChain?")
        
        # Assertions
        self.assertEqual(len(result["context"]), 0)
        self.assertIn("couldn't find any relevant information", result["response"])