Spaces:
Running
Running
| """Vector store abstraction and implementations.""" | |
| import logging | |
| import os | |
| from abc import ABC, abstractmethod | |
| from functools import cached_property | |
| from typing import Dict, Generator, List, Optional, Tuple | |
| import marqo | |
| import nltk | |
| from langchain_community.retrievers import PineconeHybridSearchRetriever | |
| from langchain_community.vectorstores import Marqo | |
| from langchain_community.vectorstores import Pinecone as LangChainPinecone | |
| from langchain_core.documents import Document | |
| from langchain_core.embeddings import Embeddings | |
| from nltk.data import find | |
| from pinecone import Pinecone, ServerlessSpec | |
| from pinecone_text.sparse import BM25Encoder | |
| from sage.constants import TEXT_FIELD | |
| from sage.data_manager import DataManager | |
| Vector = Tuple[Dict, List[float]] # (metadata, embedding) | |
| def is_punkt_downloaded(): | |
| try: | |
| find("tokenizers/punkt_tab") | |
| return True | |
| except LookupError: | |
| return False | |
| class VectorStore(ABC): | |
| """Abstract class for a vector store.""" | |
| def ensure_exists(self): | |
| """Ensures that the vector store exists. Creates it if it doesn't.""" | |
| def upsert_batch(self, vectors: List[Vector], namespace: str): | |
| """Upserts a batch of vectors.""" | |
| def upsert(self, vectors: Generator[Vector, None, None], namespace: str): | |
| """Upserts in batches of 100, since vector stores have a limit on upsert size.""" | |
| batch = [] | |
| for metadata, embedding in vectors: | |
| batch.append((metadata, embedding)) | |
| if len(batch) == 100: | |
| self.upsert_batch(batch, namespace) | |
| batch = [] | |
| if batch: | |
| self.upsert_batch(batch, namespace) | |
| def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str): | |
| """Converts the vector store to a LangChain retriever object.""" | |
| class PineconeVectorStore(VectorStore): | |
| """Vector store implementation using Pinecone.""" | |
| def __init__(self, index_name: str, dimension: int, alpha: float, bm25_cache: Optional[str] = None): | |
| """ | |
| Args: | |
| index_name: The name of the Pinecone index to use. If it doesn't exist already, we'll create it. | |
| dimension: The dimension of the vectors. | |
| alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure | |
| BM25, and 0.0 < alpha < 1.0 means a hybrid of the two. | |
| bm25_cache: The path to the BM25 encoder file. If not specified, we'll use the default BM25 (fitted on the | |
| MS MARCO dataset). | |
| """ | |
| self.index_name = index_name | |
| self.dimension = dimension | |
| self.client = Pinecone() | |
| self.alpha = alpha | |
| if alpha < 1.0: | |
| if bm25_cache and os.path.exists(bm25_cache): | |
| logging.info("Loading BM25 encoder from cache.") | |
| # We need nltk tokenizers for bm25 tokenization | |
| if is_punkt_downloaded(): | |
| print("punkt is already downloaded") | |
| else: | |
| print("punkt is not downloaded") | |
| # Optionally download it | |
| nltk.download("punkt_tab") | |
| self.bm25_encoder = BM25Encoder() | |
| self.bm25_encoder.load(path=bm25_cache) | |
| else: | |
| logging.info("Using default BM25 encoder (fitted to MS MARCO).") | |
| self.bm25_encoder = BM25Encoder.default() | |
| else: | |
| self.bm25_encoder = None | |
| def index(self): | |
| self.ensure_exists() | |
| index = self.client.Index(self.index_name) | |
| # Hack around the fact that PineconeRetriever expects the content of the chunk to be in a "text" field, | |
| # while PineconeHybridSearchRetrieve expects it to be in a "context" field. | |
| original_query = index.query | |
| def patched_query(*args, **kwargs): | |
| result = original_query(*args, **kwargs) | |
| for res in result["matches"]: | |
| if TEXT_FIELD in res["metadata"]: | |
| res["metadata"]["context"] = res["metadata"][TEXT_FIELD] | |
| return result | |
| index.query = patched_query | |
| return index | |
| def ensure_exists(self): | |
| if self.index_name not in self.client.list_indexes().names(): | |
| self.client.create_index( | |
| name=self.index_name, | |
| dimension=self.dimension, | |
| # See https://www.pinecone.io/learn/hybrid-search-intro/ | |
| metric="dotproduct" if self.bm25_encoder else "cosine", | |
| spec=ServerlessSpec(cloud="aws", region="us-east-1"), | |
| ) | |
| def upsert_batch(self, vectors: List[Vector], namespace: str): | |
| pinecone_vectors = [] | |
| for i, (metadata, embedding) in enumerate(vectors): | |
| vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata} | |
| if self.bm25_encoder: | |
| vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD]) | |
| pinecone_vectors.append(vector) | |
| self.index.upsert(vectors=pinecone_vectors, namespace=namespace) | |
| def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str): | |
| if self.bm25_encoder: | |
| return PineconeHybridSearchRetriever( | |
| embeddings=embeddings, | |
| sparse_encoder=self.bm25_encoder, | |
| index=self.index, | |
| namespace=namespace, | |
| top_k=top_k, | |
| alpha=self.alpha, | |
| ) | |
| return LangChainPinecone.from_existing_index( | |
| index_name=self.index_name, embedding=embeddings, namespace=namespace | |
| ).as_retriever(search_kwargs={"k": top_k}) | |
| class MarqoVectorStore(VectorStore): | |
| """Vector store implementation using Marqo.""" | |
| def __init__(self, url: str, index_name: str): | |
| self.client = marqo.Client(url=url) | |
| self.index_name = index_name | |
| def ensure_exists(self): | |
| pass | |
| def upsert_batch(self, vectors: List[Vector], namespace: str): | |
| # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert. | |
| pass | |
| def as_retriever(self, top_k: int, embeddings: Embeddings = None, namespace: str = None): | |
| del embeddings # Unused; The Marqo vector store is also an embedder. | |
| del namespace # Unused; Unlike Pinecone, Marqo doesn't differentiate between index name and namespace. | |
| vectorstore = Marqo(client=self.client, index_name=self.index_name) | |
| # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in | |
| # the result, and instead take the "filename" directly from the result. | |
| def patched_method(self, results): | |
| documents: List[Document] = [] | |
| for result in results["hits"]: | |
| content = result.pop(TEXT_FIELD) | |
| documents.append(Document(page_content=content, metadata=result)) | |
| return documents | |
| vectorstore._construct_documents_from_results_without_score = patched_method.__get__( | |
| vectorstore, vectorstore.__class__ | |
| ) | |
| return vectorstore.as_retriever(search_kwargs={"k": top_k}) | |
| def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager] = None) -> VectorStore: | |
| """Builds a vector store from the given command-line arguments. | |
| When `data_manager` is specified and hybrid retrieval is requested, we'll use it to fit a BM25 encoder on the corpus | |
| of documents. | |
| """ | |
| if args.vector_store_provider == "pinecone": | |
| bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json") | |
| if args.retrieval_alpha < 1.0 and not os.path.exists(bm25_cache) and data_manager: | |
| logging.info("Fitting BM25 encoder on the corpus...") | |
| if is_punkt_downloaded(): | |
| print("punkt is already downloaded") | |
| else: | |
| print("punkt is not downloaded") | |
| # Optionally download it | |
| nltk.download("punkt_tab") | |
| corpus = [content for content, _ in data_manager.walk()] | |
| bm25_encoder = BM25Encoder() | |
| bm25_encoder.fit(corpus) | |
| # Make sure the folder exists, before we dump the encoder. | |
| bm25_folder = os.path.dirname(bm25_cache) | |
| if not os.path.exists(bm25_folder): | |
| os.makedirs(bm25_folder) | |
| bm25_encoder.dump(bm25_cache) | |
| return PineconeVectorStore( | |
| index_name=args.pinecone_index_name, | |
| dimension=args.embedding_size if "embedding_size" in args else None, | |
| alpha=args.retrieval_alpha, | |
| bm25_cache=bm25_cache, | |
| ) | |
| elif args.vector_store_provider == "marqo": | |
| return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace) | |
| else: | |
| raise ValueError(f"Unrecognized vector store type {args.vector_store_provider}") | |