Spaces:
Running
Running
Merge pull request #13 from Storia-AI/julia/marqo
Browse files- README.md +55 -27
- requirements.txt +1 -0
- src/chat.py +14 -23
- src/chunker.py +25 -21
- src/embedder.py +63 -22
- src/index.py +51 -28
- src/repo_manager.py +7 -21
- src/vector_store.py +59 -6
README.md
CHANGED
|
@@ -7,40 +7,68 @@
|
|
| 7 |
**Ok, but why chat with a codebase?**
|
| 8 |
|
| 9 |
Sometimes you just want to learn how a codebase works and how to integrate it, without spending hours sifting through
|
| 10 |
-
the code itself.
|
| 11 |
|
| 12 |
-
`repo2vec` is like GitHub Copilot but with the most up-to-date information about your repo.
|
| 13 |
|
| 14 |
-
Features:
|
| 15 |
- **Dead-simple set-up.** Run *two scripts* and you have a functional chat interface for your code. That's really it.
|
| 16 |
- **Heavily documented answers.** Every response shows where in the code the context for the answer was pulled from. Let's build trust in the AI.
|
| 17 |
- **Plug-and-play.** Want to improve the algorithms powering the code understanding/generation? We've made every component of the pipeline easily swappable. Customize to your heart's content.
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
-
export OPENAI_API_KEY=...
|
| 25 |
-
export PINECONE_API_KEY=...
|
| 26 |
-
export PINECONE_INDEX_NAME=...
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
If you want to publicly host your chat experience, set `--share=true`:
|
| 36 |
-
```
|
| 37 |
-
python src/chat.py $GITHUB_REPO_NAME --share=true ...
|
| 38 |
```
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-

|
| 44 |
|
| 45 |
# Peeking under the hood
|
| 46 |
|
|
@@ -50,10 +78,11 @@ The `src/index.py` script performs the following steps:
|
|
| 50 |
- Make sure to set the `GITHUB_TOKEN` environment variable for private repositories.
|
| 51 |
2. **Chunks files**. See [Chunker](src/chunker.py).
|
| 52 |
- For code files, we implement a special `CodeChunker` that takes the parse tree into account.
|
| 53 |
-
3. **Batch-embeds chunks**. See [Embedder](src/embedder.py).
|
| 54 |
-
-
|
|
|
|
| 55 |
4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
|
| 56 |
-
-
|
| 57 |
|
| 58 |
Note you can specify an inclusion or exclusion set for the file extensions you want indexed. To specify an extension inclusion set, you can add the `--include` flag:
|
| 59 |
```
|
|
@@ -77,10 +106,9 @@ The sources are conveniently surfaced in the chat and linked directly to GitHub.
|
|
| 77 |
|
| 78 |
# Want your repository hosted?
|
| 79 |
|
| 80 |
-
We're working to make all code on the internet searchable and understandable for devs.
|
| 81 |
-
your repository, we're onboarding a handful of repos onto our infrastructure **for free**.
|
| 82 |
|
| 83 |
-
|
| 84 |
|
| 85 |

|
| 86 |
|
|
|
|
| 7 |
**Ok, but why chat with a codebase?**
|
| 8 |
|
| 9 |
Sometimes you just want to learn how a codebase works and how to integrate it, without spending hours sifting through
|
| 10 |
+
the code itself.
|
| 11 |
|
| 12 |
+
`repo2vec` is like GitHub Copilot but with the most up-to-date information about your repo.
|
| 13 |
|
| 14 |
+
Features:
|
| 15 |
- **Dead-simple set-up.** Run *two scripts* and you have a functional chat interface for your code. That's really it.
|
| 16 |
- **Heavily documented answers.** Every response shows where in the code the context for the answer was pulled from. Let's build trust in the AI.
|
| 17 |
- **Plug-and-play.** Want to improve the algorithms powering the code understanding/generation? We've made every component of the pipeline easily swappable. Customize to your heart's content.
|
| 18 |
|
| 19 |
+
# How to run it
|
| 20 |
+
## Indexing the codebase
|
| 21 |
+
We currently support two options for indexing the codebase:
|
| 22 |
|
| 23 |
+
1. **Locally**, using the open-source [Marqo vector store](https://github.com/marqo-ai/marqo). Marqo is both an embedder (you can choose your favorite embedding model from Hugging Face) and a vector store.
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
You can bring up a Marqo instance using Docker:
|
| 26 |
+
```
|
| 27 |
+
docker rm -f marqo
|
| 28 |
+
docker pull marqoai/marqo:latest
|
| 29 |
+
docker run --name marqo -it -p 8882:8882 marqoai/marqo:latest
|
| 30 |
+
```
|
| 31 |
|
| 32 |
+
Then, to index your codebase, run:
|
| 33 |
+
```
|
| 34 |
+
pip install -r requirements.txt
|
| 35 |
+
|
| 36 |
+
python src/index.py
|
| 37 |
+
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 38 |
+
--embedder_type=marqo \
|
| 39 |
+
--vector_store_type=marqo \
|
| 40 |
+
--index_name=your-index-name
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
2. **Using external providers** (OpenAI for embeddings and [Pinecone](https://www.pinecone.io/) for the vector store). To index your codebase, run:
|
| 44 |
+
```
|
| 45 |
+
pip install -r requirements.txt
|
| 46 |
+
|
| 47 |
+
export OPENAI_API_KEY=...
|
| 48 |
+
export PINECONE_API_KEY=...
|
| 49 |
+
|
| 50 |
+
python src/index.py
|
| 51 |
+
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 52 |
+
--embedder_type=openai \
|
| 53 |
+
--vector_store_type=pinecone \
|
| 54 |
+
--index_name=your-index-name
|
| 55 |
+
```
|
| 56 |
+
We are planning on adding more providers soon, so that you can mix and match them. Contributions are also welcome!
|
| 57 |
+
|
| 58 |
+
## Chatting with the codebase
|
| 59 |
+
To bring a `gradio` app where you can chat with your codebase, simply point it to your vector store:
|
| 60 |
|
|
|
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
+
export OPENAI_API_KEY=...
|
| 63 |
|
| 64 |
+
python src/chat.py \
|
| 65 |
+
github-repo-name \ # e.g. Storia-AI/repo2vec
|
| 66 |
+
--vector_store_type=marqo \ # or pinecone
|
| 67 |
+
--index_name=your-index-name
|
| 68 |
+
```
|
| 69 |
+
To get a public URL for your chat app, set `--share=true`.
|
| 70 |
|
| 71 |
+
Currently, the chat will use OpenAI's GPT-4, but we are working on adding support for other providers and local LLMs. Stay tuned!
|
|
|
|
| 72 |
|
| 73 |
# Peeking under the hood
|
| 74 |
|
|
|
|
| 78 |
- Make sure to set the `GITHUB_TOKEN` environment variable for private repositories.
|
| 79 |
2. **Chunks files**. See [Chunker](src/chunker.py).
|
| 80 |
- For code files, we implement a special `CodeChunker` that takes the parse tree into account.
|
| 81 |
+
3. **Batch-embeds chunks**. See [Embedder](src/embedder.py). We currently support:
|
| 82 |
+
- [Marqo](https://github.com/marqo-ai/marqo) as an embedder, which allows you to specify your favorite Hugging Face embedding model;
|
| 83 |
+
- OpenAI's [batch embedding API](https://platform.openai.com/docs/guides/batch/overview), which is much faster and cheaper than the regular synchronous embedding API.
|
| 84 |
4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
|
| 85 |
+
- We currently support [Marqo](https://github.com/marqo-ai/marqo) and [Pinecone](https://pinecone.io), but you can easily plug in your own.
|
| 86 |
|
| 87 |
Note you can specify an inclusion or exclusion set for the file extensions you want indexed. To specify an extension inclusion set, you can add the `--include` flag:
|
| 88 |
```
|
|
|
|
| 106 |
|
| 107 |
# Want your repository hosted?
|
| 108 |
|
| 109 |
+
We're working to make all code on the internet searchable and understandable for devs. You can check out our early product, [Code Sage](https://sage.storia.ai). We pre-indexed a slew of OSS repos, and you can index your desired ones by simply pasting a GitHub URL.
|
|
|
|
| 110 |
|
| 111 |
+
If you're the maintainer of an OSS repo and would like a dedicated page on Code Sage (e.g. `sage.storia.ai/your-repo`), then send us a message at [founders@storia.ai](mailto:founders@storia.ai). We'll do it for free!
|
| 112 |
|
| 113 |

|
| 114 |
|
requirements.txt
CHANGED
|
@@ -4,6 +4,7 @@ gradio==4.42.0
|
|
| 4 |
langchain==0.2.14
|
| 5 |
langchain-community==0.2.12
|
| 6 |
langchain-openai==0.1.22
|
|
|
|
| 7 |
nbformat==5.10.4
|
| 8 |
openai==1.42.0
|
| 9 |
pinecone==5.0.1
|
|
|
|
| 4 |
langchain==0.2.14
|
| 5 |
langchain-community==0.2.12
|
| 6 |
langchain-openai==0.1.22
|
| 7 |
+
marqo==3.7.0
|
| 8 |
nbformat==5.10.4
|
| 9 |
openai==1.42.0
|
| 10 |
pinecone==5.0.1
|
src/chat.py
CHANGED
|
@@ -5,16 +5,16 @@ You must run main.py first in order to index the codebase into a vector store.
|
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
|
| 8 |
-
from dotenv import load_dotenv
|
| 9 |
-
|
| 10 |
import gradio as gr
|
| 11 |
-
from
|
|
|
|
|
|
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
| 14 |
-
from langchain_community.vectorstores import Pinecone
|
| 15 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 16 |
-
from langchain_openai import ChatOpenAI
|
| 17 |
|
|
|
|
| 18 |
from repo_manager import RepoManager
|
| 19 |
|
| 20 |
load_dotenv()
|
|
@@ -23,14 +23,7 @@ load_dotenv()
|
|
| 23 |
def build_rag_chain(args):
|
| 24 |
"""Builds a RAG chain via LangChain."""
|
| 25 |
llm = ChatOpenAI(model=args.openai_model)
|
| 26 |
-
|
| 27 |
-
vectorstore = Pinecone.from_existing_index(
|
| 28 |
-
index_name=args.pinecone_index_name,
|
| 29 |
-
embedding=OpenAIEmbeddings(),
|
| 30 |
-
namespace=args.repo_id,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
retriever = vectorstore.as_retriever()
|
| 34 |
|
| 35 |
# Prompt to contextualize the latest query based on the chat history.
|
| 36 |
contextualize_q_system_prompt = (
|
|
@@ -45,9 +38,7 @@ def build_rag_chain(args):
|
|
| 45 |
("human", "{input}"),
|
| 46 |
]
|
| 47 |
)
|
| 48 |
-
history_aware_retriever = create_history_aware_retriever(
|
| 49 |
-
llm, retriever, contextualize_q_prompt
|
| 50 |
-
)
|
| 51 |
|
| 52 |
qa_system_prompt = (
|
| 53 |
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
|
|
@@ -76,9 +67,7 @@ def append_sources_to_response(response):
|
|
| 76 |
# Deduplicate filenames while preserving their order.
|
| 77 |
filenames = list(dict.fromkeys(filenames))
|
| 78 |
repo_manager = RepoManager(args.repo_id)
|
| 79 |
-
github_links = [
|
| 80 |
-
repo_manager.github_link_for_file(filename) for filename in filenames
|
| 81 |
-
]
|
| 82 |
return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
|
| 83 |
|
| 84 |
|
|
@@ -90,8 +79,12 @@ if __name__ == "__main__":
|
|
| 90 |
default="gpt-4",
|
| 91 |
help="The OpenAI model to use for response generation",
|
| 92 |
)
|
|
|
|
|
|
|
| 93 |
parser.add_argument(
|
| 94 |
-
"--
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
parser.add_argument(
|
| 97 |
"--share",
|
|
@@ -109,9 +102,7 @@ if __name__ == "__main__":
|
|
| 109 |
history_langchain_format.append(HumanMessage(content=human))
|
| 110 |
history_langchain_format.append(AIMessage(content=ai))
|
| 111 |
history_langchain_format.append(HumanMessage(content=message))
|
| 112 |
-
response = rag_chain.invoke(
|
| 113 |
-
{"input": message, "chat_history": history_langchain_format}
|
| 114 |
-
)
|
| 115 |
answer = append_sources_to_response(response)
|
| 116 |
return answer
|
| 117 |
|
|
|
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
|
|
|
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from langchain.chains import (create_history_aware_retriever,
|
| 11 |
+
create_retrieval_chain)
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
|
|
|
| 14 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
+
from langchain_openai import ChatOpenAI
|
| 16 |
|
| 17 |
+
import vector_store
|
| 18 |
from repo_manager import RepoManager
|
| 19 |
|
| 20 |
load_dotenv()
|
|
|
|
| 23 |
def build_rag_chain(args):
|
| 24 |
"""Builds a RAG chain via LangChain."""
|
| 25 |
llm = ChatOpenAI(model=args.openai_model)
|
| 26 |
+
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Prompt to contextualize the latest query based on the chat history.
|
| 29 |
contextualize_q_system_prompt = (
|
|
|
|
| 38 |
("human", "{input}"),
|
| 39 |
]
|
| 40 |
)
|
| 41 |
+
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
|
|
|
|
|
|
| 42 |
|
| 43 |
qa_system_prompt = (
|
| 44 |
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
|
|
|
|
| 67 |
# Deduplicate filenames while preserving their order.
|
| 68 |
filenames = list(dict.fromkeys(filenames))
|
| 69 |
repo_manager = RepoManager(args.repo_id)
|
| 70 |
+
github_links = [repo_manager.github_link_for_file(filename) for filename in filenames]
|
|
|
|
|
|
|
| 71 |
return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
|
| 72 |
|
| 73 |
|
|
|
|
| 79 |
default="gpt-4",
|
| 80 |
help="The OpenAI model to use for response generation",
|
| 81 |
)
|
| 82 |
+
parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
|
| 83 |
+
parser.add_argument("--index_name", required=True, help="Vector store index name")
|
| 84 |
parser.add_argument(
|
| 85 |
+
"--marqo_url",
|
| 86 |
+
default="http://localhost:8882",
|
| 87 |
+
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 88 |
)
|
| 89 |
parser.add_argument(
|
| 90 |
"--share",
|
|
|
|
| 102 |
history_langchain_format.append(HumanMessage(content=human))
|
| 103 |
history_langchain_format.append(AIMessage(content=ai))
|
| 104 |
history_langchain_format.append(HumanMessage(content=message))
|
| 105 |
+
response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
|
|
|
|
|
|
|
| 106 |
answer = append_sources_to_response(response)
|
| 107 |
return answer
|
| 108 |
|
src/chunker.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
"""Chunker abstraction and implementations."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
import nbformat
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from functools import lru_cache
|
| 8 |
from typing import List, Optional
|
| 9 |
|
|
|
|
| 10 |
import pygments
|
| 11 |
import tiktoken
|
| 12 |
from semchunk import chunk as chunk_via_semchunk
|
|
@@ -30,11 +30,26 @@ class Chunk:
|
|
| 30 |
"""The text content to be embedded. Might contain information beyond just the text snippet from the file."""
|
| 31 |
return self._content
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def populate_content(self, file_content: str):
|
| 34 |
"""Populates the content of the chunk with the file path and file content."""
|
| 35 |
-
self._content =
|
| 36 |
-
self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
|
| 37 |
-
)
|
| 38 |
|
| 39 |
def num_tokens(self, tokenizer):
|
| 40 |
"""Counts the number of tokens in the chunk."""
|
|
@@ -98,9 +113,7 @@ class CodeChunker(Chunker):
|
|
| 98 |
|
| 99 |
if not node.children:
|
| 100 |
# This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
|
| 101 |
-
return self.text_chunker.chunk(
|
| 102 |
-
filename, file_content[node.start_byte : node.end_byte]
|
| 103 |
-
)
|
| 104 |
|
| 105 |
chunks = []
|
| 106 |
for child in node.children:
|
|
@@ -116,11 +129,7 @@ class CodeChunker(Chunker):
|
|
| 116 |
for chunk in chunks:
|
| 117 |
if not merged_chunks:
|
| 118 |
merged_chunks.append(chunk)
|
| 119 |
-
elif (
|
| 120 |
-
merged_chunks[-1].num_tokens(self.tokenizer)
|
| 121 |
-
+ chunk.num_tokens(self.tokenizer)
|
| 122 |
-
< self.max_tokens - 50
|
| 123 |
-
):
|
| 124 |
# There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
|
| 125 |
# at this point, because tokenization is not necessarily additive.
|
| 126 |
merged = Chunk(
|
|
@@ -186,9 +195,7 @@ class CodeChunker(Chunker):
|
|
| 186 |
# a bug in the code.
|
| 187 |
assert chunk.content
|
| 188 |
size = chunk.num_tokens(self.tokenizer)
|
| 189 |
-
assert
|
| 190 |
-
size <= self.max_tokens
|
| 191 |
-
), f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
|
| 192 |
|
| 193 |
return chunks
|
| 194 |
|
|
@@ -200,17 +207,13 @@ class TextChunker(Chunker):
|
|
| 200 |
self.max_tokens = max_tokens
|
| 201 |
|
| 202 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 203 |
-
self.count_tokens = lambda text: len(
|
| 204 |
-
tokenizer.encode(text, disallowed_special=())
|
| 205 |
-
)
|
| 206 |
|
| 207 |
def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
|
| 208 |
"""Chunks a text file into smaller pieces."""
|
| 209 |
# We need to allocate some tokens for the filename, which is part of the chunk content.
|
| 210 |
extra_tokens = self.count_tokens(file_path + "\n\n")
|
| 211 |
-
text_chunks = chunk_via_semchunk(
|
| 212 |
-
file_content, self.max_tokens - extra_tokens, self.count_tokens
|
| 213 |
-
)
|
| 214 |
|
| 215 |
chunks = []
|
| 216 |
start = 0
|
|
@@ -235,6 +238,7 @@ class IPYNBChunker(Chunker):
|
|
| 235 |
|
| 236 |
Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
|
| 237 |
"""
|
|
|
|
| 238 |
def __init__(self, code_chunker: CodeChunker):
|
| 239 |
self.code_chunker = code_chunker
|
| 240 |
|
|
|
|
| 1 |
"""Chunker abstraction and implementations."""
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from functools import lru_cache
|
| 7 |
from typing import List, Optional
|
| 8 |
|
| 9 |
+
import nbformat
|
| 10 |
import pygments
|
| 11 |
import tiktoken
|
| 12 |
from semchunk import chunk as chunk_via_semchunk
|
|
|
|
| 30 |
"""The text content to be embedded. Might contain information beyond just the text snippet from the file."""
|
| 31 |
return self._content
|
| 32 |
|
| 33 |
+
@property
|
| 34 |
+
def to_metadata(self):
|
| 35 |
+
"""Converts the chunk to a dictionary that can be passed to a vector store."""
|
| 36 |
+
# Some vector stores require the IDs to be ASCII.
|
| 37 |
+
filename_ascii = self.filename.encode("ascii", "ignore").decode("ascii")
|
| 38 |
+
return {
|
| 39 |
+
# Some vector stores require the IDs to be ASCII.
|
| 40 |
+
"id": f"{filename_ascii}_{self.start_byte}_{self.end_byte}",
|
| 41 |
+
"filename": self.filename,
|
| 42 |
+
"start_byte": self.start_byte,
|
| 43 |
+
"end_byte": self.end_byte,
|
| 44 |
+
# Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
|
| 45 |
+
# size limit. In that case, you can simply store the start/end bytes above, and fetch the content
|
| 46 |
+
# directly from the repository when needed.
|
| 47 |
+
"text": self.content,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
def populate_content(self, file_content: str):
|
| 51 |
"""Populates the content of the chunk with the file path and file content."""
|
| 52 |
+
self._content = self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def num_tokens(self, tokenizer):
|
| 55 |
"""Counts the number of tokens in the chunk."""
|
|
|
|
| 113 |
|
| 114 |
if not node.children:
|
| 115 |
# This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
|
| 116 |
+
return self.text_chunker.chunk(filename, file_content[node.start_byte : node.end_byte])
|
|
|
|
|
|
|
| 117 |
|
| 118 |
chunks = []
|
| 119 |
for child in node.children:
|
|
|
|
| 129 |
for chunk in chunks:
|
| 130 |
if not merged_chunks:
|
| 131 |
merged_chunks.append(chunk)
|
| 132 |
+
elif merged_chunks[-1].num_tokens(self.tokenizer) + chunk.num_tokens(self.tokenizer) < self.max_tokens - 50:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
|
| 134 |
# at this point, because tokenization is not necessarily additive.
|
| 135 |
merged = Chunk(
|
|
|
|
| 195 |
# a bug in the code.
|
| 196 |
assert chunk.content
|
| 197 |
size = chunk.num_tokens(self.tokenizer)
|
| 198 |
+
assert size <= self.max_tokens, f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
|
|
|
|
|
|
|
| 199 |
|
| 200 |
return chunks
|
| 201 |
|
|
|
|
| 207 |
self.max_tokens = max_tokens
|
| 208 |
|
| 209 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
| 210 |
+
self.count_tokens = lambda text: len(tokenizer.encode(text, disallowed_special=()))
|
|
|
|
|
|
|
| 211 |
|
| 212 |
def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
|
| 213 |
"""Chunks a text file into smaller pieces."""
|
| 214 |
# We need to allocate some tokens for the filename, which is part of the chunk content.
|
| 215 |
extra_tokens = self.count_tokens(file_path + "\n\n")
|
| 216 |
+
text_chunks = chunk_via_semchunk(file_content, self.max_tokens - extra_tokens, self.count_tokens)
|
|
|
|
|
|
|
| 217 |
|
| 218 |
chunks = []
|
| 219 |
start = 0
|
|
|
|
| 238 |
|
| 239 |
Based on https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/code/code_retrieval_augmented_generation.ipynb
|
| 240 |
"""
|
| 241 |
+
|
| 242 |
def __init__(self, code_chunker: CodeChunker):
|
| 243 |
self.code_chunker = code_chunker
|
| 244 |
|
src/embedder.py
CHANGED
|
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
|
| 7 |
from collections import Counter
|
| 8 |
from typing import Dict, Generator, List, Tuple
|
| 9 |
|
|
|
|
| 10 |
from openai import OpenAI
|
| 11 |
|
| 12 |
from chunker import Chunk, Chunker
|
|
@@ -19,7 +20,7 @@ class BatchEmbedder(ABC):
|
|
| 19 |
"""Abstract class for batch embedding of a repository."""
|
| 20 |
|
| 21 |
@abstractmethod
|
| 22 |
-
def embed_repo(self, chunks_per_batch: int):
|
| 23 |
"""Issues batch embedding jobs for the entire repository."""
|
| 24 |
|
| 25 |
@abstractmethod
|
|
@@ -62,7 +63,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 62 |
openai_batch_id = self._issue_job_for_chunks(
|
| 63 |
sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
|
| 64 |
)
|
| 65 |
-
self.openai_batch_ids[openai_batch_id] =
|
| 66 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 67 |
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 68 |
return
|
|
@@ -71,7 +72,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 71 |
# Finally, commit the last batch.
|
| 72 |
if batch:
|
| 73 |
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
|
| 74 |
-
self.openai_batch_ids[openai_batch_id] =
|
| 75 |
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
| 76 |
|
| 77 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
|
@@ -171,22 +172,62 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 171 |
},
|
| 172 |
}
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from collections import Counter
|
| 8 |
from typing import Dict, Generator, List, Tuple
|
| 9 |
|
| 10 |
+
import marqo
|
| 11 |
from openai import OpenAI
|
| 12 |
|
| 13 |
from chunker import Chunk, Chunker
|
|
|
|
| 20 |
"""Abstract class for batch embedding of a repository."""
|
| 21 |
|
| 22 |
@abstractmethod
|
| 23 |
+
def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 24 |
"""Issues batch embedding jobs for the entire repository."""
|
| 25 |
|
| 26 |
@abstractmethod
|
|
|
|
| 63 |
openai_batch_id = self._issue_job_for_chunks(
|
| 64 |
sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
|
| 65 |
)
|
| 66 |
+
self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in sub_batch]
|
| 67 |
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 68 |
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 69 |
return
|
|
|
|
| 72 |
# Finally, commit the last batch.
|
| 73 |
if batch:
|
| 74 |
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
|
| 75 |
+
self.openai_batch_ids[openai_batch_id] = [chunk.to_metadata for chunk in batch]
|
| 76 |
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
| 77 |
|
| 78 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
|
|
|
| 172 |
},
|
| 173 |
}
|
| 174 |
|
| 175 |
+
|
| 176 |
+
class MarqoEmbedder(BatchEmbedder):
|
| 177 |
+
"""Embedder that uses the open-source Marqo vector search engine.
|
| 178 |
+
|
| 179 |
+
Embeddings can be stored locally (in which case `url` the constructor should point to localhost) or in the cloud.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, repo_manager: RepoManager, chunker: Chunker, index_name: str, url: str, model="hf/e5-base-v2"):
|
| 183 |
+
self.repo_manager = repo_manager
|
| 184 |
+
self.chunker = chunker
|
| 185 |
+
self.client = marqo.Client(url=url)
|
| 186 |
+
self.index = self.client.index(index_name)
|
| 187 |
+
|
| 188 |
+
all_index_names = [result["indexName"] for result in self.client.get_indexes()["results"]]
|
| 189 |
+
if not index_name in all_index_names:
|
| 190 |
+
self.client.create_index(index_name, model=model)
|
| 191 |
+
|
| 192 |
+
def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 193 |
+
"""Issues batch embedding jobs for the entire repository."""
|
| 194 |
+
if chunks_per_batch > 64:
|
| 195 |
+
raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
|
| 196 |
+
|
| 197 |
+
chunk_count = 0
|
| 198 |
+
batch = []
|
| 199 |
+
|
| 200 |
+
for filepath, content in self.repo_manager.walk():
|
| 201 |
+
chunks = self.chunker.chunk(filepath, content)
|
| 202 |
+
chunk_count += len(chunks)
|
| 203 |
+
batch.extend(chunks)
|
| 204 |
+
|
| 205 |
+
if len(batch) > chunks_per_batch:
|
| 206 |
+
for i in range(0, len(batch), chunks_per_batch):
|
| 207 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 208 |
+
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 209 |
+
self.index.add_documents(
|
| 210 |
+
documents=[chunk.to_metadata for chunk in sub_batch],
|
| 211 |
+
tensor_fields=["text"],
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 215 |
+
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 216 |
+
return
|
| 217 |
+
batch = []
|
| 218 |
+
|
| 219 |
+
# Finally, commit the last batch.
|
| 220 |
+
if batch:
|
| 221 |
+
self.index.add_documents(documents=[chunk.to_metadata for chunk in batch], tensor_fields=["text"])
|
| 222 |
+
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 223 |
+
|
| 224 |
+
def embeddings_are_ready(self) -> bool:
|
| 225 |
+
"""Checks whether the batch embedding jobs are done."""
|
| 226 |
+
# Marqo indexes documents synchronously, so once embed_repo() returns, the embeddings are ready.
|
| 227 |
+
return True
|
| 228 |
+
|
| 229 |
+
def download_embeddings(self) -> Generator[Vector, None, None]:
|
| 230 |
+
"""Yields (chunk_metadata, embedding) pairs for each chunk in the repository."""
|
| 231 |
+
# Marqo stores embeddings as they are created, so they're already in the vector store. No need to download them
|
| 232 |
+
# as we would with e.g. OpenAI, Cohere, or some other cloud-based embedding service.
|
| 233 |
+
return []
|
src/index.py
CHANGED
|
@@ -5,19 +5,14 @@ import logging
|
|
| 5 |
import time
|
| 6 |
|
| 7 |
from chunker import UniversalChunker
|
| 8 |
-
from embedder import OpenAIBatchEmbedder
|
| 9 |
from repo_manager import RepoManager
|
| 10 |
-
from vector_store import
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 17 |
-
)
|
| 18 |
-
MAX_CHUNKS_PER_BATCH = (
|
| 19 |
-
2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
| 20 |
-
)
|
| 21 |
MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
| 22 |
|
| 23 |
|
|
@@ -29,6 +24,8 @@ def _read_extensions(path):
|
|
| 29 |
def main():
|
| 30 |
parser = argparse.ArgumentParser(description="Batch-embeds a repository")
|
| 31 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
|
|
|
|
|
|
| 32 |
parser.add_argument(
|
| 33 |
"--local_dir",
|
| 34 |
default="repos",
|
|
@@ -41,11 +38,12 @@ def main():
|
|
| 41 |
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 42 |
)
|
| 43 |
parser.add_argument(
|
| 44 |
-
"--chunks_per_batch",
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
)
|
|
|
|
| 49 |
parser.add_argument(
|
| 50 |
"--include",
|
| 51 |
help="Path to a file containing a list of extensions to include. One extension per line.",
|
|
@@ -56,22 +54,37 @@ def main():
|
|
| 56 |
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 57 |
)
|
| 58 |
parser.add_argument(
|
| 59 |
-
"--max_embedding_jobs",
|
|
|
|
| 60 |
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 61 |
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 62 |
)
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
args = parser.parse_args()
|
| 65 |
|
| 66 |
-
# Validate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
|
| 68 |
-
parser.error(
|
| 69 |
-
f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}."
|
| 70 |
-
)
|
| 71 |
if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
|
| 72 |
-
parser.error(
|
| 73 |
-
f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}."
|
| 74 |
-
)
|
| 75 |
if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
|
| 76 |
parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
|
| 77 |
if args.include and args.exclude:
|
|
@@ -91,9 +104,23 @@ def main():
|
|
| 91 |
|
| 92 |
logging.info("Issuing embedding jobs...")
|
| 93 |
chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
logging.info("Waiting for embeddings to be ready...")
|
| 98 |
while not embedder.embeddings_are_ready():
|
| 99 |
logging.info("Sleeping for 30 seconds...")
|
|
@@ -101,11 +128,7 @@ def main():
|
|
| 101 |
|
| 102 |
logging.info("Moving embeddings to the vector store...")
|
| 103 |
# Note to developer: Replace this with your preferred vector store.
|
| 104 |
-
vector_store =
|
| 105 |
-
index_name=args.pinecone_index_name,
|
| 106 |
-
dimension=OPENAI_EMBEDDING_SIZE,
|
| 107 |
-
namespace=repo_manager.repo_id,
|
| 108 |
-
)
|
| 109 |
vector_store.ensure_exists()
|
| 110 |
vector_store.upsert(embedder.download_embeddings())
|
| 111 |
logging.info("Done!")
|
|
|
|
| 5 |
import time
|
| 6 |
|
| 7 |
from chunker import UniversalChunker
|
| 8 |
+
from embedder import MarqoEmbedder, OpenAIBatchEmbedder
|
| 9 |
from repo_manager import RepoManager
|
| 10 |
+
from vector_store import build_from_args
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
| 14 |
+
MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 15 |
+
MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
| 17 |
|
| 18 |
|
|
|
|
| 24 |
def main():
|
| 25 |
parser = argparse.ArgumentParser(description="Batch-embeds a repository")
|
| 26 |
parser.add_argument("repo_id", help="The ID of the repository to index")
|
| 27 |
+
parser.add_argument("--embedder_type", default="openai", choices=["openai", "marqo"])
|
| 28 |
+
parser.add_argument("--vector_store_type", default="pinecone", choices=["pinecone", "marqo"])
|
| 29 |
parser.add_argument(
|
| 30 |
"--local_dir",
|
| 31 |
default="repos",
|
|
|
|
| 38 |
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 39 |
)
|
| 40 |
parser.add_argument(
|
| 41 |
+
"--chunks_per_batch",
|
| 42 |
+
type=int,
|
| 43 |
+
default=2000,
|
| 44 |
+
help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
|
| 45 |
)
|
| 46 |
+
parser.add_argument("--index_name", required=True, help="Vector store index name")
|
| 47 |
parser.add_argument(
|
| 48 |
"--include",
|
| 49 |
help="Path to a file containing a list of extensions to include. One extension per line.",
|
|
|
|
| 54 |
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 55 |
)
|
| 56 |
parser.add_argument(
|
| 57 |
+
"--max_embedding_jobs",
|
| 58 |
+
type=int,
|
| 59 |
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 60 |
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 61 |
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--marqo_url",
|
| 64 |
+
default="http://localhost:8882",
|
| 65 |
+
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--marqo_embedding_model",
|
| 69 |
+
default="hf/e5-base-v2",
|
| 70 |
+
help="The embedding model to use for Marqo.",
|
| 71 |
+
)
|
| 72 |
args = parser.parse_args()
|
| 73 |
|
| 74 |
+
# Validate embedder and vector store compatibility.
|
| 75 |
+
if args.embedder_type == "openai" and args.vector_store_type != "pinecone":
|
| 76 |
+
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
|
| 77 |
+
if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
|
| 78 |
+
parser.error("When using the marqo embedder, the vector store type must also be marqo.")
|
| 79 |
+
if args.embedder_type == "marqo" and args.chunks_per_batch > 64:
|
| 80 |
+
args.chunks_per_batch = 64
|
| 81 |
+
logging.warning("Marqo enforces a limit of 64 chunks per batch. Setting --chunks_per_batch to 64.")
|
| 82 |
+
|
| 83 |
+
# Validate other arguments.
|
| 84 |
if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
|
| 85 |
+
parser.error(f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}.")
|
|
|
|
|
|
|
| 86 |
if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
|
| 87 |
+
parser.error(f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}.")
|
|
|
|
|
|
|
| 88 |
if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
|
| 89 |
parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
|
| 90 |
if args.include and args.exclude:
|
|
|
|
| 104 |
|
| 105 |
logging.info("Issuing embedding jobs...")
|
| 106 |
chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
|
| 107 |
+
|
| 108 |
+
if args.embedder_type == "openai":
|
| 109 |
+
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 110 |
+
elif args.embedder_type == "marqo":
|
| 111 |
+
embedder = MarqoEmbedder(
|
| 112 |
+
repo_manager, chunker, index_name=args.index_name, url=args.marqo_url, model=args.marqo_embedding_model
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
|
| 116 |
+
|
| 117 |
embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
|
| 118 |
|
| 119 |
+
if args.vector_store_type == "marqo":
|
| 120 |
+
# Marqo computes embeddings and stores them in the vector store at once, so we're done.
|
| 121 |
+
logging.info("Done!")
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
logging.info("Waiting for embeddings to be ready...")
|
| 125 |
while not embedder.embeddings_are_ready():
|
| 126 |
logging.info("Sleeping for 30 seconds...")
|
|
|
|
| 128 |
|
| 129 |
logging.info("Moving embeddings to the vector store...")
|
| 130 |
# Note to developer: Replace this with your preferred vector store.
|
| 131 |
+
vector_store = build_from_args(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
vector_store.ensure_exists()
|
| 133 |
vector_store.upsert(embedder.download_embeddings())
|
| 134 |
logging.info("Done!")
|
src/repo_manager.py
CHANGED
|
@@ -35,9 +35,7 @@ class RepoManager:
|
|
| 35 |
@cached_property
|
| 36 |
def is_public(self) -> bool:
|
| 37 |
"""Checks whether a GitHub repository is publicly visible."""
|
| 38 |
-
response = requests.get(
|
| 39 |
-
f"https://api.github.com/repos/{self.repo_id}", timeout=10
|
| 40 |
-
)
|
| 41 |
# Note that the response will be 404 for both private and non-existent repos.
|
| 42 |
return response.status_code == 200
|
| 43 |
|
|
@@ -50,17 +48,13 @@ class RepoManager:
|
|
| 50 |
if self.access_token:
|
| 51 |
headers["Authorization"] = f"token {self.access_token}"
|
| 52 |
|
| 53 |
-
response = requests.get(
|
| 54 |
-
f"https://api.github.com/repos/{self.repo_id}", headers=headers
|
| 55 |
-
)
|
| 56 |
if response.status_code == 200:
|
| 57 |
branch = response.json().get("default_branch", "main")
|
| 58 |
else:
|
| 59 |
# This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
|
| 60 |
# most common naming for the default branch ("main").
|
| 61 |
-
logging.warn(
|
| 62 |
-
f"Unable to fetch default branch for {self.repo_id}: {response.text}"
|
| 63 |
-
)
|
| 64 |
branch = "main"
|
| 65 |
return branch
|
| 66 |
|
|
@@ -81,9 +75,7 @@ class RepoManager:
|
|
| 81 |
try:
|
| 82 |
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
|
| 83 |
except GitCommandError as e:
|
| 84 |
-
logging.error(
|
| 85 |
-
"Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e
|
| 86 |
-
)
|
| 87 |
return False
|
| 88 |
return True
|
| 89 |
|
|
@@ -130,9 +122,7 @@ class RepoManager:
|
|
| 130 |
for path in included_file_paths:
|
| 131 |
f.write(path + "\n")
|
| 132 |
|
| 133 |
-
excluded_file_paths = set(file_paths).difference(
|
| 134 |
-
set(included_file_paths)
|
| 135 |
-
)
|
| 136 |
with open(excluded_log_file, "a") as f:
|
| 137 |
for path in excluded_file_paths:
|
| 138 |
f.write(path + "\n")
|
|
@@ -142,15 +132,11 @@ class RepoManager:
|
|
| 142 |
try:
|
| 143 |
contents = f.read()
|
| 144 |
except UnicodeDecodeError:
|
| 145 |
-
logging.warning(
|
| 146 |
-
"Unable to decode file %s. Skipping.", file_path
|
| 147 |
-
)
|
| 148 |
continue
|
| 149 |
yield file_path[len(self.local_dir) + 1 :], contents
|
| 150 |
|
| 151 |
def github_link_for_file(self, file_path: str) -> str:
|
| 152 |
"""Converts a repository file path to a GitHub link."""
|
| 153 |
file_path = file_path[len(self.repo_id) :]
|
| 154 |
-
return
|
| 155 |
-
f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
| 156 |
-
)
|
|
|
|
| 35 |
@cached_property
|
| 36 |
def is_public(self) -> bool:
|
| 37 |
"""Checks whether a GitHub repository is publicly visible."""
|
| 38 |
+
response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
|
|
|
|
|
|
|
| 39 |
# Note that the response will be 404 for both private and non-existent repos.
|
| 40 |
return response.status_code == 200
|
| 41 |
|
|
|
|
| 48 |
if self.access_token:
|
| 49 |
headers["Authorization"] = f"token {self.access_token}"
|
| 50 |
|
| 51 |
+
response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
|
|
|
|
|
|
|
| 52 |
if response.status_code == 200:
|
| 53 |
branch = response.json().get("default_branch", "main")
|
| 54 |
else:
|
| 55 |
# This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
|
| 56 |
# most common naming for the default branch ("main").
|
| 57 |
+
logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
|
|
|
|
|
|
|
| 58 |
branch = "main"
|
| 59 |
return branch
|
| 60 |
|
|
|
|
| 75 |
try:
|
| 76 |
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
|
| 77 |
except GitCommandError as e:
|
| 78 |
+
logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
|
|
|
|
|
|
|
| 79 |
return False
|
| 80 |
return True
|
| 81 |
|
|
|
|
| 122 |
for path in included_file_paths:
|
| 123 |
f.write(path + "\n")
|
| 124 |
|
| 125 |
+
excluded_file_paths = set(file_paths).difference(set(included_file_paths))
|
|
|
|
|
|
|
| 126 |
with open(excluded_log_file, "a") as f:
|
| 127 |
for path in excluded_file_paths:
|
| 128 |
f.write(path + "\n")
|
|
|
|
| 132 |
try:
|
| 133 |
contents = f.read()
|
| 134 |
except UnicodeDecodeError:
|
| 135 |
+
logging.warning("Unable to decode file %s. Skipping.", file_path)
|
|
|
|
|
|
|
| 136 |
continue
|
| 137 |
yield file_path[len(self.local_dir) + 1 :], contents
|
| 138 |
|
| 139 |
def github_link_for_file(self, file_path: str) -> str:
|
| 140 |
"""Converts a repository file path to a GitHub link."""
|
| 141 |
file_path = file_path[len(self.repo_id) :]
|
| 142 |
+
return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
|
|
|
|
|
|
src/vector_store.py
CHANGED
|
@@ -3,13 +3,19 @@
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Dict, Generator, List, Tuple
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from pinecone import Pinecone
|
| 7 |
|
|
|
|
| 8 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 9 |
|
| 10 |
|
| 11 |
class VectorStore(ABC):
|
| 12 |
"""Abstract class for a vector store."""
|
|
|
|
| 13 |
@abstractmethod
|
| 14 |
def ensure_exists(self):
|
| 15 |
"""Ensures that the vector store exists. Creates it if it doesn't."""
|
|
@@ -29,11 +35,15 @@ class VectorStore(ABC):
|
|
| 29 |
if batch:
|
| 30 |
self.upsert_batch(batch)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
class PineconeVectorStore(VectorStore):
|
| 34 |
"""Vector store implementation using Pinecone."""
|
| 35 |
|
| 36 |
-
def __init__(self, index_name: str,
|
| 37 |
self.index_name = index_name
|
| 38 |
self.dimension = dimension
|
| 39 |
self.client = Pinecone()
|
|
@@ -42,13 +52,56 @@ class PineconeVectorStore(VectorStore):
|
|
| 42 |
|
| 43 |
def ensure_exists(self):
|
| 44 |
if self.index_name not in self.client.list_indexes().names():
|
| 45 |
-
self.client.create_index(
|
| 46 |
-
name=self.index_name, dimension=self.dimension, metric="cosine"
|
| 47 |
-
)
|
| 48 |
|
| 49 |
def upsert_batch(self, vectors: List[Vector]):
|
| 50 |
pinecone_vectors = [
|
| 51 |
-
(metadata.get("id", str(i)), embedding, metadata)
|
| 52 |
-
for i, (metadata, embedding) in enumerate(vectors)
|
| 53 |
]
|
| 54 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Dict, Generator, List, Tuple
|
| 5 |
|
| 6 |
+
import marqo
|
| 7 |
+
from langchain_community.vectorstores import Marqo
|
| 8 |
+
from langchain_core.documents import Document
|
| 9 |
+
from langchain_openai import OpenAIEmbeddings
|
| 10 |
from pinecone import Pinecone
|
| 11 |
|
| 12 |
+
OPENAI_EMBEDDING_SIZE = 1536
|
| 13 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 14 |
|
| 15 |
|
| 16 |
class VectorStore(ABC):
|
| 17 |
"""Abstract class for a vector store."""
|
| 18 |
+
|
| 19 |
@abstractmethod
|
| 20 |
def ensure_exists(self):
|
| 21 |
"""Ensures that the vector store exists. Creates it if it doesn't."""
|
|
|
|
| 35 |
if batch:
|
| 36 |
self.upsert_batch(batch)
|
| 37 |
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def to_langchain(self):
|
| 40 |
+
"""Converts the vector store to a LangChain vector store object."""
|
| 41 |
+
|
| 42 |
|
| 43 |
class PineconeVectorStore(VectorStore):
|
| 44 |
"""Vector store implementation using Pinecone."""
|
| 45 |
|
| 46 |
+
def __init__(self, index_name: str, namespace: str, dimension: int = OPENAI_EMBEDDING_SIZE):
|
| 47 |
self.index_name = index_name
|
| 48 |
self.dimension = dimension
|
| 49 |
self.client = Pinecone()
|
|
|
|
| 52 |
|
| 53 |
def ensure_exists(self):
|
| 54 |
if self.index_name not in self.client.list_indexes().names():
|
| 55 |
+
self.client.create_index(name=self.index_name, dimension=self.dimension, metric="cosine")
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def upsert_batch(self, vectors: List[Vector]):
|
| 58 |
pinecone_vectors = [
|
| 59 |
+
(metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
|
|
|
|
| 60 |
]
|
| 61 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
| 62 |
+
|
| 63 |
+
def to_langchain(self):
|
| 64 |
+
return Pinecone.from_existing_index(
|
| 65 |
+
index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MarqoVectorStore(VectorStore):
|
| 70 |
+
"""Vector store implementation using Marqo."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, url: str, index_name: str):
|
| 73 |
+
self.client = marqo.Client(url=url)
|
| 74 |
+
self.index_name = index_name
|
| 75 |
+
|
| 76 |
+
def ensure_exists(self):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def upsert_batch(self, vectors: List[Vector]):
|
| 80 |
+
# Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
def to_langchain(self):
|
| 84 |
+
vectorstore = Marqo(client=self.client, index_name=self.index_name)
|
| 85 |
+
|
| 86 |
+
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
|
| 87 |
+
# the result, and instead take the "filename" directly from the result.
|
| 88 |
+
def patched_method(self, results):
|
| 89 |
+
documents: List[Document] = []
|
| 90 |
+
for res in results["hits"]:
|
| 91 |
+
documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
|
| 92 |
+
return documents
|
| 93 |
+
|
| 94 |
+
vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
|
| 95 |
+
vectorstore, vectorstore.__class__
|
| 96 |
+
)
|
| 97 |
+
return vectorstore
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_from_args(args: dict) -> VectorStore:
|
| 101 |
+
"""Builds a vector store from the given command-line arguments."""
|
| 102 |
+
if args.vector_store_type == "pinecone":
|
| 103 |
+
return PineconeVectorStore(index_name=args.index_name, namespace=args.repo_id)
|
| 104 |
+
elif args.vector_store_type == "marqo":
|
| 105 |
+
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
|