Doc-Weather-Bot / tests /test_rag_agent.py
AmritSbisht's picture
Upload 25 files
f974658 verified
import unittest
from unittest.mock import patch, MagicMock
from agents.rag_agent import RAGAgent
from langchain.schema import Document
class TestRAGAgent(unittest.TestCase):
def setUp(self):
# Create a mock for vector store
self.vector_store_patch = patch('agents.rag_agent.VectorStore')
self.mock_vector_store_class = self.vector_store_patch.start()
self.mock_vector_store = self.mock_vector_store_class.return_value
# Create a mock for LLM
self.llm_patch = patch('agents.rag_agent.ChatGoogleGenerativeAI')
self.mock_llm_class = self.llm_patch.start()
self.mock_llm = self.mock_llm_class.return_value
# Sample documents
self.sample_docs = [
Document(page_content="This is a test document about AI.", metadata={"source": "test1.pdf"}),
Document(page_content="LangChain is a framework for LLM applications.", metadata={"source": "test2.pdf"})
]
# Initialize agent
self.agent = RAGAgent(api_key="test_api_key")
def tearDown(self):
self.vector_store_patch.stop()
self.llm_patch.stop()
def test_retrieve_context(self):
# Configure mock
self.mock_vector_store.similarity_search.return_value = self.sample_docs
# Call the method
result = self.agent.retrieve_context("What is LangChain?")
# Assertions
self.assertEqual(result, self.sample_docs)
self.mock_vector_store.similarity_search.assert_called_once()
def test_get_rag_response_with_context(self):
# Mock similarity_search to return 2 documents
self.mock_vector_store.similarity_search.return_value = self.sample_docs
# Mock rag_chain
mock_chain = MagicMock()
mock_chain.invoke.return_value.content = "LangChain is a framework for building LLM applications."
self.agent.rag_chain = mock_chain
# Call the method
result = self.agent.get_rag_response("What is LangChain?")
# Assertions
self.assertEqual(result["response"], "LangChain is a framework for building LLM applications.")
self.assertEqual(len(result["context"]), 2)
self.assertEqual(result["context"][0]["page_content"], "This is a test document about AI.")
def test_get_rag_response_no_context(self):
# Configure mock to return empty list
self.mock_vector_store.similarity_search.return_value = []
# Call the method
result = self.agent.get_rag_response("What is LangChain?")
# Assertions
self.assertEqual(len(result["context"]), 0)
self.assertIn("couldn't find any relevant information", result["response"])