support-system / tests /test_retrieval.py
ayush2917's picture
Update tests/test_retrieval.py
cb557ca verified
import unittest
from src.retrieval import DocumentRetriever
import os
class TestRetrieval(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create test document
test_doc = [{
"content": "Rupeia offers three investment plans: Basic, Plus, and Premium.",
"category": "investments"
}]
# Write to temporary file
os.makedirs('data', exist_ok=True)
with open('data/test_document.json', 'w') as f:
json.dump(test_doc, f)
# Initialize retriever with test document
cls.retriever = DocumentRetriever()
cls.retriever.documents = test_doc
cls.retriever.doc_embeddings = cls.retriever._embed_documents()
def test_retrieve_relevant_documents(self):
query = "What investment options are available?"
results = self.retriever.retrieve(query)
self.assertTrue(len(results) > 0)
self.assertIn("investment plans", results[0]['content'])
def test_empty_query(self):
results = self.retriever.retrieve("")
self.assertEqual(len(results), 3) # default top_k
if __name__ == '__main__':
unittest.main()