Lucifer Akirami commited on
Commit
210c3c5
·
1 Parent(s): 1dc0365

Closes [FEATURE REQUEST] Expand to other vector stores beyond Pinecone (#102)

Browse files
Files changed (4) hide show
  1. pyproject.toml +4 -0
  2. sage/config.py +11 -7
  3. sage/index.py +6 -3
  4. sage/vector_store.py +246 -3
pyproject.toml CHANGED
@@ -26,6 +26,7 @@ dependencies = [
26
  "anytree==2.12.1",
27
  "cohere==5.9.2",
28
  "configargparse",
 
29
  "fastapi==0.112.2",
30
  "google-ai-generativelanguage==0.6.6",
31
  "gradio>=4.26.0",
@@ -41,6 +42,9 @@ dependencies = [
41
  "langchain-openai==0.1.25",
42
  "langchain-text-splitters==0.2.4",
43
  "langchain-voyageai==0.1.1",
 
 
 
44
  "marqo==3.7.0",
45
  "nbformat==5.10.4",
46
  "openai==1.42.0",
 
26
  "anytree==2.12.1",
27
  "cohere==5.9.2",
28
  "configargparse",
29
+ "faiss-cpu==1.9.0",
30
  "fastapi==0.112.2",
31
  "google-ai-generativelanguage==0.6.6",
32
  "gradio>=4.26.0",
 
42
  "langchain-openai==0.1.25",
43
  "langchain-text-splitters==0.2.4",
44
  "langchain-voyageai==0.1.1",
45
+ "langchain-milvus==0.1.6",
46
+ "langchain-chroma==0.1.4",
47
+ "langchain-qdrant==0.1.4",
48
  "marqo==3.7.0",
49
  "nbformat==5.10.4",
50
  "openai==1.42.0",
sage/config.py CHANGED
@@ -122,12 +122,16 @@ def add_embedding_args(parser: ArgumentParser) -> Callable:
122
 
123
  def add_vector_store_args(parser: ArgumentParser) -> Callable:
124
  """Adds vector store-related arguments to the parser and returns a validator."""
125
- parser.add("--vector-store-provider", default="marqo", choices=["pinecone", "marqo"])
126
  parser.add(
127
- "--pinecone-index-name",
128
- default=None,
129
- help="Pinecone index name. Required if using Pinecone as the vector store. If the index doesn't exist already, "
130
- "we will create it.",
 
 
 
 
 
131
  )
132
  parser.add(
133
  "--index-namespace",
@@ -402,8 +406,8 @@ def validate_vector_store_args(args):
402
  elif args.vector_store_provider == "pinecone":
403
  if not os.getenv("PINECONE_API_KEY"):
404
  raise ValueError("Please set the PINECONE_API_KEY environment variable.")
405
- if not args.pinecone_index_name:
406
- raise ValueError(f"Please set the vector_store.pinecone_index_name value.")
407
 
408
 
409
  def validate_indexing_args(args):
 
122
 
123
  def add_vector_store_args(parser: ArgumentParser) -> Callable:
124
  """Adds vector store-related arguments to the parser and returns a validator."""
 
125
  parser.add(
126
+ "--vector-store-provider", default="marqo", choices=["pinecone", "marqo", "chroma", "faiss", "milvus", "qdrant"]
127
+ )
128
+ parser.add(
129
+ "--index-name", default="sage_index", help="Index name for the Vector Store index. We default it to sage_index"
130
+ )
131
+ parser.add(
132
+ "--milvus-uri",
133
+ default="milvus_sage.db",
134
+ help="URI for milvus. We default it to milvus_sage.db",
135
  )
136
  parser.add(
137
  "--index-namespace",
 
406
  elif args.vector_store_provider == "pinecone":
407
  if not os.getenv("PINECONE_API_KEY"):
408
  raise ValueError("Please set the PINECONE_API_KEY environment variable.")
409
+ if not args.index_name:
410
+ raise ValueError(f"Please set the vector_store.index_name value.")
411
 
412
 
413
  def validate_indexing_args(args):
sage/index.py CHANGED
@@ -11,7 +11,7 @@ from sage.chunker import UniversalFileChunker
11
  from sage.data_manager import GitHubRepoManager
12
  from sage.embedder import build_batch_embedder_from_flags
13
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
14
- from sage.vector_store import build_vector_store_from_args
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger()
@@ -41,8 +41,11 @@ def main():
41
  return
42
 
43
  # Additionally validate embedder and vector store compatibility.
44
- if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
45
- parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
 
 
 
46
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
47
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
48
 
 
11
  from sage.data_manager import GitHubRepoManager
12
  from sage.embedder import build_batch_embedder_from_flags
13
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
14
+ from sage.vector_store import VectorStoreProvider, build_vector_store_from_args
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger()
 
41
  return
42
 
43
  # Additionally validate embedder and vector store compatibility.
44
+ vector_store_providers = [member.value for member in VectorStoreProvider]
45
+ if args.embedding_provider == "openai" and args.vector_store_provider not in vector_store_providers:
46
+ parser.error(
47
+ f"When using OpenAI embedder, the vector store type must be from the list {vector_store_providers}."
48
+ )
49
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
50
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
51
 
sage/vector_store.py CHANGED
@@ -3,20 +3,33 @@
3
  import logging
4
  import os
5
  from abc import ABC, abstractmethod
 
6
  from functools import cached_property
7
  from typing import Dict, Generator, List, Optional, Tuple
 
8
 
 
 
9
  import marqo
10
  import nltk
11
  from langchain.retrievers import EnsembleRetriever
 
 
12
  from langchain_community.retrievers import BM25Retriever
13
- from langchain_community.vectorstores import Marqo
14
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
15
  from langchain_core.documents import Document
16
  from langchain_core.embeddings import Embeddings
 
 
 
 
 
17
  from nltk.data import find
18
  from pinecone import Pinecone, ServerlessSpec
19
  from pinecone_text.sparse import BM25Encoder
 
 
20
 
21
  from sage.constants import TEXT_FIELD
22
  from sage.data_manager import DataManager
@@ -24,6 +37,15 @@ from sage.data_manager import DataManager
24
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
25
 
26
 
 
 
 
 
 
 
 
 
 
27
  def is_punkt_downloaded():
28
  try:
29
  find("tokenizers/punkt_tab")
@@ -156,6 +178,207 @@ class PineconeVectorStore(VectorStore):
156
  return dense_retriever
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  class MarqoVectorStore(VectorStore):
160
  """Vector store implementation using Marqo."""
161
 
@@ -191,12 +414,22 @@ class MarqoVectorStore(VectorStore):
191
  return vectorstore.as_retriever(search_kwargs={"k": top_k})
192
 
193
 
194
- def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager] = None) -> VectorStore:
 
 
 
195
  """Builds a vector store from the given command-line arguments.
196
 
197
  When `data_manager` is specified and hybrid retrieval is requested, we'll use it to fit a BM25 encoder on the corpus
198
  of documents.
199
  """
 
 
 
 
 
 
 
200
  if args.vector_store_provider == "pinecone":
201
  bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
202
  if args.retrieval_alpha < 1.0 and not os.path.exists(bm25_cache) and data_manager:
@@ -217,11 +450,21 @@ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager]
217
  bm25_encoder.dump(bm25_cache)
218
 
219
  return PineconeVectorStore(
220
- index_name=args.pinecone_index_name,
221
  dimension=args.embedding_size if "embedding_size" in args else None,
222
  alpha=args.retrieval_alpha,
223
  bm25_cache=bm25_cache,
224
  )
 
 
 
 
 
 
 
 
 
 
225
  elif args.vector_store_provider == "marqo":
226
  return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
227
  else:
 
3
  import logging
4
  import os
5
  from abc import ABC, abstractmethod
6
+ from enum import Enum
7
  from functools import cached_property
8
  from typing import Dict, Generator, List, Optional, Tuple
9
+ from uuid import uuid4
10
 
11
+ import chromadb
12
+ import faiss
13
  import marqo
14
  import nltk
15
  from langchain.retrievers import EnsembleRetriever
16
+ from langchain_chroma import Chroma as LangChainChroma
17
+ from langchain_community.docstore.in_memory import InMemoryDocstore
18
  from langchain_community.retrievers import BM25Retriever
19
+ from langchain_community.vectorstores import FAISS, Marqo
20
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
21
  from langchain_core.documents import Document
22
  from langchain_core.embeddings import Embeddings
23
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
24
+ from langchain_milvus import Milvus
25
+ from langchain_openai import OpenAIEmbeddings
26
+ from langchain_qdrant import QdrantVectorStore as LangChainQdrant
27
+ from langchain_voyageai import VoyageAIEmbeddings
28
  from nltk.data import find
29
  from pinecone import Pinecone, ServerlessSpec
30
  from pinecone_text.sparse import BM25Encoder
31
+ from qdrant_client import QdrantClient
32
+ from qdrant_client.http.models import Distance, VectorParams
33
 
34
  from sage.constants import TEXT_FIELD
35
  from sage.data_manager import DataManager
 
37
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
38
 
39
 
40
+ class VectorStoreProvider(Enum):
41
+ PINECONE = "pinecone"
42
+ MARQO = "marqo"
43
+ CHROMA = "chroma"
44
+ FAISS = "faiss"
45
+ MILVUS = "milvus"
46
+ QDRANT = "qdrant"
47
+
48
+
49
  def is_punkt_downloaded():
50
  try:
51
  find("tokenizers/punkt_tab")
 
178
  return dense_retriever
179
 
180
 
181
+ class ChromaVectorStore(VectorStore):
182
+ """Vector store implementation using ChromaDB"""
183
+
184
+ def __init__(self, index_name: str, alpha: float = None, bm25_cache: Optional[str] = None):
185
+ """
186
+ Args:
187
+ index_name: The name of the Chroma collection/index to use. If it doesn't exist already, we'll create it.
188
+ alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
189
+ BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
190
+ """
191
+ self.index_name = index_name
192
+ self.alpha = alpha
193
+ self.client = chromadb.PersistentClient()
194
+
195
+ @cached_property
196
+ def index(self):
197
+ index = self.client.get_or_create_collection(self.index_name)
198
+ return index
199
+
200
+ def ensure_exists(self):
201
+ pass
202
+
203
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
204
+ del namespace
205
+
206
+ ids = []
207
+ embeddings = []
208
+ metadatas = []
209
+ documents = []
210
+
211
+ for i, (metadata, embedding) in enumerate(vectors):
212
+ ids.append(metadata.get("id", str(i)))
213
+ embeddings.append(embedding)
214
+ metadatas.append(metadata)
215
+ documents.append(metadata[TEXT_FIELD])
216
+
217
+ self.index.upsert(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents)
218
+
219
+ def as_retriever(self, top_k: int, embeddings: Embeddings = None, namespace: str = None):
220
+ vector_store = LangChainChroma(
221
+ collection_name=self.index_name, embedding_function=embeddings, client=self.client
222
+ )
223
+
224
+ return vector_store.as_retriever(search_kwargs={"k": top_k})
225
+
226
+
227
+ class FAISSVectorStore(VectorStore):
228
+ """Vector store implementation using FAISS"""
229
+
230
+ def __init__(self, index_name: str, dimension: int, embeddings: Embeddings = None):
231
+ """
232
+ Args:
233
+ index_name: The name of the FAISS index to use. If it doesn't exist already, we'll create it.
234
+ dimension: The dimension of the vectors.
235
+ embeddings: The embedding function used to generate embeddings
236
+ """
237
+ self.index_name = index_name
238
+ self.dimension = dimension
239
+ self.embeddings = embeddings
240
+
241
+ # check if the index exists
242
+ if os.path.exists(self.index_name):
243
+ # load the existing index
244
+ self.vector_store = FAISS.load_local(
245
+ folder_path=self.index_name, embeddings=self.embeddings, allow_dangerous_deserialization=True
246
+ )
247
+ # else create a new index
248
+ else:
249
+ self.vector_store = FAISS(
250
+ embedding_function=self.embeddings,
251
+ index=self.index,
252
+ docstore=InMemoryDocstore(),
253
+ index_to_docstore_id={},
254
+ )
255
+
256
+ @cached_property
257
+ def index(self):
258
+ index = faiss.IndexFlatL2(self.dimension)
259
+ return index
260
+
261
+ def ensure_exists(self):
262
+ pass
263
+
264
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
265
+ del namespace
266
+
267
+ ids = []
268
+ documents = []
269
+
270
+ for i, (meta_data, embedding) in enumerate(vectors):
271
+ ids.append(meta_data.get("id", str(i)))
272
+ document = Document(page_content=meta_data[TEXT_FIELD], metadata=meta_data)
273
+ documents.append(document)
274
+
275
+ self.vector_store.add_documents(documents=documents, ids=ids)
276
+
277
+ # saving the index after every batch upsert
278
+ self.vector_store.save_local(self.index_name)
279
+ print("Save Local Executed")
280
+ logging.error("Save Local Got Executed")
281
+
282
+ def as_retriever(self, top_k, embeddings, namespace):
283
+ del embeddings
284
+ del namespace
285
+
286
+ return self.vector_store.as_retriever(search_kwards={"k": top_k})
287
+
288
+
289
+ class MilvusVectorStore(VectorStore):
290
+ """Vector store implementation using Milvus"""
291
+
292
+ def __init__(self, uri: str, index_name: str, embeddings: Embeddings = None):
293
+ """
294
+ Args:
295
+ index_name: The name of the Milvus collection to use. If it doesn't exist already, we'll create it.
296
+ embeddings: The embedding function used to generate embeddings
297
+ """
298
+ self.uri = uri
299
+ self.index_name = index_name
300
+ self.embeddings = embeddings
301
+
302
+ self.vector_store = Milvus(
303
+ embedding_function=embeddings, connection_args={"uri": self.uri}, collection_name=self.index_name
304
+ )
305
+
306
+ def ensure_exists(self):
307
+ pass
308
+
309
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
310
+ del namespace
311
+
312
+ ids = []
313
+ documents = []
314
+
315
+ for i, (meta_data, embedding) in enumerate(vectors):
316
+ ids.append(meta_data.get("id", str(i)))
317
+ # "text" is a reserved keyword. So removing it
318
+ page_content = meta_data[TEXT_FIELD]
319
+ meta_data["content"] = meta_data[TEXT_FIELD]
320
+ del meta_data[TEXT_FIELD]
321
+
322
+ document = Document(page_content=page_content, metadata=meta_data)
323
+ documents.append(document)
324
+
325
+ self.vector_store.add_documents(documents=documents, ids=ids)
326
+
327
+ def as_retriever(self, top_k, embeddings, namespace):
328
+ del embeddings
329
+ del namespace
330
+
331
+ return self.vector_store.as_retriever(search_kwards={"k": top_k})
332
+
333
+
334
+ class QdrantVectorStore(VectorStore):
335
+ """Vector store implementation using Qdrant"""
336
+
337
+ def __init__(self, index_name: str, dimension: int, embeddings: Embeddings = None):
338
+ """
339
+ Args:
340
+ index_name: The name of the Qdrant collection to use. If it doesn't exist already, we'll create it.
341
+ embeddings: The embedding function used to generate embeddings
342
+ """
343
+ self.index_name = index_name
344
+ self.dimension = dimension
345
+ self.embeddings = embeddings
346
+ self.client = QdrantClient(path="qdrantdb")
347
+ self.vector_store = self.index
348
+
349
+ @cached_property
350
+ def index(self):
351
+ self.ensure_exists()
352
+ vector_store = LangChainQdrant(client=self.client, collection_name=self.index_name, embedding=self.embeddings)
353
+ return vector_store
354
+
355
+ def ensure_exists(self):
356
+ if not self.client.collection_exists(self.index_name):
357
+ self.client.create_collection(
358
+ collection_name=self.index_name,
359
+ vectors_config=VectorParams(size=self.dimension, distance=Distance.COSINE),
360
+ )
361
+
362
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
363
+ del namespace
364
+
365
+ ids = []
366
+ documents = []
367
+
368
+ for i, (meta_data, embedding) in enumerate(vectors):
369
+ ids.append(str(uuid4()))
370
+ document = Document(page_content=meta_data[TEXT_FIELD], metadata=meta_data)
371
+ documents.append(document)
372
+
373
+ self.vector_store.add_documents(documents=documents, ids=ids)
374
+
375
+ def as_retriever(self, top_k, embeddings, namespace):
376
+ del embeddings
377
+ del namespace
378
+
379
+ return self.vector_store.as_retriever(search_kwards={"k": top_k})
380
+
381
+
382
  class MarqoVectorStore(VectorStore):
383
  """Vector store implementation using Marqo."""
384
 
 
414
  return vectorstore.as_retriever(search_kwargs={"k": top_k})
415
 
416
 
417
+ def build_vector_store_from_args(
418
+ args: dict,
419
+ data_manager: Optional[DataManager] = None,
420
+ ) -> VectorStore:
421
  """Builds a vector store from the given command-line arguments.
422
 
423
  When `data_manager` is specified and hybrid retrieval is requested, we'll use it to fit a BM25 encoder on the corpus
424
  of documents.
425
  """
426
+ if args.embedding_provider == "openai":
427
+ embeddings = OpenAIEmbeddings(model=args.embedding_model)
428
+ elif args.embedding_provider == "voyage":
429
+ embeddings = VoyageAIEmbeddings(model=args.embedding_model)
430
+ elif args.embedding_provider == "gemini":
431
+ embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
432
+
433
  if args.vector_store_provider == "pinecone":
434
  bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
435
  if args.retrieval_alpha < 1.0 and not os.path.exists(bm25_cache) and data_manager:
 
450
  bm25_encoder.dump(bm25_cache)
451
 
452
  return PineconeVectorStore(
453
+ index_name=args.index_name,
454
  dimension=args.embedding_size if "embedding_size" in args else None,
455
  alpha=args.retrieval_alpha,
456
  bm25_cache=bm25_cache,
457
  )
458
+ elif args.vector_store_provider == "chroma":
459
+ return ChromaVectorStore(
460
+ index_name=args.index_name,
461
+ )
462
+ elif args.vector_store_provider == "faiss":
463
+ return FAISSVectorStore(index_name=args.index_name, dimension=args.embedding_size, embeddings=embeddings)
464
+ elif args.vector_store_provider == "milvus":
465
+ return MilvusVectorStore(uri=args.milvus_uri, index_name=args.index_name, embeddings=embeddings)
466
+ elif args.vector_store_provider == "qdrant":
467
+ return QdrantVectorStore(index_name=args.index_name, dimension=args.embedding_size, embeddings=embeddings)
468
  elif args.vector_store_provider == "marqo":
469
  return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
470
  else: