juliaturc commited on
Commit
559dd34
·
1 Parent(s): bf938c6

Add code for indexing and chatting.

Browse files
Files changed (12) hide show
  1. .gitignore +3 -0
  2. .pylintrc +8 -0
  3. README.md +45 -2
  4. requirements.txt +14 -0
  5. src/.sample-env +3 -0
  6. src/__init__.py +0 -0
  7. src/chat.py +114 -0
  8. src/chunker.py +237 -0
  9. src/embedder.py +199 -0
  10. src/index.py +86 -0
  11. src/repo_manager.py +146 -0
  12. src/vector_store.py +54 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ __pycache__
3
+ *.cpython.*
.pylintrc ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [FORMAT]
2
+ max-line-length=120
3
+
4
+ [DESIGN]
5
+ min-public-methods=1
6
+
7
+ [MASTER]
8
+ init-hook='import sys; sys.path.append(".")'
README.md CHANGED
@@ -1,2 +1,45 @@
1
- # repo2vec
2
- From GitHub repo to vector store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Overview
2
+ `repo2vec` enables you to chat with your codebase by simply running two python scripts:
3
+ ```
4
+ pip install -r requirements.txt
5
+
6
+ export GITHUB_REPO_NAME=...
7
+ export OPENAI_API_KEY=...
8
+ export PINECONE_API_KEY=...
9
+ export PINECONE_INDEX_NAME=...
10
+
11
+ python src/index.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
12
+ python src/chat.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
13
+ ```
14
+ This will bring up a `gradio` app where you can ask questions about your codebase. The assistant responses always include GitHub links to the documents retrieved for each query.
15
+
16
+ Here is, for example, a conversation about the repo [Storia-AI/image-eval](https://github.com/Storia-AI/image-eval):
17
+ ![screenshot](assets/chat_screenshot.png)
18
+
19
+ # Under the hood
20
+
21
+ ## Indexing the repo
22
+ The `src/index.py` script performs the following steps:
23
+ 1. **Clones a GitHub repository**. See [RepoManager](src/repo_manager.py).
24
+ - Make sure to set the `GITHUB_TOKEN` environment variable for private repositories.
25
+ 2. **Chunks files**. See [Chunker](src/chunker.py).
26
+ - For code files, we implement a special `CodeChunker` that takes the parse tree into account.
27
+ 3. **Batch-embeds chunks**. See [Embedder](src/embedder.py).
28
+ - By default, we use OpenAI's [batch embedding API](https://platform.openai.com/docs/guides/batch/overview), which is much faster and cheaper than the regular synchronous embedding API.
29
+ 4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
30
+ - By default, we use [Pinecone](https://pinecone.io) as a vector store, but you can easily plug in your own.
31
+
32
+ ## Chatting via RAG
33
+ The `src/chat.py` brings up a [Gradio app](https://www.gradio.app/) with a chat interface as shown above. We use [LangChain](https://langchain.com) to define a RAG chain which, given a user query about the repository:
34
+
35
+ 1. Rewrites the query to be self-contained based on previous queries
36
+ 2. Embeds the rewritten query using OpenAI embeddings
37
+ 3. Retrieves relevant documents from the vector store
38
+ 4. Calls an OpenAI LLM to respond to the user query based on the retrieved documents.
39
+
40
+ The sources are conveniently surfaced in the chat and linked directly to GitHub.
41
+
42
+ # Extensions & Contributions
43
+ We built the code purposefully modular so that you can plug in your desired embeddings, LLM and vector stores providers by simply implementing the relevant abstract classes.
44
+
45
+ Feel free to send feature requests to [founders@storia.ai](mailto:founders@storia.ai) or make a pull request!
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GitPython==3.1.43
2
+ Pygments==2.18.0
3
+ gradio==4.42.0
4
+ langchain==0.2.14
5
+ langchain-community==0.2.12
6
+ langchain-openai==0.1.22
7
+ openai==1.42.0
8
+ pinecone==5.0.1
9
+ python-dotenv==1.0.1
10
+ requests==2.32.3
11
+ semchunk==2.2.0
12
+ tiktoken==0.7.0
13
+ tree-sitter==0.22.3
14
+ tree-sitter-language-pack==0.2.0
src/.sample-env ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ OPENAI_API_KEY=
2
+ PINECONE_API_KEY=
3
+ GITHUB_TOKEN=
src/__init__.py ADDED
File without changes
src/chat.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A gradio app that enables users to chat with their codebase.
2
+
3
+ You must run main.py first in order to index the codebase into a vector store.
4
+ """
5
+
6
+ import argparse
7
+
8
+ from dotenv import load_dotenv
9
+
10
+ import gradio as gr
11
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
+ from langchain.chains.combine_documents import create_stuff_documents_chain
13
+ from langchain.schema import AIMessage, HumanMessage
14
+ from langchain_community.vectorstores import Pinecone
15
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
16
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
17
+
18
+ from repo_manager import RepoManager
19
+
20
+ load_dotenv()
21
+
22
+
23
+ def build_rag_chain(args):
24
+ """Builds a RAG chain via LangChain."""
25
+ llm = ChatOpenAI(model=args.openai_model)
26
+
27
+ vectorstore = Pinecone.from_existing_index(
28
+ index_name=args.pinecone_index_name,
29
+ embedding=OpenAIEmbeddings(),
30
+ namespace=args.repo_id,
31
+ )
32
+
33
+ retriever = vectorstore.as_retriever()
34
+
35
+ # Prompt to contextualize the latest query based on the chat history.
36
+ contextualize_q_system_prompt = (
37
+ "Given a chat history and the latest user question which might reference context in the chat history, "
38
+ "formualte a standalone question which can be understood without the chat history. Do NOT answer the question, "
39
+ "just reformulate it if needed and otherwise return it as is."
40
+ )
41
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
42
+ [
43
+ ("system", contextualize_q_system_prompt),
44
+ MessagesPlaceholder("chat_history"),
45
+ ("human", "{input}"),
46
+ ]
47
+ )
48
+ history_aware_retriever = create_history_aware_retriever(
49
+ llm, retriever, contextualize_q_prompt
50
+ )
51
+
52
+ qa_system_prompt = (
53
+ f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
54
+ "Assume I am an advanced developer and answer my questions in the most succinct way possible."
55
+ "\n\n"
56
+ "Here are some snippets from the codebase."
57
+ "\n\n"
58
+ "{context}"
59
+ )
60
+ qa_prompt = ChatPromptTemplate.from_messages(
61
+ [
62
+ ("system", qa_system_prompt),
63
+ MessagesPlaceholder("chat_history"),
64
+ ("human", "{input}"),
65
+ ]
66
+ )
67
+
68
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
69
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
70
+ return rag_chain
71
+
72
+
73
+ def append_sources_to_response(response):
74
+ """Given an OpenAI completion response, appends to it GitHub links of the context sources."""
75
+ filenames = [document.metadata["filename"] for document in response["context"]]
76
+ # Deduplicate filenames while preserving their order.
77
+ filenames = list(dict.fromkeys(filenames))
78
+ repo_manager = RepoManager(args.repo_id)
79
+ github_links = [
80
+ repo_manager.github_link_for_file(filename) for filename in filenames
81
+ ]
82
+ return response["answer"] + "\n\nSources:\n" + "\n".join(github_links)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ parser = argparse.ArgumentParser(description="UI to chat with your codebase")
87
+ parser.add_argument("repo_id", help="The ID of the repository to index")
88
+ parser.add_argument(
89
+ "--openai_model", default="gpt-4", help="The OpenAI model to use for response generation"
90
+ )
91
+ parser.add_argument(
92
+ "--pinecone_index_name", required=True, help="Pinecone index name"
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ rag_chain = build_rag_chain(args)
97
+
98
+ def _predict(message, history):
99
+ """Performs one RAG operation."""
100
+ history_langchain_format = []
101
+ for human, ai in history:
102
+ history_langchain_format.append(HumanMessage(content=human))
103
+ history_langchain_format.append(AIMessage(content=ai))
104
+ history_langchain_format.append(HumanMessage(content=message))
105
+ response = rag_chain.invoke(
106
+ {"input": message, "chat_history": history_langchain_format}
107
+ )
108
+ answer = append_sources_to_response(response)
109
+ return answer
110
+
111
+ gr.ChatInterface(_predict,
112
+ title=args.repo_id,
113
+ description=f"Code sage for your repo: {args.repo_id}",
114
+ examples=["What does this repo do?", "Give me some sample code."]).launch()
src/chunker.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chunker abstraction and implementations."""
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import List, Optional
8
+
9
+ import pygments
10
+ import tiktoken
11
+ from semchunk import chunk as chunk_via_semchunk
12
+ from tree_sitter import Node
13
+ from tree_sitter_language_pack import get_parser
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class Chunk:
20
+ """A chunk of code or text extracted from a file in the repository."""
21
+
22
+ filename: str
23
+ start_byte: int
24
+ end_byte: int
25
+ _content: Optional[str] = None
26
+
27
+ @property
28
+ def content(self) -> Optional[str]:
29
+ """The text content to be embedded. Might contain information beyond just the text snippet from the file."""
30
+ return self._content
31
+
32
+ def populate_content(self, file_content: str):
33
+ """Populates the content of the chunk with the file path and file content."""
34
+ self._content = (
35
+ self.filename + "\n\n" + file_content[self.start_byte : self.end_byte]
36
+ )
37
+
38
+ def num_tokens(self, tokenizer):
39
+ """Counts the number of tokens in the chunk."""
40
+ if not self.content:
41
+ raise ValueError("Content not populated.")
42
+ return Chunk._cached_num_tokens(self.content, tokenizer)
43
+
44
+ @staticmethod
45
+ @lru_cache(maxsize=1024)
46
+ def _cached_num_tokens(content: str, tokenizer):
47
+ """Static method to cache token counts."""
48
+ return len(tokenizer.encode(content, disallowed_special=()))
49
+
50
+ def __eq__(self, other):
51
+ if isinstance(other, Chunk):
52
+ return (
53
+ self.filename == other.filename
54
+ and self.start_byte == other.start_byte
55
+ and self.end_byte == other.end_byte
56
+ )
57
+ return False
58
+
59
+ def __hash__(self):
60
+ return hash((self.filename, self.start_byte, self.end_byte))
61
+
62
+
63
+ class Chunker(ABC):
64
+ """Abstract class for chunking a file into smaller pieces."""
65
+
66
+ @abstractmethod
67
+ def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
68
+ """Chunks a file into smaller pieces."""
69
+
70
+
71
+ class CodeChunker(Chunker):
72
+ """Splits a code file into chunks of at most `max_tokens` tokens each."""
73
+
74
+ def __init__(self, max_tokens: int):
75
+ self.max_tokens = max_tokens
76
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
77
+ self.text_chunker = TextChunker(max_tokens)
78
+
79
+ @staticmethod
80
+ def _get_language_from_filename(filename: str):
81
+ """Returns a canonical name for the language of the file, based on its extension.
82
+ Returns None if the language is unknown to the pygments lexer.
83
+ """
84
+ try:
85
+ lexer = pygments.lexers.get_lexer_for_filename(filename)
86
+ return lexer.name.lower()
87
+ except pygments.util.ClassNotFound:
88
+ return None
89
+
90
+ def _chunk_node(self, node: Node, filename: str, file_content: str) -> List[Chunk]:
91
+ """Splits a node in the parse tree into a flat list of chunks."""
92
+ node_chunk = Chunk(filename, node.start_byte, node.end_byte)
93
+ node_chunk.populate_content(file_content)
94
+
95
+ if node_chunk.num_tokens(self.tokenizer) <= self.max_tokens:
96
+ return [node_chunk]
97
+
98
+ if not node.children:
99
+ # This is a leaf node, but it's too long. We'll have to split it with a text tokenizer.
100
+ return self.text_chunker.chunk(
101
+ filename, file_content[node.start_byte : node.end_byte]
102
+ )
103
+
104
+ chunks = []
105
+ for child in node.children:
106
+ chunks.extend(self._chunk_node(child, filename, file_content))
107
+
108
+ for chunk in chunks:
109
+ # This should always be true. Otherwise there must be a bug in the code.
110
+ assert chunk.content and chunk.num_tokens(self.tokenizer) <= self.max_tokens
111
+
112
+ # Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically
113
+ # small chunks that end up being undeservedly preferred by the retriever.
114
+ merged_chunks = []
115
+ for chunk in chunks:
116
+ if not merged_chunks:
117
+ merged_chunks.append(chunk)
118
+ elif (
119
+ merged_chunks[-1].num_tokens(self.tokenizer)
120
+ + chunk.num_tokens(self.tokenizer)
121
+ < self.max_tokens - 50
122
+ ):
123
+ # There's a good chance that merging these two chunks will be under the token limit. We're not 100% sure
124
+ # at this point, because tokenization is not necessarily additive.
125
+ merged = Chunk(
126
+ merged_chunks[-1].filename,
127
+ merged_chunks[-1].start_byte,
128
+ chunk.end_byte,
129
+ )
130
+ merged.populate_content(file_content)
131
+ if merged.num_tokens(self.tokenizer) <= self.max_tokens:
132
+ merged_chunks[-1] = merged
133
+ else:
134
+ merged_chunks.append(chunk)
135
+ chunks = merged_chunks
136
+
137
+ for chunk in merged_chunks:
138
+ # This should always be true. Otherwise there's a bug worth investigating.
139
+ assert chunk.content and chunk.num_tokens(self.tokenizer) <= self.max_tokens
140
+
141
+ return merged_chunks
142
+
143
+ @staticmethod
144
+ def is_code_file(filename: str) -> bool:
145
+ """Checks whether pygment & tree_sitter can parse the file as code."""
146
+ language = CodeChunker._get_language_from_filename(filename)
147
+ return language and language not in ["text only", "None"]
148
+
149
+ @staticmethod
150
+ def parse_tree(filename: str, content: str) -> List[str]:
151
+ """Parses the code in a file and returns the parse tree."""
152
+ language = CodeChunker._get_language_from_filename(filename)
153
+
154
+ if not language or language in ["text only", "None"]:
155
+ logging.debug("%s doesn't seem to be a code file.", filename)
156
+ return None
157
+
158
+ try:
159
+ parser = get_parser(language)
160
+ except LookupError:
161
+ logging.debug("%s doesn't seem to be a code file.", filename)
162
+ return None
163
+
164
+ tree = parser.parse(bytes(content, "utf8"))
165
+
166
+ if not tree.root_node.children or tree.root_node.children[0].type == "ERROR":
167
+ logging.warning("Failed to parse code in %s.", filename)
168
+ return None
169
+ return tree
170
+
171
+ def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
172
+ """Chunks a code file into smaller pieces."""
173
+ tree = self.parse_tree(file_path, file_content)
174
+ if tree is None:
175
+ return []
176
+
177
+ chunks = self._chunk_node(tree.root_node, file_path, file_content)
178
+ for chunk in chunks:
179
+ # Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be
180
+ # a bug in the code.
181
+ assert chunk.content
182
+ size = chunk.num_tokens(self.tokenizer)
183
+ assert (
184
+ size <= self.max_tokens
185
+ ), f"Chunk size {size} exceeds max_tokens {self.max_tokens}."
186
+
187
+ return chunks
188
+
189
+
190
+ class TextChunker(Chunker):
191
+ """Wrapper around semchunk: https://github.com/umarbutler/semchunk."""
192
+
193
+ def __init__(self, max_tokens: int):
194
+ self.max_tokens = max_tokens
195
+
196
+ tokenizer = tiktoken.get_encoding("cl100k_base")
197
+ self.count_tokens = lambda text: len(
198
+ tokenizer.encode(text, disallowed_special=())
199
+ )
200
+
201
+ def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
202
+ """Chunks a text file into smaller pieces."""
203
+ # We need to allocate some tokens for the filename, which is part of the chunk content.
204
+ extra_tokens = self.count_tokens(file_path + "\n\n")
205
+ text_chunks = chunk_via_semchunk(
206
+ file_content, self.max_tokens - extra_tokens, self.count_tokens
207
+ )
208
+
209
+ chunks = []
210
+ start = 0
211
+ for text_chunk in text_chunks:
212
+ # This assertion should always be true. Otherwise there's a bug worth finding.
213
+ assert self.count_tokens(text_chunk) <= self.max_tokens - extra_tokens
214
+
215
+ # Find the start/end positions of the chunks.
216
+ start = file_content.index(text_chunk, start)
217
+ if start == -1:
218
+ logging.warning("Couldn't find semchunk in content: %s", text_chunk)
219
+ else:
220
+ end = start + len(text_chunk)
221
+ chunks.append(Chunk(file_path, start, end, text_chunk))
222
+
223
+ start = end
224
+ return chunks
225
+
226
+
227
+ class UniversalChunker(Chunker):
228
+ """Chunks a file into smaller pieces, regardless of whether it's code or text."""
229
+
230
+ def __init__(self, max_tokens: int):
231
+ self.code_chunker = CodeChunker(max_tokens)
232
+ self.text_chunker = TextChunker(max_tokens)
233
+
234
+ def chunk(self, file_path: str, file_content: str) -> List[Chunk]:
235
+ if CodeChunker.is_code_file(file_path):
236
+ return self.code_chunker.chunk(file_path, file_content)
237
+ return self.text_chunker.chunk(file_path, file_content)
src/embedder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Batch embedder abstraction and implementations."""
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from abc import ABC, abstractmethod
7
+ from collections import Counter
8
+ from typing import Dict, Generator, List, Tuple
9
+
10
+ from openai import OpenAI
11
+
12
+ from chunker import Chunk, Chunker
13
+ from repo_manager import RepoManager
14
+
15
+ Vector = Tuple[Dict, List[float]] # (metadata, embedding)
16
+
17
+
18
+ class BatchEmbedder(ABC):
19
+ """Abstract class for batch embedding of a repository."""
20
+
21
+ @abstractmethod
22
+ def embed_repo(self, chunks_per_batch: int):
23
+ """Issues batch embedding jobs for the entire repository."""
24
+
25
+ @abstractmethod
26
+ def embeddings_are_ready(self) -> bool:
27
+ """Checks whether the batch embedding jobs are done."""
28
+
29
+ @abstractmethod
30
+ def download_embeddings(self) -> Generator[Vector, None, None]:
31
+ """Yields (chunk_metadata, embedding) pairs for each chunk in the repository."""
32
+
33
+
34
+ class OpenAIBatchEmbedder(BatchEmbedder):
35
+ """Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
36
+
37
+ def __init__(
38
+ self, repo_manager: RepoManager, chunker: Chunker, local_dir: str
39
+ ):
40
+ self.repo_manager = repo_manager
41
+ self.chunker = chunker
42
+ self.local_dir = local_dir
43
+ # IDs issued by OpenAI for each batch job mapped to metadata about the chunks.
44
+ self.openai_batch_ids = {}
45
+ self.client = OpenAI()
46
+
47
+ def embed_repo(self, chunks_per_batch: int):
48
+ """Issues batch embedding jobs for the entire repository."""
49
+ if self.openai_batch_ids:
50
+ raise ValueError("Embeddings are in progress.")
51
+
52
+ batch = []
53
+ chunk_count = 0
54
+ repo_name = self.repo_manager.repo_id.split("/")[-1]
55
+
56
+ for filepath, content in self.repo_manager.walk():
57
+ chunks = self.chunker.chunk(filepath, content)
58
+ chunk_count += len(chunks)
59
+ batch.extend(chunks)
60
+
61
+ if len(batch) > chunks_per_batch:
62
+ for i in range(0, len(batch), chunks_per_batch):
63
+ batch = batch[i : i + chunks_per_batch]
64
+ openai_batch_id = self._issue_job_for_chunks(
65
+ batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
66
+ )
67
+ self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
68
+ batch
69
+ )
70
+ batch = []
71
+
72
+ # Finally, commit the last batch.
73
+ if batch:
74
+ openai_batch_id = self._issue_job_for_chunks(
75
+ batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
76
+ )
77
+ self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(batch)
78
+ logging.info(
79
+ "Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count
80
+ )
81
+
82
+ # Save the job IDs to a file, just in case this script is terminated by mistake.
83
+ metadata_file = os.path.join(self.local_dir, "openai_batch_ids.json")
84
+ with open(metadata_file, "w") as f:
85
+ json.dump(self.openai_batch_ids, f)
86
+ logging.info("Job metadata saved at %s", metadata_file)
87
+
88
+ def embeddings_are_ready(self) -> bool:
89
+ """Checks whether the embeddings jobs are done (either completed or failed)."""
90
+ if not self.openai_batch_ids:
91
+ raise ValueError("No embeddings in progress.")
92
+ job_ids = self.openai_batch_ids.keys()
93
+ statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids]
94
+ are_ready = all(status.status in ["completed", "failed"] for status in statuses)
95
+ status_counts = Counter(status.status for status in statuses)
96
+ logging.info("Job statuses: %s", status_counts)
97
+ return are_ready
98
+
99
+ def download_embeddings(self) -> Generator[Vector, None, None]:
100
+ """Yield a (chunk_metadata, embedding) pair for each chunk in the repository."""
101
+ job_ids = self.openai_batch_ids.keys()
102
+ statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids]
103
+
104
+ for idx, status in enumerate(statuses):
105
+ if status.status == "failed":
106
+ logging.error("Job failed: %s", status)
107
+ continue
108
+
109
+ if not status.output_file_id:
110
+ error = self.client.files.content(status.error_file_id)
111
+ logging.error("Job %s failed with error: %s", status.id, error.text)
112
+ continue
113
+
114
+ batch_metadata = self.openai_batch_ids[status.id]
115
+ file_response = self.client.files.content(status.output_file_id)
116
+ data = json.loads(file_response.text)["response"]["body"]["data"]
117
+ logging.info("Job %s generated %d embeddings.", status.id, len(data))
118
+
119
+ for datum in data:
120
+ idx = int(datum["index"])
121
+ metadata = batch_metadata[idx]
122
+ embedding = datum["embedding"]
123
+ yield (metadata, embedding)
124
+
125
+ def _issue_job_for_chunks(self, chunks: List[Chunk], batch_id: str) -> str:
126
+ """Issues a batch embedding job for the given chunks. Returns the job ID."""
127
+ logging.info("*" * 100)
128
+ logging.info("Issuing job for batch %s with %d chunks.", batch_id, len(chunks))
129
+
130
+ # Create a .jsonl file with the batch.
131
+ request = OpenAIBatchEmbedder._chunks_to_request(chunks, batch_id)
132
+ input_file = os.path.join(self.local_dir, f"batch_{batch_id}.jsonl")
133
+ OpenAIBatchEmbedder._export_to_jsonl([request], input_file)
134
+
135
+ # Uplaod the file and issue the embedding job.
136
+ batch_input_file = self.client.files.create(file=open(input_file, "rb"), purpose="batch")
137
+ batch_status = self._create_batch_job(batch_input_file.id)
138
+ logging.info("Created job with ID %s", batch_status.id)
139
+ return batch_status.id
140
+
141
+ def _create_batch_job(self, input_file_id: str):
142
+ """Creates a batch embedding job for OpenAI."""
143
+ try:
144
+ return self.client.batches.create(
145
+ input_file_id=input_file_id,
146
+ endpoint="/v1/embeddings",
147
+ completion_window="24h", # This is the only allowed value for now.
148
+ timeout=3 * 60, # 3 minutes
149
+ metadata={},
150
+ )
151
+ except Exception as e:
152
+ print(
153
+ f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}"
154
+ )
155
+ return None
156
+
157
+ @staticmethod
158
+ def _export_to_jsonl(list_of_dicts: List[Dict], output_file: str):
159
+ """Exports a list of dictionaries to a .jsonl file."""
160
+ directory = os.path.dirname(output_file)
161
+ if not os.path.exists(directory):
162
+ os.makedirs(directory)
163
+ with open(output_file, "w") as f:
164
+ for item in list_of_dicts:
165
+ json.dump(item, f)
166
+ f.write("\n")
167
+
168
+ @staticmethod
169
+ def _chunks_to_request(chunks: List[Chunk], batch_id: str):
170
+ """Convert a list of chunks to a batch request."""
171
+ return {
172
+ "custom_id": batch_id,
173
+ "method": "POST",
174
+ "url": "/v1/embeddings",
175
+ "body": {
176
+ "model": "text-embedding-ada-002",
177
+ "input": [chunk.content for chunk in chunks],
178
+ },
179
+ }
180
+
181
+ @staticmethod
182
+ def _metadata_for_chunks(chunks):
183
+ metadata = []
184
+ for chunk in chunks:
185
+ filename_ascii = chunk.filename.encode("ascii", "ignore").decode("ascii")
186
+ metadata.append(
187
+ {
188
+ # Some vector stores require the IDs to be ASCII.
189
+ "id": f"{filename_ascii}_{chunk.start_byte}_{chunk.end_byte}",
190
+ "filename": chunk.filename,
191
+ "start_byte": chunk.start_byte,
192
+ "end_byte": chunk.end_byte,
193
+ # Note to developer: When choosing a large chunk size, you might exceed the vector store's metadata
194
+ # size limit. In that case, you can simply store the start/end bytes above, and fetch the content
195
+ # directly from the repository when needed.
196
+ "text": chunk.content,
197
+ }
198
+ )
199
+ return metadata
src/index.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
+
3
+ import argparse
4
+ import logging
5
+ import time
6
+
7
+ from chunker import UniversalChunker
8
+ from embedder import OpenAIBatchEmbedder
9
+ from repo_manager import RepoManager
10
+ from vector_store import PineconeVectorStore
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ OPENAI_EMBEDDING_SIZE = 1536
15
+ MAX_TOKENS_PER_CHUNK = (
16
+ 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
17
+ )
18
+ MAX_CHUNKS_PER_BATCH = (
19
+ 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
20
+ )
21
+ MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
22
+
23
+
24
+ def main():
25
+ parser = argparse.ArgumentParser(description="Batch-embeds a repository")
26
+ parser.add_argument("repo_id", help="The ID of the repository to index")
27
+ parser.add_argument(
28
+ "--local_dir",
29
+ default="repos",
30
+ help="The local directory to store the repository",
31
+ )
32
+ parser.add_argument(
33
+ "--tokens_per_chunk",
34
+ type=int,
35
+ default=800,
36
+ help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
37
+ )
38
+ parser.add_argument(
39
+ "--chunks_per_batch", type=int, default=2000, help="Maximum chunks per batch"
40
+ )
41
+ parser.add_argument(
42
+ "--pinecone_index_name", required=True, help="Pinecone index name"
43
+ )
44
+
45
+ args = parser.parse_args()
46
+
47
+ # Validate the arguments.
48
+ if args.tokens_per_chunk > MAX_TOKENS_PER_CHUNK:
49
+ parser.error(
50
+ f"The maximum number of tokens per chunk is {MAX_TOKENS_PER_CHUNK}."
51
+ )
52
+ if args.chunks_per_batch > MAX_CHUNKS_PER_BATCH:
53
+ parser.error(
54
+ f"The maximum number of chunks per batch is {MAX_CHUNKS_PER_BATCH}."
55
+ )
56
+ if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
57
+ parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
58
+
59
+ logging.info("Cloning the repository...")
60
+ repo_manager = RepoManager(args.repo_id, local_dir=args.local_dir)
61
+ repo_manager.clone()
62
+
63
+ logging.info("Issuing embedding jobs...")
64
+ chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
65
+ embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
66
+ embedder.embed_repo(args.chunks_per_batch)
67
+
68
+ logging.info("Waiting for embeddings to be ready...")
69
+ while not embedder.embeddings_are_ready():
70
+ logging.info("Sleeping for 30 seconds...")
71
+ time.sleep(30)
72
+
73
+ logging.info("Moving embeddings to the vector store...")
74
+ # Note to developer: Replace this with your preferred vector store.
75
+ vector_store = PineconeVectorStore(
76
+ index_name=args.pinecone_index_name,
77
+ dimension=OPENAI_EMBEDDING_SIZE,
78
+ namespace=repo_manager.repo_id,
79
+ )
80
+ vector_store.ensure_exists()
81
+ vector_store.upsert(embedder.download_embeddings())
82
+ logging.info("Done!")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()
src/repo_manager.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility classes to maniuplate GitHub repositories."""
2
+
3
+ import logging
4
+ import os
5
+ from functools import cached_property
6
+
7
+ import requests
8
+ from git import GitCommandError, Repo
9
+
10
+
11
+ class RepoManager:
12
+ """Class to manage a local clone of a GitHub repository."""
13
+
14
+ def __init__(self, repo_id: str, local_dir: str = None):
15
+ """
16
+ Args:
17
+ repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
18
+ local_dir: The local directory where the repository will be cloned.
19
+ """
20
+ self.repo_id = repo_id
21
+ self.local_dir = local_dir or "/tmp/"
22
+ if not os.path.exists(self.local_dir):
23
+ os.makedirs(self.local_dir)
24
+ self.local_path = os.path.join(self.local_dir, repo_id)
25
+ self.access_token = os.getenv("GITHUB_TOKEN")
26
+
27
+ @cached_property
28
+ def is_public(self) -> bool:
29
+ """Checks whether a GitHub repository is publicly visible."""
30
+ response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
31
+ # Note that the response will be 404 for both private and non-existent repos.
32
+ return response.status_code == 200
33
+
34
+ @cached_property
35
+ def default_branch(self) -> str:
36
+ """Fetches the default branch of the repository from GitHub."""
37
+ headers = {
38
+ "Accept": "application/vnd.github.v3+json",
39
+ }
40
+ if self.access_token:
41
+ headers["Authorization"] = f"token {self.access_token}"
42
+
43
+ response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
44
+ if response.status_code == 200:
45
+ branch = response.json().get("default_branch", "main")
46
+ else:
47
+ # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
48
+ # most common naming for the default branch ("main").
49
+ logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
50
+ branch = "main"
51
+ return branch
52
+
53
+ def clone(self) -> bool:
54
+ """Clones the repository to the local directory, if it's not already cloned."""
55
+ if os.path.exists(self.local_path):
56
+ # The repository is already cloned.
57
+ return True
58
+
59
+ if not self.is_public and not self.access_token:
60
+ raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.")
61
+
62
+ if self.access_token:
63
+ clone_url = f"https://{self.access_token}@github.com/{self.repo_id}.git"
64
+ else:
65
+ clone_url = f"https://github.com/{self.repo_id}.git"
66
+
67
+ try:
68
+ Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
69
+ except GitCommandError as e:
70
+ logging.error(
71
+ "Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e
72
+ )
73
+ return False
74
+ return True
75
+
76
+ def walk(
77
+ self,
78
+ included_extensions: set = None,
79
+ excluded_extensions: set = None,
80
+ log_dir: str = None,
81
+ ):
82
+ """Walks the local repository path and yields a tuple of (filepath, content) for each file.
83
+ The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
84
+
85
+ Args:
86
+ included_extensions: Optional set of extensions to include.
87
+ excluded_extensions: Optional set of extensions to exclude.
88
+ log_dir: Optional directory where to log the included and excluded files.
89
+ """
90
+ # Convert included and excluded extensions to lowercase.
91
+ if included_extensions:
92
+ included_extensions = {ext.lower() for ext in included_extensions}
93
+ if excluded_extensions:
94
+ excluded_extensions = {ext.lower() for ext in excluded_extensions}
95
+
96
+ def include(file_path: str) -> bool:
97
+ _, extension = os.path.splitext(file_path)
98
+ extension = extension.lower()
99
+ if included_extensions and extension not in included_extensions:
100
+ return False
101
+ if excluded_extensions and extension in excluded_extensions:
102
+ return False
103
+ # Exclude hidden files and directories.
104
+ if any(part.startswith(".") for part in file_path.split(os.path.sep)):
105
+ return False
106
+ return True
107
+
108
+ # We will keep apending to these files during the iteration, so we need to clear them first.
109
+ if log_dir:
110
+ repo_name = self.repo_id.replace("/", "_")
111
+ included_log_file = os.path.join(log_dir, f"included_{repo_name}.txt")
112
+ excluded_log_file = os.path.join(log_dir, f"excluded_{repo_name}.txt")
113
+ if os.path.exists(included_log_file):
114
+ os.remove(included_log_file)
115
+ if os.path.exists(excluded_log_file):
116
+ os.remove(excluded_log_file)
117
+
118
+ for root, _, files in os.walk(self.local_path):
119
+ file_paths = [os.path.join(root, file) for file in files]
120
+ included_file_paths = [f for f in file_paths if include(f)]
121
+
122
+ if log_dir:
123
+ with open(included_log_file, "a") as f:
124
+ for path in included_file_paths:
125
+ f.write(path + "\n")
126
+
127
+ excluded_file_paths = set(file_paths).difference(
128
+ set(included_file_paths)
129
+ )
130
+ with open(excluded_log_file, "a") as f:
131
+ for path in excluded_file_paths:
132
+ f.write(path + "\n")
133
+
134
+ for file_path in included_file_paths:
135
+ with open(file_path, "r") as f:
136
+ try:
137
+ contents = f.read()
138
+ except UnicodeDecodeError:
139
+ logging.warning("Unable to decode file %s. Skipping.", file_path)
140
+ continue
141
+ yield file_path[len(self.local_dir) + 1 :], contents
142
+
143
+ def github_link_for_file(self, file_path: str) -> str:
144
+ """Converts a repository file path to a GitHub link."""
145
+ file_path = file_path[len(self.repo_id):]
146
+ return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
src/vector_store.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector store abstraction and implementations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, Generator, List, Tuple
5
+
6
+ from pinecone import Pinecone
7
+
8
+ Vector = Tuple[Dict, List[float]] # (metadata, embedding)
9
+
10
+
11
+ class VectorStore(ABC):
12
+ """Abstract class for a vector store."""
13
+ @abstractmethod
14
+ def ensure_exists(self):
15
+ """Ensures that the vector store exists. Creates it if it doesn't."""
16
+
17
+ @abstractmethod
18
+ def upsert_batch(self, vectors: List[Vector]):
19
+ """Upserts a batch of vectors."""
20
+
21
+ def upsert(self, vectors: Generator[Vector, None, None]):
22
+ """Upserts in batches of 100, since vector stores have a limit on upsert size."""
23
+ batch = []
24
+ for metadata, embedding in vectors:
25
+ batch.append((metadata, embedding))
26
+ if len(batch) == 100:
27
+ self.upsert_batch(batch)
28
+ batch = []
29
+ if batch:
30
+ self.upsert_batch(batch)
31
+
32
+
33
+ class PineconeVectorStore(VectorStore):
34
+ """Vector store implementation using Pinecone."""
35
+
36
+ def __init__(self, index_name: str, dimension: int, namespace: str):
37
+ self.index_name = index_name
38
+ self.dimension = dimension
39
+ self.client = Pinecone()
40
+ self.index = self.client.Index(self.index_name)
41
+ self.namespace = namespace
42
+
43
+ def ensure_exists(self):
44
+ if self.index_name not in self.client.list_indexes().names():
45
+ self.client.create_index(
46
+ name=self.index_name, dimension=self.dimension, metric="cosine"
47
+ )
48
+
49
+ def upsert_batch(self, vectors: List[Vector]):
50
+ pinecone_vectors = [
51
+ (metadata.get("id", str(i)), embedding, metadata)
52
+ for i, (metadata, embedding) in enumerate(vectors)
53
+ ]
54
+ self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)