Aditya Bharadwaj commited on
Commit
3f12090
·
1 Parent(s): e95a894

Adding tests for vector_store.py (#98)

Browse files
Files changed (1) hide show
  1. tests/test_vector_store.py +124 -0
tests/test_vector_store.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+
4
+ import pytest
5
+ from unittest.mock import patch, MagicMock
6
+
7
+ from sage.vector_store import PineconeVectorStore, MarqoVectorStore, build_vector_store_from_args
8
+
9
+ mock_vectors = [({"id": "1", "text": "example"}, [0.1, 0.2, 0.3])]
10
+ mock_namespace = "test_namespace"
11
+
12
+
13
+ class TestVectorStore:
14
+ @pytest.fixture
15
+ def pinecone_store(self):
16
+ with patch("sage.vector_store.Pinecone"):
17
+ store = PineconeVectorStore(index_name="test_index", dimension=128, alpha=0.5)
18
+ yield store
19
+
20
+ @pytest.fixture
21
+ def marqo_store(self):
22
+ with patch("marqo.Client"):
23
+ store = MarqoVectorStore(url="http://localhost:8882", index_name="test_index")
24
+ yield store
25
+
26
+ @pytest.fixture
27
+ def mock_data_manager(self):
28
+ data_manager = MagicMock()
29
+ data_manager.walk.return_value = [("sample content", {})]
30
+ return data_manager
31
+
32
+ @pytest.fixture
33
+ def mock_nltk(self):
34
+ with patch("nltk.data.find") as mock_find:
35
+ mock_find.side_effect = LookupError
36
+ yield mock_find
37
+
38
+ @pytest.fixture
39
+ def mock_bm25_encoder(self):
40
+ with patch("sage.vector_store.BM25Encoder") as MockBM25Encoder:
41
+ mock_instance = MockBM25Encoder.return_value
42
+ mock_instance.encode_documents.return_value = [0.1, 0.2, 0.3]
43
+ mock_instance.fit = MagicMock()
44
+ mock_instance.dump = MagicMock()
45
+ yield mock_instance
46
+
47
+ def test_pinecone_vector_store_initialization(self, pinecone_store):
48
+ assert pinecone_store.index_name == "test_index"
49
+ assert pinecone_store.dimension == 128
50
+ assert pinecone_store.alpha == 0.5
51
+
52
+ def test_pinecone_vector_store_ensure_exists(self, pinecone_store):
53
+ pinecone_store.ensure_exists()
54
+ pinecone_store.client.create_index.assert_called_once()
55
+
56
+ def test_pinecone_vector_store_upsert_batch(self, pinecone_store):
57
+ pinecone_store.upsert_batch(mock_vectors, mock_namespace)
58
+ pinecone_store.client.Index().upsert.assert_called_once()
59
+
60
+ def test_marqo_vector_store_initialization(self, marqo_store):
61
+ assert marqo_store.index_name == "test_index"
62
+
63
+ def test_marqo_vector_store_ensure_exists(self, marqo_store):
64
+ # No specific assertion as ensure_exists is a no-op
65
+ marqo_store.ensure_exists()
66
+
67
+ def test_marqo_vector_store_upsert_batch(self, marqo_store):
68
+ # No specific assertion as upsert_batch is a no-op
69
+ marqo_store.upsert_batch(mock_vectors, mock_namespace)
70
+
71
+
72
+ def build_args(self, provider, alpha=1.0):
73
+ if provider == "pinecone":
74
+ return Namespace(
75
+ vector_store_provider="pinecone",
76
+ pinecone_index_name="test_index",
77
+ embedding_size=128,
78
+ retrieval_alpha=alpha,
79
+ index_namespace="test_namespace"
80
+ )
81
+ elif provider == "marqo":
82
+ return Namespace(
83
+ vector_store_provider="marqo",
84
+ marqo_url="http://localhost:8882",
85
+ index_namespace="test_index"
86
+ )
87
+
88
+ def build_bm25_cache_path(self):
89
+ return os.path.join(".bm25_cache", "test_namespace", "bm25_encoder.json")
90
+
91
+ def test_builds_pinecone_vector_store_with_default_bm25_encoder(self, pinecone_store, mock_bm25_encoder, mock_data_manager, mock_nltk):
92
+ args = self.build_args("pinecone", alpha=0.5)
93
+ store = build_vector_store_from_args(args, data_manager=mock_data_manager)
94
+ assert isinstance(store, PineconeVectorStore)
95
+ assert store.bm25_encoder is not None
96
+ mock_bm25_encoder.fit.assert_called_once()
97
+ mock_bm25_encoder.dump.assert_called_once_with(self.build_bm25_cache_path())
98
+
99
+ def test_builds_pinecone_vector_store_with_cached_bm25_encoder(self, pinecone_store, mock_bm25_encoder):
100
+ with patch("os.path.exists", return_value=True):
101
+ args = self.build_args("pinecone", alpha=0.5)
102
+ store = build_vector_store_from_args(args)
103
+ assert isinstance(store, PineconeVectorStore)
104
+ assert store.bm25_encoder is not None
105
+ mock_bm25_encoder.load.assert_called_once_with(path=self.build_bm25_cache_path())
106
+
107
+ def test_builds_pinecone_vector_store_without_bm25_encoder(self, pinecone_store):
108
+ args = self.build_args("pinecone", alpha=1.0)
109
+ store = build_vector_store_from_args(args)
110
+ assert isinstance(store, PineconeVectorStore)
111
+ assert store.bm25_encoder is None
112
+
113
+ def test_builds_marqo_vector_store(self):
114
+ args = self.build_args("marqo")
115
+ store = build_vector_store_from_args(args)
116
+ assert isinstance(store, MarqoVectorStore)
117
+
118
+ def test_raises_value_error_for_unrecognized_provider(self):
119
+ args = Namespace(vector_store_provider="unknown")
120
+ with pytest.raises(ValueError, match="Unrecognized vector store type unknown"):
121
+ build_vector_store_from_args(args)
122
+
123
+ if __name__ == '__main__':
124
+ pytest.main()