import unittest from src.retrieval import DocumentRetriever import os class TestRetrieval(unittest.TestCase): @classmethod def setUpClass(cls): # Create test document test_doc = [{ "content": "Rupeia offers three investment plans: Basic, Plus, and Premium.", "category": "investments" }] # Write to temporary file os.makedirs('data', exist_ok=True) with open('data/test_document.json', 'w') as f: json.dump(test_doc, f) # Initialize retriever with test document cls.retriever = DocumentRetriever() cls.retriever.documents = test_doc cls.retriever.doc_embeddings = cls.retriever._embed_documents() def test_retrieve_relevant_documents(self): query = "What investment options are available?" results = self.retriever.retrieve(query) self.assertTrue(len(results) > 0) self.assertIn("investment plans", results[0]['content']) def test_empty_query(self): results = self.retriever.retrieve("") self.assertEqual(len(results), 3) # default top_k if __name__ == '__main__': unittest.main()