Spaces:
Build error
Build error
| """ | |
| Unit tests for vector_store module | |
| Tests ChromaDB vector store operations | |
| """ | |
| import unittest | |
| from unittest.mock import MagicMock, Mock, patch | |
| from src.vector_store import (_calculate_similarity_impl, | |
| _generate_embeddings_impl, _process_context_impl, | |
| calculate_similarity, generate_embeddings, | |
| process_context) | |
| class TestVectorStore(unittest.TestCase): | |
| """Test cases for vector_store module""" | |
| def setUp(self): | |
| """Set up test fixtures""" | |
| # Mock document | |
| self.mock_doc = Mock() | |
| self.mock_doc.page_content = "Test document content" | |
| self.mock_doc.metadata = { | |
| "id": "KB001", | |
| "question": "Test question?", | |
| "content": "Test answer.", | |
| "section": "Test", | |
| } | |
| self.mock_documents = [self.mock_doc] | |
| def test_generate_embeddings_impl(self, mock_genai_client): | |
| """Test internal embedding generation implementation""" | |
| # Mock embeddings for query and document | |
| mock_query_embedding = Mock() | |
| mock_query_embedding.values = [0.1, 0.2, 0.3] | |
| mock_doc_embedding = Mock() | |
| mock_doc_embedding.values = [0.2, 0.3, 0.4] | |
| # Setup side effect for multiple calls | |
| call_count = [0] | |
| def embed_side_effect(*args, **kwargs): | |
| call_count[0] += 1 | |
| mock_response = Mock() | |
| if call_count[0] == 1: | |
| mock_response.embeddings = [mock_query_embedding] | |
| else: | |
| mock_response.embeddings = [mock_doc_embedding] | |
| return mock_response | |
| mock_genai_client.models.embed_content.side_effect = embed_side_effect | |
| query = "Test query" | |
| query_emb, doc_embs = _generate_embeddings_impl(query, self.mock_documents) | |
| # Verify embed_content was called correctly | |
| self.assertEqual(mock_genai_client.models.embed_content.call_count, 2) | |
| # Verify embeddings | |
| self.assertEqual(query_emb, [0.1, 0.2, 0.3]) | |
| self.assertEqual(len(doc_embs), 1) | |
| self.assertEqual(doc_embs[0], [0.2, 0.3, 0.4]) | |
| def test_generate_embeddings_with_timer(self, mock_genai_client): | |
| """Test embedding generation with timer""" | |
| # Mock embeddings | |
| mock_embedding = Mock() | |
| mock_embedding.values = [0.1, 0.2, 0.3] | |
| mock_response = Mock() | |
| mock_response.embeddings = [mock_embedding] | |
| mock_genai_client.models.embed_content.return_value = mock_response | |
| mock_timer = Mock() | |
| mock_timer.time_step = MagicMock() | |
| mock_timer.time_step.return_value.__enter__ = Mock() | |
| mock_timer.time_step.return_value.__exit__ = Mock() | |
| generate_embeddings("Test", self.mock_documents, timer=mock_timer) | |
| # Verify timer was used | |
| mock_timer.time_step.assert_called_once_with("embedding_generation") | |
| def test_generate_embeddings_multiple_docs(self, mock_genai_client): | |
| """Test embedding generation with multiple documents""" | |
| # Create multiple mock documents | |
| mock_doc2 = Mock() | |
| mock_doc2.page_content = "Second document" | |
| docs = [self.mock_doc, mock_doc2] | |
| # Mock embeddings | |
| mock_query_emb = Mock() | |
| mock_query_emb.values = [0.1, 0.2, 0.3] | |
| mock_doc1_emb = Mock() | |
| mock_doc1_emb.values = [0.2, 0.3, 0.4] | |
| mock_doc2_emb = Mock() | |
| mock_doc2_emb.values = [0.3, 0.4, 0.5] | |
| # First call for query, second call for both docs | |
| call_count = [0] | |
| def embed_side_effect(*args, **kwargs): | |
| call_count[0] += 1 | |
| mock_response = Mock() | |
| if call_count[0] == 1: | |
| mock_response.embeddings = [mock_query_emb] | |
| else: | |
| mock_response.embeddings = [mock_doc1_emb, mock_doc2_emb] | |
| return mock_response | |
| mock_genai_client.models.embed_content.side_effect = embed_side_effect | |
| query_emb, doc_embs = _generate_embeddings_impl("Test", docs) | |
| # Should have 2 doc embeddings | |
| self.assertEqual(len(doc_embs), 2) | |
| self.assertEqual(mock_genai_client.models.embed_content.call_count, 2) | |
| def test_calculate_similarity_impl(self): | |
| """Test internal similarity calculation implementation""" | |
| query_embedding = [1.0, 0.0, 0.0] | |
| doc_embeddings = [ | |
| [1.0, 0.0, 0.0], # Same as query - score should be ~1.0 | |
| [0.0, 1.0, 0.0], # Orthogonal - score should be ~0.0 | |
| [0.5, 0.5, 0.0], # Partial similarity | |
| ] | |
| scores = _calculate_similarity_impl(query_embedding, doc_embeddings) | |
| # Check scores | |
| self.assertEqual(len(scores), 3) | |
| self.assertAlmostEqual(scores[0], 1.0, places=5) | |
| self.assertAlmostEqual(scores[1], 0.0, places=5) | |
| self.assertGreater(scores[2], 0.0) | |
| self.assertLess(scores[2], 1.0) | |
| def test_calculate_similarity_with_timer(self): | |
| """Test similarity calculation with timer""" | |
| mock_timer = Mock() | |
| mock_timer.time_step = MagicMock() | |
| mock_timer.time_step.return_value.__enter__ = Mock() | |
| mock_timer.time_step.return_value.__exit__ = Mock() | |
| query_emb = [1.0, 0.0, 0.0] | |
| doc_embs = [[1.0, 0.0, 0.0]] | |
| calculate_similarity(query_emb, doc_embs, timer=mock_timer) | |
| # Verify timer was used | |
| mock_timer.time_step.assert_called_once_with("similarity_calculation") | |
| def test_process_context_impl(self): | |
| """Test internal context processing implementation""" | |
| # Create mock results with metadata | |
| results = [] | |
| for i in range(3): | |
| mock_result = Mock() | |
| mock_result.metadata = { | |
| "id": f"KB00{i+1}", | |
| "question": f"Question {i+1}?", | |
| "content": f"Answer {i+1}.", | |
| } | |
| results.append(mock_result) | |
| # Cosine scores (sorted: 0.9, 0.7, 0.5) | |
| cosine_scores = [0.7, 0.5, 0.9] | |
| context, source_ids, knowledge_pairs = _process_context_impl( | |
| results, cosine_scores, max_results=2 | |
| ) | |
| # Should return top 2 results | |
| self.assertEqual(len(source_ids), 2) | |
| self.assertEqual(len(knowledge_pairs), 2) | |
| # Check that highest score (0.9, index 2) is first | |
| self.assertEqual(source_ids[0], "KB003") | |
| self.assertEqual(knowledge_pairs[0][0], "Question 3?") | |
| # Check formatted context | |
| self.assertIn("Knowledge Entry 1:", context) | |
| self.assertIn("Knowledge Entry 2:", context) | |
| self.assertIn("Q: Question 3?", context) | |
| self.assertIn("A: Answer 3.", context) | |
| def test_process_context_with_timer(self): | |
| """Test context processing with timer""" | |
| mock_result = Mock() | |
| mock_result.metadata = {"id": "KB001", "question": "Q?", "content": "A."} | |
| mock_timer = Mock() | |
| mock_timer.time_step = MagicMock() | |
| mock_timer.time_step.return_value.__enter__ = Mock() | |
| mock_timer.time_step.return_value.__exit__ = Mock() | |
| process_context([mock_result], [0.9], timer=mock_timer) | |
| # Verify timer was used | |
| mock_timer.time_step.assert_called_once_with("context_processing") | |
| def test_process_context_max_results(self): | |
| """Test that max_results parameter limits output""" | |
| # Create 5 mock results | |
| results = [] | |
| for i in range(5): | |
| mock_result = Mock() | |
| mock_result.metadata = { | |
| "id": f"KB00{i}", | |
| "question": f"Q{i}?", | |
| "content": f"A{i}.", | |
| } | |
| results.append(mock_result) | |
| scores = [0.9, 0.8, 0.7, 0.6, 0.5] | |
| # Request only 3 results | |
| context, source_ids, knowledge_pairs = _process_context_impl( | |
| results, scores, max_results=3 | |
| ) | |
| # Should only return 3 | |
| self.assertEqual(len(source_ids), 3) | |
| self.assertEqual(len(knowledge_pairs), 3) | |
| def test_process_context_formatting(self): | |
| """Test context formatting details""" | |
| mock_result = Mock() | |
| mock_result.metadata = { | |
| "id": "KB001", | |
| "question": "Test question?", | |
| "content": "Test answer.", | |
| } | |
| context, _, _ = _process_context_impl([mock_result], [0.9], max_results=1) | |
| # Check formatting | |
| self.assertIn("Knowledge Entry 1:", context) | |
| self.assertIn("Q: Test question?", context) | |
| self.assertIn("A: Test answer.", context) | |
| self.assertIn("-" * 40, context) | |
| def test_process_context_missing_metadata(self): | |
| """Test context processing with missing metadata fields""" | |
| mock_result = Mock() | |
| mock_result.metadata = {} # No metadata | |
| context, source_ids, knowledge_pairs = _process_context_impl( | |
| [mock_result], [0.9], max_results=1 | |
| ) | |
| # Should handle missing fields with N/A | |
| self.assertIn("N/A", context) | |
| self.assertEqual(source_ids[0], "N/A") | |
| def test_initialize_vector_store_new_collection( | |
| self, mock_chroma_class, mock_client_class, mock_get_kb | |
| ): | |
| """Test initializing vector store with new collection""" | |
| # Mock knowledge base data | |
| mock_get_kb.return_value = ( | |
| ["doc1", "doc2"], | |
| [{"id": "1"}, {"id": "2"}], | |
| ["id1", "id2"], | |
| ) | |
| # Mock ChromaDB client | |
| mock_client = Mock() | |
| mock_client_class.return_value = mock_client | |
| # Simulate collection doesn't exist (raises exception) | |
| mock_client.get_collection.side_effect = Exception("Collection not found") | |
| # Mock create_collection | |
| mock_collection = Mock() | |
| mock_client.create_collection.return_value = mock_collection | |
| # Mock Chroma vector store | |
| mock_vector_store = Mock() | |
| mock_retriever = Mock() | |
| mock_vector_store.as_retriever.return_value = mock_retriever | |
| mock_chroma_class.return_value = mock_vector_store | |
| # Call function | |
| from src.vector_store import initialize_vector_store | |
| collection, vector_store, retriever = initialize_vector_store() | |
| # Verify collection was created | |
| mock_client.create_collection.assert_called_once() | |
| mock_collection.add.assert_called_once() | |
| # Verify vector store and retriever | |
| self.assertEqual(vector_store, mock_vector_store) | |
| self.assertEqual(retriever, mock_retriever) | |
| def test_initialize_vector_store_existing_collection( | |
| self, mock_chroma_class, mock_client_class, mock_get_kb | |
| ): | |
| """Test initializing vector store with existing collection""" | |
| # Mock knowledge base data | |
| mock_get_kb.return_value = ( | |
| ["doc1", "doc2"], | |
| [{"id": "1"}, {"id": "2"}], | |
| ["id1", "id2"], | |
| ) | |
| # Mock ChromaDB client | |
| mock_client = Mock() | |
| mock_client_class.return_value = mock_client | |
| # Simulate collection exists | |
| mock_collection = Mock() | |
| mock_client.get_collection.return_value = mock_collection | |
| # Mock Chroma vector store | |
| mock_vector_store = Mock() | |
| mock_retriever = Mock() | |
| mock_vector_store.as_retriever.return_value = mock_retriever | |
| mock_chroma_class.return_value = mock_vector_store | |
| # Call function | |
| from src.vector_store import initialize_vector_store | |
| collection, vector_store, retriever = initialize_vector_store() | |
| # Verify existing collection was loaded (not created) | |
| mock_client.get_collection.assert_called_once() | |
| mock_client.create_collection.assert_not_called() | |
| # Verify vector store and retriever | |
| self.assertEqual(collection, mock_collection) | |
| self.assertEqual(vector_store, mock_vector_store) | |
| self.assertEqual(retriever, mock_retriever) | |
| def test_initialize_vector_store_failure(self, mock_client_class, mock_get_kb): | |
| """Test initialize_vector_store handles errors properly""" | |
| # Mock knowledge base data | |
| mock_get_kb.return_value = (["doc1"], [{"id": "1"}], ["id1"]) | |
| # Mock client to raise exception | |
| mock_client_class.side_effect = Exception("Database connection failed") | |
| # Call function and expect exception | |
| from src.vector_store import initialize_vector_store | |
| with self.assertRaises(Exception) as context: | |
| initialize_vector_store() | |
| self.assertIn("Database connection failed", str(context.exception)) | |
| if __name__ == "__main__": | |
| unittest.main() | |