juliaturc commited on
Commit
c3e6715
·
1 Parent(s): 82df3b5

Add option for Pinecone + BM25 hybrid retrieval. (#36)

Browse files
README.md CHANGED
@@ -80,7 +80,7 @@ pip install git+https://github.com/Storia-AI/sage.git@main
80
  export PINECONE_API_KEY=...
81
  ```
82
 
83
- 2. Create a Pinecone index [on their website](https://pinecone.io) and export the name:
84
  ```
85
  export PINECONE_INDEX_NAME=...
86
  ```
 
80
  export PINECONE_API_KEY=...
81
  ```
82
 
83
+ 2. Create a Pinecone account. Export the desired index name (if it doesn't exist yet, we'll create it):
84
  ```
85
  export PINECONE_INDEX_NAME=...
86
  ```
sage/chat.py CHANGED
@@ -28,7 +28,8 @@ def build_rag_chain(args):
28
  """Builds a RAG chain via LangChain."""
29
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
30
 
31
- retriever = vector_store.build_from_args(args).to_langchain().as_retriever(search_kwargs={"k": 25})
 
32
 
33
  if args.reranker_provider == "none":
34
  compressor = None
@@ -78,14 +79,6 @@ def build_rag_chain(args):
78
  return rag_chain
79
 
80
 
81
- def append_sources_to_response(response):
82
- """Given an OpenAI completion response, appends to it GitHub links of the context sources."""
83
- urls = [document.metadata["url"] for document in response["context"]]
84
- # Deduplicate urls while preserving their order.
85
- urls = list(dict.fromkeys(urls))
86
- return response["answer"] + "\n\nSources:\n" + "\n".join(urls)
87
-
88
-
89
  def main():
90
  parser = argparse.ArgumentParser(description="UI to chat with your codebase")
91
  parser.add_argument("repo_id", help="The ID of the repository to index")
@@ -112,6 +105,13 @@ def main():
112
  default=False,
113
  help="Whether to make the gradio app publicly accessible.",
114
  )
 
 
 
 
 
 
 
115
  args = parser.parse_args()
116
 
117
  if not args.index_name:
 
28
  """Builds a RAG chain via LangChain."""
29
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
30
 
31
+ retriever_top_k = 5 if args.reranker_provider == "none" else 25
32
+ retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
33
 
34
  if args.reranker_provider == "none":
35
  compressor = None
 
79
  return rag_chain
80
 
81
 
 
 
 
 
 
 
 
 
82
  def main():
83
  parser = argparse.ArgumentParser(description="UI to chat with your codebase")
84
  parser.add_argument("repo_id", help="The ID of the repository to index")
 
105
  default=False,
106
  help="Whether to make the gradio app publicly accessible.",
107
  )
108
+ parser.add_argument(
109
+ "--hybrid-retrieval",
110
+ action=argparse.BooleanOptionalAction,
111
+ default=True,
112
+ help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
113
+ "retrieval. This is only relevant if using Pinecone as the vector store.",
114
+ )
115
  args = parser.parse_args()
116
 
117
  if not args.index_name:
sage/chunker.py CHANGED
@@ -14,6 +14,8 @@ from semchunk import chunk as chunk_via_semchunk
14
  from tree_sitter import Node
15
  from tree_sitter_language_pack import get_parser
16
 
 
 
17
  logger = logging.getLogger(__name__)
18
  tokenizer = tiktoken.get_encoding("cl100k_base")
19
 
@@ -62,7 +64,7 @@ class FileChunk(Chunk):
62
  # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
63
  # size limit. In that case, you can simply store the start/end bytes above, and fetch the content
64
  # directly from the repository when needed.
65
- "text": self.content,
66
  }
67
  chunk_metadata.update(self.file_metadata)
68
  return chunk_metadata
 
14
  from tree_sitter import Node
15
  from tree_sitter_language_pack import get_parser
16
 
17
+ from sage.constants import TEXT_FIELD
18
+
19
  logger = logging.getLogger(__name__)
20
  tokenizer = tiktoken.get_encoding("cl100k_base")
21
 
 
64
  # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
65
  # size limit. In that case, you can simply store the start/end bytes above, and fetch the content
66
  # directly from the repository when needed.
67
+ TEXT_FIELD: self.content,
68
  }
69
  chunk_metadata.update(self.file_metadata)
70
  return chunk_metadata
sage/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This is the key in the metadata that points to the actual text content of a document or chunk.
2
+ # It can mostly be an arbitrary string, but certain classes in LangChain do expect it to be "text" specifically.
3
+ TEXT_FIELD = "text"
sage/embedder.py CHANGED
@@ -12,6 +12,7 @@ import marqo
12
  from openai import OpenAI
13
 
14
  from sage.chunker import Chunk, Chunker
 
15
  from sage.data_manager import DataManager
16
 
17
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
@@ -139,7 +140,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
139
  and "start_byte" in metadata
140
  and "end_byte" in metadata
141
  ):
142
- metadata.pop("text", None)
143
  embedding = datum["embedding"]
144
  yield (metadata, embedding)
145
 
@@ -240,7 +241,7 @@ class MarqoEmbedder(BatchEmbedder):
240
  logging.info("Indexing %d chunks...", len(sub_batch))
241
  self.index.add_documents(
242
  documents=[chunk.metadata for chunk in sub_batch],
243
- tensor_fields=["text"],
244
  )
245
  job_count += 1
246
 
@@ -251,7 +252,7 @@ class MarqoEmbedder(BatchEmbedder):
251
 
252
  # Finally, commit the last batch.
253
  if batch:
254
- self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=["text"])
255
  logging.info(f"Successfully embedded {chunk_count} chunks.")
256
 
257
  def embeddings_are_ready(self) -> bool:
 
12
  from openai import OpenAI
13
 
14
  from sage.chunker import Chunk, Chunker
15
+ from sage.constants import TEXT_FIELD
16
  from sage.data_manager import DataManager
17
 
18
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
 
140
  and "start_byte" in metadata
141
  and "end_byte" in metadata
142
  ):
143
+ metadata.pop(TEXT_FIELD, None)
144
  embedding = datum["embedding"]
145
  yield (metadata, embedding)
146
 
 
241
  logging.info("Indexing %d chunks...", len(sub_batch))
242
  self.index.add_documents(
243
  documents=[chunk.metadata for chunk in sub_batch],
244
+ tensor_fields=[TEXT_FIELD],
245
  )
246
  job_count += 1
247
 
 
252
 
253
  # Finally, commit the last batch.
254
  if batch:
255
+ self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD])
256
  logging.info(f"Successfully embedded {chunk_count} chunks.")
257
 
258
  def embeddings_are_ready(self) -> bool:
sage/github.py CHANGED
@@ -9,6 +9,7 @@ import requests
9
  import tiktoken
10
 
11
  from sage.chunker import Chunk, Chunker
 
12
  from sage.data_manager import DataManager
13
 
14
  tokenizer = tiktoken.get_encoding("cl100k_base")
@@ -179,7 +180,7 @@ class IssueChunk(Chunk):
179
  # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
180
  # size limit. In that case, you can simply store the start/end comment indices above, and fetch the
181
  # content of the issue on demand from the URL.
182
- "text": self.content,
183
  }
184
 
185
  @property
 
9
  import tiktoken
10
 
11
  from sage.chunker import Chunk, Chunker
12
+ from sage.constants import TEXT_FIELD
13
  from sage.data_manager import DataManager
14
 
15
  tokenizer = tiktoken.get_encoding("cl100k_base")
 
180
  # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
181
  # size limit. In that case, you can simply store the start/end comment indices above, and fetch the
182
  # content of the issue on demand from the URL.
183
+ TEXT_FIELD: self.content,
184
  }
185
 
186
  @property
sage/index.py CHANGED
@@ -118,6 +118,13 @@ def main():
118
  "GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
119
  "of the gains anyway.",
120
  )
 
 
 
 
 
 
 
121
  args = parser.parse_args()
122
 
123
  # Validate embedder and vector store compatibility.
 
118
  "GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
119
  "of the gains anyway.",
120
  )
121
+ parser.add_argument(
122
+ "--hybrid-retrieval",
123
+ action=argparse.BooleanOptionalAction,
124
+ default=True,
125
+ help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
126
+ "retrieval. This is only relevant if using Pinecone as the vector store.",
127
+ )
128
  args = parser.parse_args()
129
 
130
  # Validate embedder and vector store compatibility.
sage/vector_store.py CHANGED
@@ -1,14 +1,19 @@
1
  """Vector store abstraction and implementations."""
2
 
3
  from abc import ABC, abstractmethod
 
4
  from typing import Dict, Generator, List, Tuple
5
 
6
  import marqo
 
7
  from langchain_community.vectorstores import Marqo
8
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
9
  from langchain_core.documents import Document
10
  from langchain_openai import OpenAIEmbeddings
11
- from pinecone import Pinecone
 
 
 
12
 
13
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
14
 
@@ -36,34 +41,77 @@ class VectorStore(ABC):
36
  self.upsert_batch(batch)
37
 
38
  @abstractmethod
39
- def to_langchain(self):
40
- """Converts the vector store to a LangChain vector store object."""
41
 
42
 
43
  class PineconeVectorStore(VectorStore):
44
  """Vector store implementation using Pinecone."""
45
 
46
- def __init__(self, index_name: str, namespace: str, dimension: int):
47
  self.index_name = index_name
48
  self.dimension = dimension
49
  self.client = Pinecone()
50
- self.index = self.client.Index(self.index_name)
51
  self.namespace = namespace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def ensure_exists(self):
54
  if self.index_name not in self.client.list_indexes().names():
55
- self.client.create_index(name=self.index_name, dimension=self.dimension, metric="cosine")
 
 
 
 
 
 
56
 
57
  def upsert_batch(self, vectors: List[Vector]):
58
- pinecone_vectors = [
59
- (metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
60
- ]
 
 
 
 
61
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
62
 
63
- def to_langchain(self):
 
 
 
 
 
 
 
 
 
 
64
  return LangChainPinecone.from_existing_index(
65
  index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
66
- )
67
 
68
 
69
  class MarqoVectorStore(VectorStore):
@@ -80,7 +128,7 @@ class MarqoVectorStore(VectorStore):
80
  # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
81
  pass
82
 
83
- def to_langchain(self):
84
  vectorstore = Marqo(client=self.client, index_name=self.index_name)
85
 
86
  # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
@@ -88,21 +136,23 @@ class MarqoVectorStore(VectorStore):
88
  def patched_method(self, results):
89
  documents: List[Document] = []
90
  for result in results["hits"]:
91
- content = result.pop("text")
92
  documents.append(Document(page_content=content, metadata=result))
93
  return documents
94
 
95
  vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
96
  vectorstore, vectorstore.__class__
97
  )
98
- return vectorstore
99
 
100
 
101
  def build_from_args(args: dict) -> VectorStore:
102
  """Builds a vector store from the given command-line arguments."""
103
  if args.vector_store_type == "pinecone":
104
  dimension = args.embedding_size if "embedding_size" in args else None
105
- return PineconeVectorStore(index_name=args.index_name, namespace=args.repo_id, dimension=dimension)
 
 
106
  elif args.vector_store_type == "marqo":
107
  return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
108
  else:
 
1
  """Vector store abstraction and implementations."""
2
 
3
  from abc import ABC, abstractmethod
4
+ from functools import cached_property
5
  from typing import Dict, Generator, List, Tuple
6
 
7
  import marqo
8
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
9
  from langchain_community.vectorstores import Marqo
10
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
11
  from langchain_core.documents import Document
12
  from langchain_openai import OpenAIEmbeddings
13
+ from pinecone import Pinecone, ServerlessSpec
14
+ from pinecone_text.sparse import BM25Encoder
15
+
16
+ from sage.constants import TEXT_FIELD
17
 
18
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
19
 
 
41
  self.upsert_batch(batch)
42
 
43
  @abstractmethod
44
+ def as_retriever(self, top_k: int):
45
+ """Converts the vector store to a LangChain retriever object."""
46
 
47
 
48
  class PineconeVectorStore(VectorStore):
49
  """Vector store implementation using Pinecone."""
50
 
51
+ def __init__(self, index_name: str, namespace: str, dimension: int, hybrid: bool = True):
52
  self.index_name = index_name
53
  self.dimension = dimension
54
  self.client = Pinecone()
 
55
  self.namespace = namespace
56
+ self.hybrid = hybrid
57
+ # The default BM25 encoder was fit in the MS MARCO dataset.
58
+ # See https://docs.pinecone.io/guides/data/encode-sparse-vectors
59
+ # In the future, we should fit the encoder on the current dataset. It's somewhat non-trivial for large datasets,
60
+ # because most BM25 implementations require the entire dataset to fit in memory.
61
+ self.bm25_encoder = BM25Encoder.default() if hybrid else None
62
+
63
+ @cached_property
64
+ def index(self):
65
+ self.ensure_exists()
66
+ index = self.client.Index(self.index_name)
67
+
68
+ # Hack around the fact that PineconeRetriever expects the content of the chunk to be in a "text" field,
69
+ # while PineconeHybridSearchRetrieve expects it to be in a "context" field.
70
+ original_query = index.query
71
+
72
+ def patched_query(*args, **kwargs):
73
+ result = original_query(*args, **kwargs)
74
+ for res in result["matches"]:
75
+ res["metadata"]["context"] = res["metadata"][TEXT_FIELD]
76
+ return result
77
+
78
+ index.query = patched_query
79
+ return index
80
 
81
  def ensure_exists(self):
82
  if self.index_name not in self.client.list_indexes().names():
83
+ self.client.create_index(
84
+ name=self.index_name,
85
+ dimension=self.dimension,
86
+ # See https://www.pinecone.io/learn/hybrid-search-intro/
87
+ metric="dotproduct" if self.hybrid else "cosine",
88
+ spec=ServerlessSpec(cloud="aws", region="us-east-1"),
89
+ )
90
 
91
  def upsert_batch(self, vectors: List[Vector]):
92
+ pinecone_vectors = []
93
+ for i, (metadata, embedding) in enumerate(vectors):
94
+ vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata}
95
+ if self.bm25_encoder:
96
+ vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD])
97
+ pinecone_vectors.append(vector)
98
+
99
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
100
 
101
+ def as_retriever(self, top_k: int):
102
+ if self.bm25_encoder:
103
+ return PineconeHybridSearchRetriever(
104
+ embeddings=OpenAIEmbeddings(),
105
+ sparse_encoder=self.bm25_encoder,
106
+ index=self.index,
107
+ namespace=self.namespace,
108
+ top_k=top_k,
109
+ alpha=0.5,
110
+ )
111
+
112
  return LangChainPinecone.from_existing_index(
113
  index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
114
+ ).as_retriever(search_kwargs={"k": top_k})
115
 
116
 
117
  class MarqoVectorStore(VectorStore):
 
128
  # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
129
  pass
130
 
131
+ def as_retriever(self, top_k: int):
132
  vectorstore = Marqo(client=self.client, index_name=self.index_name)
133
 
134
  # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
 
136
  def patched_method(self, results):
137
  documents: List[Document] = []
138
  for result in results["hits"]:
139
+ content = result.pop(TEXT_FIELD)
140
  documents.append(Document(page_content=content, metadata=result))
141
  return documents
142
 
143
  vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
144
  vectorstore, vectorstore.__class__
145
  )
146
+ return vectorstore.as_retriever(search_kwargs={"k": top_k})
147
 
148
 
149
  def build_from_args(args: dict) -> VectorStore:
150
  """Builds a vector store from the given command-line arguments."""
151
  if args.vector_store_type == "pinecone":
152
  dimension = args.embedding_size if "embedding_size" in args else None
153
+ return PineconeVectorStore(
154
+ index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
155
+ )
156
  elif args.vector_store_type == "marqo":
157
  return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
158
  else: