Spaces:
Running
Running
File size: 5,025 Bytes
3f12090 28fa826 3f12090 28fa826 3f12090 28fa826 3f12090 28fa826 3f12090 28fa826 3f12090 28fa826 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import os
from argparse import Namespace
from unittest.mock import MagicMock, patch
import pytest
from sage.vector_store import (
MarqoVectorStore,
PineconeVectorStore,
build_vector_store_from_args,
)
mock_vectors = [({"id": "1", "text": "example"}, [0.1, 0.2, 0.3])]
mock_namespace = "test_namespace"
class TestVectorStore:
@pytest.fixture
def pinecone_store(self):
with patch("sage.vector_store.Pinecone"):
store = PineconeVectorStore(index_name="test_index", dimension=128, alpha=0.5)
yield store
@pytest.fixture
def marqo_store(self):
with patch("marqo.Client"):
store = MarqoVectorStore(url="http://localhost:8882", index_name="test_index")
yield store
@pytest.fixture
def mock_data_manager(self):
data_manager = MagicMock()
data_manager.walk.return_value = [("sample content", {})]
return data_manager
@pytest.fixture
def mock_nltk(self):
with patch("nltk.data.find") as mock_find:
mock_find.side_effect = LookupError
yield mock_find
@pytest.fixture
def mock_bm25_encoder(self):
with patch("sage.vector_store.BM25Encoder") as MockBM25Encoder:
mock_instance = MockBM25Encoder.return_value
mock_instance.encode_documents.return_value = [0.1, 0.2, 0.3]
mock_instance.fit = MagicMock()
mock_instance.dump = MagicMock()
yield mock_instance
def test_pinecone_vector_store_initialization(self, pinecone_store):
assert pinecone_store.index_name == "test_index"
assert pinecone_store.dimension == 128
assert pinecone_store.alpha == 0.5
def test_pinecone_vector_store_ensure_exists(self, pinecone_store):
pinecone_store.ensure_exists()
pinecone_store.client.create_index.assert_called_once()
def test_pinecone_vector_store_upsert_batch(self, pinecone_store):
pinecone_store.upsert_batch(mock_vectors, mock_namespace)
pinecone_store.client.Index().upsert.assert_called_once()
def test_marqo_vector_store_initialization(self, marqo_store):
assert marqo_store.index_name == "test_index"
def test_marqo_vector_store_ensure_exists(self, marqo_store):
# No specific assertion as ensure_exists is a no-op
marqo_store.ensure_exists()
def test_marqo_vector_store_upsert_batch(self, marqo_store):
# No specific assertion as upsert_batch is a no-op
marqo_store.upsert_batch(mock_vectors, mock_namespace)
def build_args(self, provider, alpha=1.0):
if provider == "pinecone":
return Namespace(
vector_store_provider="pinecone",
pinecone_index_name="test_index",
embedding_size=128,
retrieval_alpha=alpha,
index_namespace="test_namespace",
)
elif provider == "marqo":
return Namespace(
vector_store_provider="marqo", marqo_url="http://localhost:8882", index_namespace="test_index"
)
def build_bm25_cache_path(self):
return os.path.join(".bm25_cache", "test_namespace", "bm25_encoder.json")
def test_builds_pinecone_vector_store_with_default_bm25_encoder(
self, pinecone_store, mock_bm25_encoder, mock_data_manager, mock_nltk
):
args = self.build_args("pinecone", alpha=0.5)
store = build_vector_store_from_args(args, data_manager=mock_data_manager)
assert isinstance(store, PineconeVectorStore)
assert store.bm25_encoder is not None
mock_bm25_encoder.fit.assert_called_once()
mock_bm25_encoder.dump.assert_called_once_with(self.build_bm25_cache_path())
def test_builds_pinecone_vector_store_with_cached_bm25_encoder(self, pinecone_store, mock_bm25_encoder):
with patch("os.path.exists", return_value=True):
args = self.build_args("pinecone", alpha=0.5)
store = build_vector_store_from_args(args)
assert isinstance(store, PineconeVectorStore)
assert store.bm25_encoder is not None
mock_bm25_encoder.load.assert_called_once_with(path=self.build_bm25_cache_path())
def test_builds_pinecone_vector_store_without_bm25_encoder(self, pinecone_store):
args = self.build_args("pinecone", alpha=1.0)
store = build_vector_store_from_args(args)
assert isinstance(store, PineconeVectorStore)
assert store.bm25_encoder is None
def test_builds_marqo_vector_store(self):
args = self.build_args("marqo")
store = build_vector_store_from_args(args)
assert isinstance(store, MarqoVectorStore)
def test_raises_value_error_for_unrecognized_provider(self):
args = Namespace(vector_store_provider="unknown")
with pytest.raises(ValueError, match="Unrecognized vector store type unknown"):
build_vector_store_from_args(args)
if __name__ == "__main__":
pytest.main()
|