AskXeno / tests /test_vector_store.py
github-actions
Sync from GitHub
3cdce90
"""
Unit tests for vector_store module
Tests ChromaDB vector store operations
"""
import unittest
from unittest.mock import MagicMock, Mock, patch
from src.vector_store import (_calculate_similarity_impl,
_generate_embeddings_impl, _process_context_impl,
calculate_similarity, generate_embeddings,
process_context)
class TestVectorStore(unittest.TestCase):
"""Test cases for vector_store module"""
def setUp(self):
"""Set up test fixtures"""
# Mock document
self.mock_doc = Mock()
self.mock_doc.page_content = "Test document content"
self.mock_doc.metadata = {
"id": "KB001",
"question": "Test question?",
"content": "Test answer.",
"section": "Test",
}
self.mock_documents = [self.mock_doc]
@patch("src.vector_store.genai_client")
def test_generate_embeddings_impl(self, mock_genai_client):
"""Test internal embedding generation implementation"""
# Mock embeddings for query and document
mock_query_embedding = Mock()
mock_query_embedding.values = [0.1, 0.2, 0.3]
mock_doc_embedding = Mock()
mock_doc_embedding.values = [0.2, 0.3, 0.4]
# Setup side effect for multiple calls
call_count = [0]
def embed_side_effect(*args, **kwargs):
call_count[0] += 1
mock_response = Mock()
if call_count[0] == 1:
mock_response.embeddings = [mock_query_embedding]
else:
mock_response.embeddings = [mock_doc_embedding]
return mock_response
mock_genai_client.models.embed_content.side_effect = embed_side_effect
query = "Test query"
query_emb, doc_embs = _generate_embeddings_impl(query, self.mock_documents)
# Verify embed_content was called correctly
self.assertEqual(mock_genai_client.models.embed_content.call_count, 2)
# Verify embeddings
self.assertEqual(query_emb, [0.1, 0.2, 0.3])
self.assertEqual(len(doc_embs), 1)
self.assertEqual(doc_embs[0], [0.2, 0.3, 0.4])
@patch("src.vector_store.genai_client")
def test_generate_embeddings_with_timer(self, mock_genai_client):
"""Test embedding generation with timer"""
# Mock embeddings
mock_embedding = Mock()
mock_embedding.values = [0.1, 0.2, 0.3]
mock_response = Mock()
mock_response.embeddings = [mock_embedding]
mock_genai_client.models.embed_content.return_value = mock_response
mock_timer = Mock()
mock_timer.time_step = MagicMock()
mock_timer.time_step.return_value.__enter__ = Mock()
mock_timer.time_step.return_value.__exit__ = Mock()
generate_embeddings("Test", self.mock_documents, timer=mock_timer)
# Verify timer was used
mock_timer.time_step.assert_called_once_with("embedding_generation")
@patch("src.vector_store.genai_client")
def test_generate_embeddings_multiple_docs(self, mock_genai_client):
"""Test embedding generation with multiple documents"""
# Create multiple mock documents
mock_doc2 = Mock()
mock_doc2.page_content = "Second document"
docs = [self.mock_doc, mock_doc2]
# Mock embeddings
mock_query_emb = Mock()
mock_query_emb.values = [0.1, 0.2, 0.3]
mock_doc1_emb = Mock()
mock_doc1_emb.values = [0.2, 0.3, 0.4]
mock_doc2_emb = Mock()
mock_doc2_emb.values = [0.3, 0.4, 0.5]
# First call for query, second call for both docs
call_count = [0]
def embed_side_effect(*args, **kwargs):
call_count[0] += 1
mock_response = Mock()
if call_count[0] == 1:
mock_response.embeddings = [mock_query_emb]
else:
mock_response.embeddings = [mock_doc1_emb, mock_doc2_emb]
return mock_response
mock_genai_client.models.embed_content.side_effect = embed_side_effect
query_emb, doc_embs = _generate_embeddings_impl("Test", docs)
# Should have 2 doc embeddings
self.assertEqual(len(doc_embs), 2)
self.assertEqual(mock_genai_client.models.embed_content.call_count, 2)
def test_calculate_similarity_impl(self):
"""Test internal similarity calculation implementation"""
query_embedding = [1.0, 0.0, 0.0]
doc_embeddings = [
[1.0, 0.0, 0.0], # Same as query - score should be ~1.0
[0.0, 1.0, 0.0], # Orthogonal - score should be ~0.0
[0.5, 0.5, 0.0], # Partial similarity
]
scores = _calculate_similarity_impl(query_embedding, doc_embeddings)
# Check scores
self.assertEqual(len(scores), 3)
self.assertAlmostEqual(scores[0], 1.0, places=5)
self.assertAlmostEqual(scores[1], 0.0, places=5)
self.assertGreater(scores[2], 0.0)
self.assertLess(scores[2], 1.0)
def test_calculate_similarity_with_timer(self):
"""Test similarity calculation with timer"""
mock_timer = Mock()
mock_timer.time_step = MagicMock()
mock_timer.time_step.return_value.__enter__ = Mock()
mock_timer.time_step.return_value.__exit__ = Mock()
query_emb = [1.0, 0.0, 0.0]
doc_embs = [[1.0, 0.0, 0.0]]
calculate_similarity(query_emb, doc_embs, timer=mock_timer)
# Verify timer was used
mock_timer.time_step.assert_called_once_with("similarity_calculation")
def test_process_context_impl(self):
"""Test internal context processing implementation"""
# Create mock results with metadata
results = []
for i in range(3):
mock_result = Mock()
mock_result.metadata = {
"id": f"KB00{i+1}",
"question": f"Question {i+1}?",
"content": f"Answer {i+1}.",
}
results.append(mock_result)
# Cosine scores (sorted: 0.9, 0.7, 0.5)
cosine_scores = [0.7, 0.5, 0.9]
context, source_ids, knowledge_pairs = _process_context_impl(
results, cosine_scores, max_results=2
)
# Should return top 2 results
self.assertEqual(len(source_ids), 2)
self.assertEqual(len(knowledge_pairs), 2)
# Check that highest score (0.9, index 2) is first
self.assertEqual(source_ids[0], "KB003")
self.assertEqual(knowledge_pairs[0][0], "Question 3?")
# Check formatted context
self.assertIn("Knowledge Entry 1:", context)
self.assertIn("Knowledge Entry 2:", context)
self.assertIn("Q: Question 3?", context)
self.assertIn("A: Answer 3.", context)
def test_process_context_with_timer(self):
"""Test context processing with timer"""
mock_result = Mock()
mock_result.metadata = {"id": "KB001", "question": "Q?", "content": "A."}
mock_timer = Mock()
mock_timer.time_step = MagicMock()
mock_timer.time_step.return_value.__enter__ = Mock()
mock_timer.time_step.return_value.__exit__ = Mock()
process_context([mock_result], [0.9], timer=mock_timer)
# Verify timer was used
mock_timer.time_step.assert_called_once_with("context_processing")
def test_process_context_max_results(self):
"""Test that max_results parameter limits output"""
# Create 5 mock results
results = []
for i in range(5):
mock_result = Mock()
mock_result.metadata = {
"id": f"KB00{i}",
"question": f"Q{i}?",
"content": f"A{i}.",
}
results.append(mock_result)
scores = [0.9, 0.8, 0.7, 0.6, 0.5]
# Request only 3 results
context, source_ids, knowledge_pairs = _process_context_impl(
results, scores, max_results=3
)
# Should only return 3
self.assertEqual(len(source_ids), 3)
self.assertEqual(len(knowledge_pairs), 3)
def test_process_context_formatting(self):
"""Test context formatting details"""
mock_result = Mock()
mock_result.metadata = {
"id": "KB001",
"question": "Test question?",
"content": "Test answer.",
}
context, _, _ = _process_context_impl([mock_result], [0.9], max_results=1)
# Check formatting
self.assertIn("Knowledge Entry 1:", context)
self.assertIn("Q: Test question?", context)
self.assertIn("A: Test answer.", context)
self.assertIn("-" * 40, context)
def test_process_context_missing_metadata(self):
"""Test context processing with missing metadata fields"""
mock_result = Mock()
mock_result.metadata = {} # No metadata
context, source_ids, knowledge_pairs = _process_context_impl(
[mock_result], [0.9], max_results=1
)
# Should handle missing fields with N/A
self.assertIn("N/A", context)
self.assertEqual(source_ids[0], "N/A")
@patch("src.vector_store.get_knowledge_base_data")
@patch("src.vector_store.chromadb.PersistentClient")
@patch("src.vector_store.Chroma")
def test_initialize_vector_store_new_collection(
self, mock_chroma_class, mock_client_class, mock_get_kb
):
"""Test initializing vector store with new collection"""
# Mock knowledge base data
mock_get_kb.return_value = (
["doc1", "doc2"],
[{"id": "1"}, {"id": "2"}],
["id1", "id2"],
)
# Mock ChromaDB client
mock_client = Mock()
mock_client_class.return_value = mock_client
# Simulate collection doesn't exist (raises exception)
mock_client.get_collection.side_effect = Exception("Collection not found")
# Mock create_collection
mock_collection = Mock()
mock_client.create_collection.return_value = mock_collection
# Mock Chroma vector store
mock_vector_store = Mock()
mock_retriever = Mock()
mock_vector_store.as_retriever.return_value = mock_retriever
mock_chroma_class.return_value = mock_vector_store
# Call function
from src.vector_store import initialize_vector_store
collection, vector_store, retriever = initialize_vector_store()
# Verify collection was created
mock_client.create_collection.assert_called_once()
mock_collection.add.assert_called_once()
# Verify vector store and retriever
self.assertEqual(vector_store, mock_vector_store)
self.assertEqual(retriever, mock_retriever)
@patch("src.vector_store.get_knowledge_base_data")
@patch("src.vector_store.chromadb.PersistentClient")
@patch("src.vector_store.Chroma")
def test_initialize_vector_store_existing_collection(
self, mock_chroma_class, mock_client_class, mock_get_kb
):
"""Test initializing vector store with existing collection"""
# Mock knowledge base data
mock_get_kb.return_value = (
["doc1", "doc2"],
[{"id": "1"}, {"id": "2"}],
["id1", "id2"],
)
# Mock ChromaDB client
mock_client = Mock()
mock_client_class.return_value = mock_client
# Simulate collection exists
mock_collection = Mock()
mock_client.get_collection.return_value = mock_collection
# Mock Chroma vector store
mock_vector_store = Mock()
mock_retriever = Mock()
mock_vector_store.as_retriever.return_value = mock_retriever
mock_chroma_class.return_value = mock_vector_store
# Call function
from src.vector_store import initialize_vector_store
collection, vector_store, retriever = initialize_vector_store()
# Verify existing collection was loaded (not created)
mock_client.get_collection.assert_called_once()
mock_client.create_collection.assert_not_called()
# Verify vector store and retriever
self.assertEqual(collection, mock_collection)
self.assertEqual(vector_store, mock_vector_store)
self.assertEqual(retriever, mock_retriever)
@patch("src.vector_store.get_knowledge_base_data")
@patch("src.vector_store.chromadb.PersistentClient")
def test_initialize_vector_store_failure(self, mock_client_class, mock_get_kb):
"""Test initialize_vector_store handles errors properly"""
# Mock knowledge base data
mock_get_kb.return_value = (["doc1"], [{"id": "1"}], ["id1"])
# Mock client to raise exception
mock_client_class.side_effect = Exception("Database connection failed")
# Call function and expect exception
from src.vector_store import initialize_vector_store
with self.assertRaises(Exception) as context:
initialize_vector_store()
self.assertIn("Database connection failed", str(context.exception))
if __name__ == "__main__":
unittest.main()