Spaces:
Running
Running
Support marqo on the inference side and format code.
Browse files- src/chat.py +38 -18
- src/chunker.py +9 -22
- src/embedder.py +7 -15
- src/index.py +18 -23
- src/repo_manager.py +7 -21
- src/vector_store.py +3 -5
src/chat.py
CHANGED
|
@@ -4,14 +4,17 @@ You must run main.py first in order to index the codebase into a vector store.
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
-
|
| 8 |
-
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
| 14 |
-
from langchain_community.vectorstores import Pinecone
|
|
|
|
| 15 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 16 |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 17 |
|
|
@@ -24,10 +27,29 @@ def build_rag_chain(args):
|
|
| 24 |
"""Builds a RAG chain via LangChain."""
|
| 25 |
llm = ChatOpenAI(model=args.openai_model)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
retriever = vectorstore.as_retriever()
|
|
@@ -45,9 +67,7 @@ def build_rag_chain(args):
|
|
| 45 |
("human", "{input}"),
|
| 46 |
]
|
| 47 |
)
|
| 48 |
-
history_aware_retriever = create_history_aware_retriever(
|
| 49 |
-
llm, retriever, contextualize_q_prompt
|
| 50 |
-
)
|
| 51 |
|
| 52 |
qa_system_prompt = (
|
| 53 |
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
|
|
@@ -76,9 +96,7 @@ def append_sources_to_response(response):
|
|
| 76 |
# Deduplicate filenames while preserving their order.
|
| 77 |
filenames = list(dict.fromkeys(filenames))
|
| 78 |
repo_manager = RepoManager(args.repo_id)
|
| 79 |
-
github_links = [
|
| 80 |
-
repo_manager.github_link_for_file(filename) for filename in filenames
|
| 81 |
-
]
|
| 82 |
return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
|
| 83 |
|
| 84 |
|
|
@@ -90,8 +108,12 @@ if __name__ == "__main__":
|
|
| 90 |
default="gpt-4",
|
| 91 |
help="The OpenAI model to use for response generation",
|
| 92 |
)
|
|
|
|
|
|
|
| 93 |
parser.add_argument(
|
| 94 |
-
"--
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
parser.add_argument(
|
| 97 |
"--share",
|
|
@@ -109,9 +131,7 @@ if __name__ == "__main__":
|
|
| 109 |
history_langchain_format.append(HumanMessage(content=human))
|
| 110 |
history_langchain_format.append(AIMessage(content=ai))
|
| 111 |
history_langchain_format.append(HumanMessage(content=message))
|
| 112 |
-
response = rag_chain.invoke(
|
| 113 |
-
{"input": message, "chat_history": history_langchain_format}
|
| 114 |
-
)
|
| 115 |
answer = append_sources_to_response(response)
|
| 116 |
return answer
|
| 117 |
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
+
from typing import List
|
|
|
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
+
import marqo
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from langchain.chains import (create_history_aware_retriever,
|
| 13 |
+
create_retrieval_chain)
|
| 14 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 15 |
from langchain.schema import AIMessage, HumanMessage
|
| 16 |
+
from langchain_community.vectorstores import Marqo, Pinecone
|
| 17 |
+
from langchain_core.documents import Document
|
| 18 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 19 |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 20 |
|
|
|
|
| 27 |
"""Builds a RAG chain via LangChain."""
|
| 28 |
llm = ChatOpenAI(model=args.openai_model)
|
| 29 |
|
| 30 |
+
if args.vector_store_type == "pinecone":
|
| 31 |
+
vectorstore = Pinecone.from_existing_index(
|
| 32 |
+
index_name=args.pinecone_index_name,
|
| 33 |
+
embedding=OpenAIEmbeddings(),
|
| 34 |
+
namespace=args.repo_id,
|
| 35 |
+
)
|
| 36 |
+
elif args.vector_store_type == "marqo":
|
| 37 |
+
marqo_client = marqo.Client(url=args.marqo_url)
|
| 38 |
+
vectorstore = Marqo(
|
| 39 |
+
client=marqo_client,
|
| 40 |
+
index_name=args.index_name,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in the
|
| 44 |
+
# result, and instead take the "filename" directly from the result.
|
| 45 |
+
def patched_method(self, results):
|
| 46 |
+
documents: List[Document] = []
|
| 47 |
+
for res in results["hits"]:
|
| 48 |
+
documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
|
| 49 |
+
return documents
|
| 50 |
+
|
| 51 |
+
vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
|
| 52 |
+
vectorstore, vectorstore.__class__
|
| 53 |
)
|
| 54 |
|
| 55 |
retriever = vectorstore.as_retriever()
|
|
|
|
| 67 |
("human", "{input}"),
|
| 68 |
]
|
| 69 |
)
|
| 70 |
+
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
|
|
|
|
|
|
| 71 |
|
| 72 |
qa_system_prompt = (
|
| 73 |
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
|
|
|
|
| 96 |
# Deduplicate filenames while preserving their order.
|
| 97 |
filenames = list(dict.fromkeys(filenames))
|
| 98 |
repo_manager = RepoManager(args.repo_id)
|
| 99 |
+
github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
|
|
|
|
|
|
|
| 100 |
return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
|
| 101 |
|
| 102 |
|
|
|
|
| 108 |
default="gpt-4",
|
| 109 |
help="The OpenAI model to use for response generation",
|
| 110 |
)
|
| 111 |
+
parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
|
| 112 |
+
parser.add_argument("--index_name", required=True, help="Vector store index name")
|
| 113 |
parser.add_argument(
|
| 114 |
+
"--marqo_url",
|
| 115 |
+
default="http://localhost:8882",
|
| 116 |
+
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 117 |
)
|
| 118 |
parser.add_argument(
|
| 119 |
"--share",
|
|
|
|
| 131 |
history_langchain_format.append(HumanMessage(content=human))
|
| 132 |
history_langchain_format.append(AIMessage(content=ai))
|
| 133 |
history_langchain_format.append(HumanMessage(content=message))
|
| 134 |
+
response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
|
|
|
|
|
|
|
| 135 |
answer = append_sources_to_response(response)
|
| 136 |
return answer
|
| 137 |
|
src/chunker.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
"""Chunker abstraction and implementations."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
import nbformat
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from functools import lru_cache
|
| 8 |
from typing import List, Optional
|
| 9 |
|
|
|
|
| 10 |
import pygments
|
| 11 |
import tiktoken
|
| 12 |
from semchunk import chunk as chunk_via_semchunk
|
|
@@ -31,7 +31,7 @@ class Chunk:
|
|
| 31 |
return self._content
|
| 32 |
|
| 33 |
@property
|
| 34 |
-
def
|
| 35 |
"""Converts the chunk to a dictionary that can be passed to a vector store."""
|
| 36 |
# Some vector stores require the IDs to be ASCII.
|
| 37 |
filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
|
|
@@ -49,9 +49,7 @@ class Chunk:
|
|
| 49 |
|
| 50 |
def populate_content(self, file_content: str):
|
| 51 |
"""Populates the content of the chunk with the file path and file content."""
|
| 52 |
-
self._content =
|
| 53 |
-
self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
|
| 54 |
-
)
|
| 55 |
|
| 56 |
def num_tokens(self, tokenizer):
|
| 57 |
"""Counts the number of tokens in the chunk."""
|
|
@@ -115,9 +113,7 @@ class CodeChunker(Chunker):
|
|
| 115 |
|
| 116 |
if not node.children:
|
| 117 |
# This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
|
| 118 |
-
return self.text_chunker.chunk(
|
| 119 |
-
filename, file_content[node.start_byte : node.end_byte]
|
| 120 |
-
)
|
| 121 |
|
| 122 |
chunks = []
|
| 123 |
for child in node.children:
|
|
@@ -133,11 +129,7 @@ class CodeChunker(Chunker):
|
|
| 133 |
for chunk in chunks:
|
| 134 |
if not merged_chunks:
|
| 135 |
merged_chunks.append(chunk)
|
| 136 |
-
elif (
|
| 137 |
-
merged_chunks[-1].num_tokens(self.tokenizer)
|
| 138 |
-
+ chunk.num_tokens(self.tokenizer)
|
| 139 |
-
< self.max_tokens - 50
|
| 140 |
-
):
|
| 141 |
# There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
|
| 142 |
# at this point, because tokenization is not necessarily additive.
|
| 143 |
merged = Chunk(
|
|
@@ -203,9 +195,7 @@ class CodeChunker(Chunker):
|
|
| 203 |
# a bug in the code.
|
| 204 |
assert chunk.content
|
| 205 |
size = chunk.num_tokens(self.tokenizer)
|
| 206 |
-
assert
|
| 207 |
-
size <= self.max_tokens
|
| 208 |
-
), f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
|
| 209 |
|
| 210 |
return chunks
|
| 211 |
|
|
@@ -217,17 +207,13 @@ class TextChunker(Chunker):
|
|
| 217 |
self.max_tokens = max_tokens
|
| 218 |
|
| 219 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 220 |
-
self.count_tokens = lambda text: len(
|
| 221 |
-
tokenizer.encode(text, disallowed_special=())
|
| 222 |
-
)
|
| 223 |
|
| 224 |
def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
|
| 225 |
"""Chunks a text file into smaller pieces."""
|
| 226 |
# We need to allocate some tokens for the filename, which is part of the chunk content.
|
| 227 |
extra_tokens = self.count_tokens(file_path + "\n\n")
|
| 228 |
-
text_chunks = chunk_via_semchunk(
|
| 229 |
-
file_content, self.max_tokens - extra_tokens, self.count_tokens
|
| 230 |
-
)
|
| 231 |
|
| 232 |
chunks = []
|
| 233 |
start = 0
|
|
@@ -252,6 +238,7 @@ class IPYNBChunker(Chunker):
|
|
| 252 |
|
| 253 |
Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
|
| 254 |
"""
|
|
|
|
| 255 |
def __init__(self, code_chunker: CodeChunker):
|
| 256 |
self.code_chunker = code_chunker
|
| 257 |
|
|
|
|
| 1 |
"""Chunker abstraction and implementations."""
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from functools import lru_cache
|
| 7 |
from typing import List, Optional
|
| 8 |
|
| 9 |
+
import nbformat
|
| 10 |
import pygments
|
| 11 |
import tiktoken
|
| 12 |
from semchunk import chunk as chunk_via_semchunk
|
|
|
|
| 31 |
return self._content
|
| 32 |
|
| 33 |
@property
|
| 34 |
+
def to_metadata(self):
|
| 35 |
"""Converts the chunk to a dictionary that can be passed to a vector store."""
|
| 36 |
# Some vector stores require the IDs to be ASCII.
|
| 37 |
filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
|
|
|
|
| 49 |
|
| 50 |
def populate_content(self, file_content: str):
|
| 51 |
"""Populates the content of the chunk with the file path and file content."""
|
| 52 |
+
self._content = self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def num_tokens(self, tokenizer):
|
| 55 |
"""Counts the number of tokens in the chunk."""
|
|
|
|
| 113 |
|
| 114 |
if not node.children:
|
| 115 |
# This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
|
| 116 |
+
return self.text_chunker.chunk(filename, file_content[node.start_byte : node.end_byte])
|
|
|
|
|
|
|
| 117 |
|
| 118 |
chunks = []
|
| 119 |
for child in node.children:
|
|
|
|
| 129 |
for chunk in chunks:
|
| 130 |
if not merged_chunks:
|
| 131 |
merged_chunks.append(chunk)
|
| 132 |
+
elif merged_chunks[-1].num_tokens(self.tokenizer) + chunk.num_tokens(self.tokenizer) < self.max_tokens - 50:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
|
| 134 |
# at this point, because tokenization is not necessarily additive.
|
| 135 |
merged = Chunk(
|
|
|
|
| 195 |
# a bug in the code.
|
| 196 |
assert chunk.content
|
| 197 |
size = chunk.num_tokens(self.tokenizer)
|
| 198 |
+
assert size <= self.max_tokens, f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
|
|
|
|
|
|
|
| 199 |
|
| 200 |
return chunks
|
| 201 |
|
|
|
|
| 207 |
self.max_tokens = max_tokens
|
| 208 |
|
| 209 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 210 |
+
self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
|
|
|
|
|
|
|
| 211 |
|
| 212 |
def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
|
| 213 |
"""Chunks a text file into smaller pieces."""
|
| 214 |
# We need to allocate some tokens for the filename, which is part of the chunk content.
|
| 215 |
extra_tokens = self.count_tokens(file_path + "\n\n")
|
| 216 |
+
text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
|
|
|
|
|
|
|
| 217 |
|
| 218 |
chunks = []
|
| 219 |
start = 0
|
|
|
|
| 238 |
|
| 239 |
Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
|
| 240 |
"""
|
| 241 |
+
|
| 242 |
def __init__(self, code_chunker: CodeChunker):
|
| 243 |
self.code_chunker = code_chunker
|
| 244 |
|
src/embedder.py
CHANGED
|
@@ -7,11 +7,11 @@ from abc import ABC, abstractmethod
|
|
| 7 |
from collections import Counter
|
| 8 |
from typing import Dict, Generator, List, Tuple
|
| 9 |
|
|
|
|
| 10 |
from openai import OpenAI
|
| 11 |
|
| 12 |
from chunker import Chunk, Chunker
|
| 13 |
from repo_manager import RepoManager
|
| 14 |
-
import marqo
|
| 15 |
|
| 16 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 17 |
|
|
@@ -63,7 +63,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 63 |
openai_batch_id = self._issue_job_for_chunks(
|
| 64 |
sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
|
| 65 |
)
|
| 66 |
-
self.openai_batch_ids[openai_batch_id] = [chunk.
|
| 67 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 68 |
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 69 |
return
|
|
@@ -72,7 +72,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 72 |
# Finally, commit the last batch.
|
| 73 |
if batch:
|
| 74 |
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
|
| 75 |
-
self.openai_batch_ids[openai_batch_id] = [chunk.
|
| 76 |
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
| 77 |
|
| 78 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
|
@@ -179,12 +179,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 179 |
Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
|
| 180 |
"""
|
| 181 |
|
| 182 |
-
def __init__(self,
|
| 183 |
-
repo_manager: RepoManager,
|
| 184 |
-
chunker: Chunker,
|
| 185 |
-
index_name: str,
|
| 186 |
-
url: str,
|
| 187 |
-
model="hf/e5-base-v2"):
|
| 188 |
self.repo_manager = repo_manager
|
| 189 |
self.chunker = chunker
|
| 190 |
self.client = marqo.Client(url=url)
|
|
@@ -212,8 +207,8 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 212 |
sub_batch = batch[i : i + chunks_per_batch]
|
| 213 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 214 |
self.index.add_documents(
|
| 215 |
-
documents=[chunk.
|
| 216 |
-
tensor_fields=["text"]
|
| 217 |
)
|
| 218 |
|
| 219 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
|
@@ -223,10 +218,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 223 |
|
| 224 |
# Finally, commit the last batch.
|
| 225 |
if batch:
|
| 226 |
-
self.index.add_documents(
|
| 227 |
-
documents=[chunk.to_dict for chunk in batch],
|
| 228 |
-
tensor_fields=["text"]
|
| 229 |
-
)
|
| 230 |
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 231 |
|
| 232 |
def embeddings_are_ready(self) -> bool:
|
|
|
|
| 7 |
from collections import Counter
|
| 8 |
from typing import Dict, Generator, List, Tuple
|
| 9 |
|
| 10 |
+
import marqo
|
| 11 |
from openai import OpenAI
|
| 12 |
|
| 13 |
from chunker import Chunk, Chunker
|
| 14 |
from repo_manager import RepoManager
|
|
|
|
| 15 |
|
| 16 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 17 |
|
|
|
|
| 63 |
openai_batch_id = self._issue_job_for_chunks(
|
| 64 |
sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
|
| 65 |
)
|
| 66 |
+
self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in sub_batch]
|
| 67 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 68 |
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 69 |
return
|
|
|
|
| 72 |
# Finally, commit the last batch.
|
| 73 |
if batch:
|
| 74 |
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
|
| 75 |
+
self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in batch]
|
| 76 |
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
| 77 |
|
| 78 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
|
|
|
| 179 |
Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
|
| 180 |
"""
|
| 181 |
|
| 182 |
+
def __init__(self, repo_manager: RepoManager, chunker: Chunker, index_name: str, url: str, model="hf/e5-base-v2"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
self.repo_manager = repo_manager
|
| 184 |
self.chunker = chunker
|
| 185 |
self.client = marqo.Client(url=url)
|
|
|
|
| 207 |
sub_batch = batch[i : i + chunks_per_batch]
|
| 208 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 209 |
self.index.add_documents(
|
| 210 |
+
documents=[chunk.to_metadata for chunk in sub_batch],
|
| 211 |
+
tensor_fields=["text"],
|
| 212 |
)
|
| 213 |
|
| 214 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
|
|
|
| 218 |
|
| 219 |
# Finally, commit the last batch.
|
| 220 |
if batch:
|
| 221 |
+
self.index.add_documents(documents=[chunk.to_metadata for chunk in batch], tensor_fields=["text"])
|
|
|
|
|
|
|
|
|
|
| 222 |
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 223 |
|
| 224 |
def embeddings_are_ready(self) -> bool:
|
src/index.py
CHANGED
|
@@ -5,19 +5,15 @@ import logging
|
|
| 5 |
import time
|
| 6 |
|
| 7 |
from chunker import UniversalChunker
|
| 8 |
-
from embedder import
|
| 9 |
from repo_manager import RepoManager
|
| 10 |
from vector_store import PineconeVectorStore
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
| 14 |
OPENAI_EMBEDDING_SIZE = 1536
|
| 15 |
-
MAX_TOKENS_PER_CHUNK =
|
| 16 |
-
|
| 17 |
-
)
|
| 18 |
-
MAX_CHUNKS_PER_BATCH = (
|
| 19 |
-
2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
| 20 |
-
)
|
| 21 |
MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
| 22 |
|
| 23 |
|
|
@@ -43,11 +39,12 @@ def main():
|
|
| 43 |
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 44 |
)
|
| 45 |
parser.add_argument(
|
| 46 |
-
"--chunks_per_batch",
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
)
|
|
|
|
| 51 |
parser.add_argument(
|
| 52 |
"--include",
|
| 53 |
help="Path to a file containing a list of extensions to include. One extension per line.",
|
|
@@ -58,7 +55,8 @@ def main():
|
|
| 58 |
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 59 |
)
|
| 60 |
parser.add_argument(
|
| 61 |
-
"--max_embedding_jobs",
|
|
|
|
| 62 |
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 63 |
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 64 |
)
|
|
@@ -79,16 +77,15 @@ def main():
|
|
| 79 |
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
|
| 80 |
if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
|
| 81 |
parser.error("When using the marqo embedder, the vector store type must also be marqo.")
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Validate other arguments.
|
| 84 |
if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
|
| 85 |
-
parser.error(
|
| 86 |
-
f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}."
|
| 87 |
-
)
|
| 88 |
if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
|
| 89 |
-
parser.error(
|
| 90 |
-
f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}."
|
| 91 |
-
)
|
| 92 |
if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
|
| 93 |
parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
|
| 94 |
if args.include and args.exclude:
|
|
@@ -112,11 +109,9 @@ def main():
|
|
| 112 |
if args.embedder_type == "openai":
|
| 113 |
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 114 |
elif args.embedder_type == "marqo":
|
| 115 |
-
embedder = MarqoEmbedder(
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
url=args.marqo_url,
|
| 119 |
-
model=args.marqo_embedding_model)
|
| 120 |
else:
|
| 121 |
raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
|
| 122 |
|
|
|
|
| 5 |
import time
|
| 6 |
|
| 7 |
from chunker import UniversalChunker
|
| 8 |
+
from embedder import MarqoEmbedder, OpenAIBatchEmbedder
|
| 9 |
from repo_manager import RepoManager
|
| 10 |
from vector_store import PineconeVectorStore
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
| 14 |
OPENAI_EMBEDDING_SIZE = 1536
|
| 15 |
+
MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 16 |
+
MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
| 18 |
|
| 19 |
|
|
|
|
| 39 |
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 40 |
)
|
| 41 |
parser.add_argument(
|
| 42 |
+
"--chunks_per_batch",
|
| 43 |
+
type=int,
|
| 44 |
+
default=2000,
|
| 45 |
+
help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
|
| 46 |
)
|
| 47 |
+
parser.add_argument("--index_name", required=True, help="Vector store index name")
|
| 48 |
parser.add_argument(
|
| 49 |
"--include",
|
| 50 |
help="Path to a file containing a list of extensions to include. One extension per line.",
|
|
|
|
| 55 |
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 56 |
)
|
| 57 |
parser.add_argument(
|
| 58 |
+
"--max_embedding_jobs",
|
| 59 |
+
type=int,
|
| 60 |
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 61 |
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 62 |
)
|
|
|
|
| 77 |
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
|
| 78 |
if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
|
| 79 |
parser.error("When using the marqo embedder, the vector store type must also be marqo.")
|
| 80 |
+
if args.embedder_type == "marqo" and args.chunks_per_batch > 64:
|
| 81 |
+
args.chunks_per_batch = 64
|
| 82 |
+
logging.warning("Marqo enforces a limit of 64 chunks per batch. Setting --chunks_per_batch to 64.")
|
| 83 |
|
| 84 |
# Validate other arguments.
|
| 85 |
if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
|
| 86 |
+
parser.error(f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}.")
|
|
|
|
|
|
|
| 87 |
if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
|
| 88 |
+
parser.error(f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}.")
|
|
|
|
|
|
|
| 89 |
if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
|
| 90 |
parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
|
| 91 |
if args.include and args.exclude:
|
|
|
|
| 109 |
if args.embedder_type == "openai":
|
| 110 |
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 111 |
elif args.embedder_type == "marqo":
|
| 112 |
+
embedder = MarqoEmbedder(
|
| 113 |
+
repo_manager, chunker, index_name=args.index_name, url=args.marqo_url, model=args.marqo_embedding_model
|
| 114 |
+
)
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
|
| 117 |
|
src/repo_manager.py
CHANGED
|
@@ -35,9 +35,7 @@ class RepoManager:
|
|
| 35 |
@cached_property
|
| 36 |
def is_public(self) -> bool:
|
| 37 |
"""Checks whether a GitHub repository is publicly visible."""
|
| 38 |
-
response = requests.get(
|
| 39 |
-
f"https://api.github.com/repos/{self.repo_id}", timeout=10
|
| 40 |
-
)
|
| 41 |
# Note that the response will be 404 for both private and non-existent repos.
|
| 42 |
return response.status_code == 200
|
| 43 |
|
|
@@ -50,17 +48,13 @@ class RepoManager:
|
|
| 50 |
if self.access_token:
|
| 51 |
headers["Authorization"] = f"token {self.access_token}"
|
| 52 |
|
| 53 |
-
response = requests.get(
|
| 54 |
-
f"https://api.github.com/repos/{self.repo_id}", headers=headers
|
| 55 |
-
)
|
| 56 |
if response.status_code == 200:
|
| 57 |
branch = response.json().get("default_branch", "main")
|
| 58 |
else:
|
| 59 |
# This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
|
| 60 |
# most common naming for the default branch ("main").
|
| 61 |
-
logging.warn(
|
| 62 |
-
f"Unable to fetch default branch for {self.repo_id}: {response.text}"
|
| 63 |
-
)
|
| 64 |
branch = "main"
|
| 65 |
return branch
|
| 66 |
|
|
@@ -81,9 +75,7 @@ class RepoManager:
|
|
| 81 |
try:
|
| 82 |
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
|
| 83 |
except GitCommandError as e:
|
| 84 |
-
logging.error(
|
| 85 |
-
"Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e
|
| 86 |
-
)
|
| 87 |
return False
|
| 88 |
return True
|
| 89 |
|
|
@@ -130,9 +122,7 @@ class RepoManager:
|
|
| 130 |
for path in included_file_paths:
|
| 131 |
f.write(path + "\n")
|
| 132 |
|
| 133 |
-
excluded_file_paths = set(file_paths).difference(
|
| 134 |
-
set(included_file_paths)
|
| 135 |
-
)
|
| 136 |
with open(excluded_log_file, "a") as f:
|
| 137 |
for path in excluded_file_paths:
|
| 138 |
f.write(path + "\n")
|
|
@@ -142,15 +132,11 @@ class RepoManager:
|
|
| 142 |
try:
|
| 143 |
contents = f.read()
|
| 144 |
except UnicodeDecodeError:
|
| 145 |
-
logging.warning(
|
| 146 |
-
"Unable to decode file %s. Skipping.", file_path
|
| 147 |
-
)
|
| 148 |
continue
|
| 149 |
yield file_path[len(self.local_dir) + 1 :], contents
|
| 150 |
|
| 151 |
def github_link_for_file(self, file_path: str) -> str:
|
| 152 |
"""Converts a repository file path to a GitHub link."""
|
| 153 |
file_path = file_path[len(self.repo_id) :]
|
| 154 |
-
return
|
| 155 |
-
f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
| 156 |
-
)
|
|
|
|
| 35 |
@cached_property
|
| 36 |
def is_public(self) -> bool:
|
| 37 |
"""Checks whether a GitHub repository is publicly visible."""
|
| 38 |
+
response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
|
|
|
|
|
|
|
| 39 |
# Note that the response will be 404 for both private and non-existent repos.
|
| 40 |
return response.status_code == 200
|
| 41 |
|
|
|
|
| 48 |
if self.access_token:
|
| 49 |
headers["Authorization"] = f"token {self.access_token}"
|
| 50 |
|
| 51 |
+
response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
|
|
|
|
|
|
|
| 52 |
if response.status_code == 200:
|
| 53 |
branch = response.json().get("default_branch", "main")
|
| 54 |
else:
|
| 55 |
# This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
|
| 56 |
# most common naming for the default branch ("main").
|
| 57 |
+
logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
|
|
|
|
|
|
|
| 58 |
branch = "main"
|
| 59 |
return branch
|
| 60 |
|
|
|
|
| 75 |
try:
|
| 76 |
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
|
| 77 |
except GitCommandError as e:
|
| 78 |
+
logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
|
|
|
|
|
|
|
| 79 |
return False
|
| 80 |
return True
|
| 81 |
|
|
|
|
| 122 |
for path in included_file_paths:
|
| 123 |
f.write(path + "\n")
|
| 124 |
|
| 125 |
+
excluded_file_paths = set(file_paths).difference(set(included_file_paths))
|
|
|
|
|
|
|
| 126 |
with open(excluded_log_file, "a") as f:
|
| 127 |
for path in excluded_file_paths:
|
| 128 |
f.write(path + "\n")
|
|
|
|
| 132 |
try:
|
| 133 |
contents = f.read()
|
| 134 |
except UnicodeDecodeError:
|
| 135 |
+
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
|
|
|
|
|
|
| 136 |
continue
|
| 137 |
yield file_path[len(self.local_dir) + 1 :], contents
|
| 138 |
|
| 139 |
def github_link_for_file(self, file_path: str) -> str:
|
| 140 |
"""Converts a repository file path to a GitHub link."""
|
| 141 |
file_path = file_path[len(self.repo_id) :]
|
| 142 |
+
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
|
|
|
|
|
src/vector_store.py
CHANGED
|
@@ -10,6 +10,7 @@ Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
|
| 10 |
|
| 11 |
class VectorStore(ABC):
|
| 12 |
"""Abstract class for a vector store."""
|
|
|
|
| 13 |
@abstractmethod
|
| 14 |
def ensure_exists(self):
|
| 15 |
"""Ensures that the vector store exists. Creates it if it doesn't."""
|
|
@@ -42,13 +43,10 @@ class PineconeVectorStore(VectorStore):
|
|
| 42 |
|
| 43 |
def ensure_exists(self):
|
| 44 |
if self.index_name not in self.client.list_indexes().names():
|
| 45 |
-
self.client.create_index(
|
| 46 |
-
name=self.index_name, dimension=self.dimension, metric="cosine"
|
| 47 |
-
)
|
| 48 |
|
| 49 |
def upsert_batch(self, vectors: List[Vector]):
|
| 50 |
pinecone_vectors = [
|
| 51 |
-
(metadata.get("id", str(i)), embedding, metadata)
|
| 52 |
-
for i, (metadata, embedding) in enumerate(vectors)
|
| 53 |
]
|
| 54 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
|
|
|
| 10 |
|
| 11 |
class VectorStore(ABC):
|
| 12 |
"""Abstract class for a vector store."""
|
| 13 |
+
|
| 14 |
@abstractmethod
|
| 15 |
def ensure_exists(self):
|
| 16 |
"""Ensures that the vector store exists. Creates it if it doesn't."""
|
|
|
|
| 43 |
|
| 44 |
def ensure_exists(self):
|
| 45 |
if self.index_name not in self.client.list_indexes().names():
|
| 46 |
+
self.client.create_index(name=self.index_name, dimension=self.dimension, metric="cosine")
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def upsert_batch(self, vectors: List[Vector]):
|
| 49 |
pinecone_vectors = [
|
| 50 |
+
(metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
|
|
|
|
| 51 |
]
|
| 52 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|