Spaces:
Running
Running
Add option for Pinecone + BM25 hybrid retrieval. (#36)
Browse files- README.md +1 -1
- sage/chat.py +9 -9
- sage/chunker.py +3 -1
- sage/constants.py +3 -0
- sage/embedder.py +4 -3
- sage/github.py +2 -1
- sage/index.py +7 -0
- sage/vector_store.py +65 -15
README.md
CHANGED
|
@@ -80,7 +80,7 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 80 |
export PINECONE_API_KEY=...
|
| 81 |
```
|
| 82 |
|
| 83 |
-
2. Create a Pinecone index
|
| 84 |
```
|
| 85 |
export PINECONE_INDEX_NAME=...
|
| 86 |
```
|
|
|
|
| 80 |
export PINECONE_API_KEY=...
|
| 81 |
```
|
| 82 |
|
| 83 |
+
2. Create a Pinecone account. Export the desired index name (if it doesn't exist yet, we'll create it):
|
| 84 |
```
|
| 85 |
export PINECONE_INDEX_NAME=...
|
| 86 |
```
|
sage/chat.py
CHANGED
|
@@ -28,7 +28,8 @@ def build_rag_chain(args):
|
|
| 28 |
"""Builds a RAG chain via LangChain."""
|
| 29 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 30 |
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
if args.reranker_provider == "none":
|
| 34 |
compressor = None
|
|
@@ -78,14 +79,6 @@ def build_rag_chain(args):
|
|
| 78 |
return rag_chain
|
| 79 |
|
| 80 |
|
| 81 |
-
def append_sources_to_response(response):
|
| 82 |
-
"""Given an OpenAI completion response, appends to it GitHub links of the context sources."""
|
| 83 |
-
urls = [document.metadata["url"] for document in response["context"]]
|
| 84 |
-
# Deduplicate urls while preserving their order.
|
| 85 |
-
urls = list(dict.fromkeys(urls))
|
| 86 |
-
return response["answer"] + "\n\nSources:\n" + "\n".join(urls)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
def main():
|
| 90 |
parser = argparse.ArgumentParser(description="UI to chat with your codebase")
|
| 91 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
|
@@ -112,6 +105,13 @@ def main():
|
|
| 112 |
default=False,
|
| 113 |
help="Whether to make the gradio app publicly accessible.",
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
args = parser.parse_args()
|
| 116 |
|
| 117 |
if not args.index_name:
|
|
|
|
| 28 |
"""Builds a RAG chain via LangChain."""
|
| 29 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 30 |
|
| 31 |
+
retriever_top_k = 5 if args.reranker_provider == "none" else 25
|
| 32 |
+
retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
|
| 33 |
|
| 34 |
if args.reranker_provider == "none":
|
| 35 |
compressor = None
|
|
|
|
| 79 |
return rag_chain
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def main():
|
| 83 |
parser = argparse.ArgumentParser(description="UI to chat with your codebase")
|
| 84 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
|
|
|
| 105 |
default=False,
|
| 106 |
help="Whether to make the gradio app publicly accessible.",
|
| 107 |
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--hybrid-retrieval",
|
| 110 |
+
action=argparse.BooleanOptionalAction,
|
| 111 |
+
default=True,
|
| 112 |
+
help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
|
| 113 |
+
"retrieval. This is only relevant if using Pinecone as the vector store.",
|
| 114 |
+
)
|
| 115 |
args = parser.parse_args()
|
| 116 |
|
| 117 |
if not args.index_name:
|
sage/chunker.py
CHANGED
|
@@ -14,6 +14,8 @@ from semchunk import chunk as chunk_via_semchunk
|
|
| 14 |
from tree_sitter import Node
|
| 15 |
from tree_sitter_language_pack import get_parser
|
| 16 |
|
|
|
|
|
|
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 19 |
|
|
@@ -62,7 +64,7 @@ class FileChunk(Chunk):
|
|
| 62 |
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
| 63 |
# size limit. In that case, you can simply store the start/end bytes above, and fetch the content
|
| 64 |
# directly from the repository when needed.
|
| 65 |
-
|
| 66 |
}
|
| 67 |
chunk_metadata.update(self.file_metadata)
|
| 68 |
return chunk_metadata
|
|
|
|
| 14 |
from tree_sitter import Node
|
| 15 |
from tree_sitter_language_pack import get_parser
|
| 16 |
|
| 17 |
+
from sage.constants import TEXT_FIELD
|
| 18 |
+
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 21 |
|
|
|
|
| 64 |
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
| 65 |
# size limit. In that case, you can simply store the start/end bytes above, and fetch the content
|
| 66 |
# directly from the repository when needed.
|
| 67 |
+
TEXT_FIELD: self.content,
|
| 68 |
}
|
| 69 |
chunk_metadata.update(self.file_metadata)
|
| 70 |
return chunk_metadata
|
sage/constants.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is the key in the metadata that points to the actual text content of a document or chunk.
|
| 2 |
+
# It can mostly be an arbitrary string, but certain classes in LangChain do expect it to be "text" specifically.
|
| 3 |
+
TEXT_FIELD = "text"
|
sage/embedder.py
CHANGED
|
@@ -12,6 +12,7 @@ import marqo
|
|
| 12 |
from openai import OpenAI
|
| 13 |
|
| 14 |
from sage.chunker import Chunk, Chunker
|
|
|
|
| 15 |
from sage.data_manager import DataManager
|
| 16 |
|
| 17 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
|
@@ -139,7 +140,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 139 |
and "start_byte" in metadata
|
| 140 |
and "end_byte" in metadata
|
| 141 |
):
|
| 142 |
-
metadata.pop(
|
| 143 |
embedding = datum["embedding"]
|
| 144 |
yield (metadata, embedding)
|
| 145 |
|
|
@@ -240,7 +241,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 240 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 241 |
self.index.add_documents(
|
| 242 |
documents=[chunk.metadata for chunk in sub_batch],
|
| 243 |
-
tensor_fields=[
|
| 244 |
)
|
| 245 |
job_count += 1
|
| 246 |
|
|
@@ -251,7 +252,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 251 |
|
| 252 |
# Finally, commit the last batch.
|
| 253 |
if batch:
|
| 254 |
-
self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[
|
| 255 |
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 256 |
|
| 257 |
def embeddings_are_ready(self) -> bool:
|
|
|
|
| 12 |
from openai import OpenAI
|
| 13 |
|
| 14 |
from sage.chunker import Chunk, Chunker
|
| 15 |
+
from sage.constants import TEXT_FIELD
|
| 16 |
from sage.data_manager import DataManager
|
| 17 |
|
| 18 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
|
|
|
| 140 |
and "start_byte" in metadata
|
| 141 |
and "end_byte" in metadata
|
| 142 |
):
|
| 143 |
+
metadata.pop(TEXT_FIELD, None)
|
| 144 |
embedding = datum["embedding"]
|
| 145 |
yield (metadata, embedding)
|
| 146 |
|
|
|
|
| 241 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 242 |
self.index.add_documents(
|
| 243 |
documents=[chunk.metadata for chunk in sub_batch],
|
| 244 |
+
tensor_fields=[TEXT_FIELD],
|
| 245 |
)
|
| 246 |
job_count += 1
|
| 247 |
|
|
|
|
| 252 |
|
| 253 |
# Finally, commit the last batch.
|
| 254 |
if batch:
|
| 255 |
+
self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD])
|
| 256 |
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 257 |
|
| 258 |
def embeddings_are_ready(self) -> bool:
|
sage/github.py
CHANGED
|
@@ -9,6 +9,7 @@ import requests
|
|
| 9 |
import tiktoken
|
| 10 |
|
| 11 |
from sage.chunker import Chunk, Chunker
|
|
|
|
| 12 |
from sage.data_manager import DataManager
|
| 13 |
|
| 14 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
@@ -179,7 +180,7 @@ class IssueChunk(Chunk):
|
|
| 179 |
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
| 180 |
# size limit. In that case, you can simply store the start/end comment indices above, and fetch the
|
| 181 |
# content of the issue on demand from the URL.
|
| 182 |
-
|
| 183 |
}
|
| 184 |
|
| 185 |
@property
|
|
|
|
| 9 |
import tiktoken
|
| 10 |
|
| 11 |
from sage.chunker import Chunk, Chunker
|
| 12 |
+
from sage.constants import TEXT_FIELD
|
| 13 |
from sage.data_manager import DataManager
|
| 14 |
|
| 15 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
| 180 |
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
| 181 |
# size limit. In that case, you can simply store the start/end comment indices above, and fetch the
|
| 182 |
# content of the issue on demand from the URL.
|
| 183 |
+
TEXT_FIELD: self.content,
|
| 184 |
}
|
| 185 |
|
| 186 |
@property
|
sage/index.py
CHANGED
|
@@ -118,6 +118,13 @@ def main():
|
|
| 118 |
"GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
|
| 119 |
"of the gains anyway.",
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
args = parser.parse_args()
|
| 122 |
|
| 123 |
# Validate embedder and vector store compatibility.
|
|
|
|
| 118 |
"GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
|
| 119 |
"of the gains anyway.",
|
| 120 |
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--hybrid-retrieval",
|
| 123 |
+
action=argparse.BooleanOptionalAction,
|
| 124 |
+
default=True,
|
| 125 |
+
help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
|
| 126 |
+
"retrieval. This is only relevant if using Pinecone as the vector store.",
|
| 127 |
+
)
|
| 128 |
args = parser.parse_args()
|
| 129 |
|
| 130 |
# Validate embedder and vector store compatibility.
|
sage/vector_store.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
"""Vector store abstraction and implementations."""
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
|
|
|
| 4 |
from typing import Dict, Generator, List, Tuple
|
| 5 |
|
| 6 |
import marqo
|
|
|
|
| 7 |
from langchain_community.vectorstores import Marqo
|
| 8 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 9 |
from langchain_core.documents import Document
|
| 10 |
from langchain_openai import OpenAIEmbeddings
|
| 11 |
-
from pinecone import Pinecone
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 14 |
|
|
@@ -36,34 +41,77 @@ class VectorStore(ABC):
|
|
| 36 |
self.upsert_batch(batch)
|
| 37 |
|
| 38 |
@abstractmethod
|
| 39 |
-
def
|
| 40 |
-
"""Converts the vector store to a LangChain
|
| 41 |
|
| 42 |
|
| 43 |
class PineconeVectorStore(VectorStore):
|
| 44 |
"""Vector store implementation using Pinecone."""
|
| 45 |
|
| 46 |
-
def __init__(self, index_name: str, namespace: str, dimension: int):
|
| 47 |
self.index_name = index_name
|
| 48 |
self.dimension = dimension
|
| 49 |
self.client = Pinecone()
|
| 50 |
-
self.index = self.client.Index(self.index_name)
|
| 51 |
self.namespace = namespace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def ensure_exists(self):
|
| 54 |
if self.index_name not in self.client.list_indexes().names():
|
| 55 |
-
self.client.create_index(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def upsert_batch(self, vectors: List[Vector]):
|
| 58 |
-
pinecone_vectors = [
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
| 62 |
|
| 63 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
return LangChainPinecone.from_existing_index(
|
| 65 |
index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
|
| 66 |
-
)
|
| 67 |
|
| 68 |
|
| 69 |
class MarqoVectorStore(VectorStore):
|
|
@@ -80,7 +128,7 @@ class MarqoVectorStore(VectorStore):
|
|
| 80 |
# Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
|
| 81 |
pass
|
| 82 |
|
| 83 |
-
def
|
| 84 |
vectorstore = Marqo(client=self.client, index_name=self.index_name)
|
| 85 |
|
| 86 |
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
|
|
@@ -88,21 +136,23 @@ class MarqoVectorStore(VectorStore):
|
|
| 88 |
def patched_method(self, results):
|
| 89 |
documents: List[Document] = []
|
| 90 |
for result in results["hits"]:
|
| 91 |
-
content = result.pop(
|
| 92 |
documents.append(Document(page_content=content, metadata=result))
|
| 93 |
return documents
|
| 94 |
|
| 95 |
vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
|
| 96 |
vectorstore, vectorstore.__class__
|
| 97 |
)
|
| 98 |
-
return vectorstore
|
| 99 |
|
| 100 |
|
| 101 |
def build_from_args(args: dict) -> VectorStore:
|
| 102 |
"""Builds a vector store from the given command-line arguments."""
|
| 103 |
if args.vector_store_type == "pinecone":
|
| 104 |
dimension = args.embedding_size if "embedding_size" in args else None
|
| 105 |
-
return PineconeVectorStore(
|
|
|
|
|
|
|
| 106 |
elif args.vector_store_type == "marqo":
|
| 107 |
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
|
| 108 |
else:
|
|
|
|
| 1 |
"""Vector store abstraction and implementations."""
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
+
from functools import cached_property
|
| 5 |
from typing import Dict, Generator, List, Tuple
|
| 6 |
|
| 7 |
import marqo
|
| 8 |
+
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
| 9 |
from langchain_community.vectorstores import Marqo
|
| 10 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
| 11 |
from langchain_core.documents import Document
|
| 12 |
from langchain_openai import OpenAIEmbeddings
|
| 13 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 14 |
+
from pinecone_text.sparse import BM25Encoder
|
| 15 |
+
|
| 16 |
+
from sage.constants import TEXT_FIELD
|
| 17 |
|
| 18 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 19 |
|
|
|
|
| 41 |
self.upsert_batch(batch)
|
| 42 |
|
| 43 |
@abstractmethod
|
| 44 |
+
def as_retriever(self, top_k: int):
|
| 45 |
+
"""Converts the vector store to a LangChain retriever object."""
|
| 46 |
|
| 47 |
|
| 48 |
class PineconeVectorStore(VectorStore):
|
| 49 |
"""Vector store implementation using Pinecone."""
|
| 50 |
|
| 51 |
+
def __init__(self, index_name: str, namespace: str, dimension: int, hybrid: bool = True):
|
| 52 |
self.index_name = index_name
|
| 53 |
self.dimension = dimension
|
| 54 |
self.client = Pinecone()
|
|
|
|
| 55 |
self.namespace = namespace
|
| 56 |
+
self.hybrid = hybrid
|
| 57 |
+
# The default BM25 encoder was fit in the MS MARCO dataset.
|
| 58 |
+
# See https://docs.pinecone.io/guides/data/encode-sparse-vectors
|
| 59 |
+
# In the future, we should fit the encoder on the current dataset. It's somewhat non-trivial for large datasets,
|
| 60 |
+
# because most BM25 implementations require the entire dataset to fit in memory.
|
| 61 |
+
self.bm25_encoder = BM25Encoder.default() if hybrid else None
|
| 62 |
+
|
| 63 |
+
@cached_property
|
| 64 |
+
def index(self):
|
| 65 |
+
self.ensure_exists()
|
| 66 |
+
index = self.client.Index(self.index_name)
|
| 67 |
+
|
| 68 |
+
# Hack around the fact that PineconeRetriever expects the content of the chunk to be in a "text" field,
|
| 69 |
+
# while PineconeHybridSearchRetrieve expects it to be in a "context" field.
|
| 70 |
+
original_query = index.query
|
| 71 |
+
|
| 72 |
+
def patched_query(*args, **kwargs):
|
| 73 |
+
result = original_query(*args, **kwargs)
|
| 74 |
+
for res in result["matches"]:
|
| 75 |
+
res["metadata"]["context"] = res["metadata"][TEXT_FIELD]
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
index.query = patched_query
|
| 79 |
+
return index
|
| 80 |
|
| 81 |
def ensure_exists(self):
|
| 82 |
if self.index_name not in self.client.list_indexes().names():
|
| 83 |
+
self.client.create_index(
|
| 84 |
+
name=self.index_name,
|
| 85 |
+
dimension=self.dimension,
|
| 86 |
+
# See https://www.pinecone.io/learn/hybrid-search-intro/
|
| 87 |
+
metric="dotproduct" if self.hybrid else "cosine",
|
| 88 |
+
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
| 89 |
+
)
|
| 90 |
|
| 91 |
def upsert_batch(self, vectors: List[Vector]):
|
| 92 |
+
pinecone_vectors = []
|
| 93 |
+
for i, (metadata, embedding) in enumerate(vectors):
|
| 94 |
+
vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata}
|
| 95 |
+
if self.bm25_encoder:
|
| 96 |
+
vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD])
|
| 97 |
+
pinecone_vectors.append(vector)
|
| 98 |
+
|
| 99 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
| 100 |
|
| 101 |
+
def as_retriever(self, top_k: int):
|
| 102 |
+
if self.bm25_encoder:
|
| 103 |
+
return PineconeHybridSearchRetriever(
|
| 104 |
+
embeddings=OpenAIEmbeddings(),
|
| 105 |
+
sparse_encoder=self.bm25_encoder,
|
| 106 |
+
index=self.index,
|
| 107 |
+
namespace=self.namespace,
|
| 108 |
+
top_k=top_k,
|
| 109 |
+
alpha=0.5,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
return LangChainPinecone.from_existing_index(
|
| 113 |
index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
|
| 114 |
+
).as_retriever(search_kwargs={"k": top_k})
|
| 115 |
|
| 116 |
|
| 117 |
class MarqoVectorStore(VectorStore):
|
|
|
|
| 128 |
# Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
|
| 129 |
pass
|
| 130 |
|
| 131 |
+
def as_retriever(self, top_k: int):
|
| 132 |
vectorstore = Marqo(client=self.client, index_name=self.index_name)
|
| 133 |
|
| 134 |
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
|
|
|
|
| 136 |
def patched_method(self, results):
|
| 137 |
documents: List[Document] = []
|
| 138 |
for result in results["hits"]:
|
| 139 |
+
content = result.pop(TEXT_FIELD)
|
| 140 |
documents.append(Document(page_content=content, metadata=result))
|
| 141 |
return documents
|
| 142 |
|
| 143 |
vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
|
| 144 |
vectorstore, vectorstore.__class__
|
| 145 |
)
|
| 146 |
+
return vectorstore.as_retriever(search_kwargs={"k": top_k})
|
| 147 |
|
| 148 |
|
| 149 |
def build_from_args(args: dict) -> VectorStore:
|
| 150 |
"""Builds a vector store from the given command-line arguments."""
|
| 151 |
if args.vector_store_type == "pinecone":
|
| 152 |
dimension = args.embedding_size if "embedding_size" in args else None
|
| 153 |
+
return PineconeVectorStore(
|
| 154 |
+
index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
|
| 155 |
+
)
|
| 156 |
elif args.vector_store_type == "marqo":
|
| 157 |
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
|
| 158 |
else:
|