juliaturc commited on
Commit
d5c979a
·
1 Parent(s): 559dd34

Add inclusion and exclusion sets.

Browse files
Files changed (5) hide show
  1. README.md +17 -2
  2. src/chat.py +14 -5
  3. src/index.py +22 -1
  4. src/repo_manager.py +40 -32
  5. src/sample-exclude.txt +62 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  # Overview
2
- `repo2vec` enables you to chat with your codebase by simply running two python scripts:
3
  ```
4
  pip install -r requirements.txt
5
 
@@ -11,7 +11,12 @@ export PINECONE_INDEX_NAME=...
11
  python src/index.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
12
  python src/chat.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
13
  ```
14
- This will bring up a `gradio` app where you can ask questions about your codebase. The assistant responses always include GitHub links to the documents retrieved for each query.
 
 
 
 
 
15
 
16
  Here is, for example, a conversation about the repo [Storia-AI/image-eval](https://github.com/Storia-AI/image-eval):
17
  ![screenshot](assets/chat_screenshot.png)
@@ -29,6 +34,16 @@ The `src/index.py` script performs the following steps:
29
  4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
30
  - By default, we use [Pinecone](https://pinecone.io) as a vector store, but you can easily plug in your own.
31
 
 
 
 
 
 
 
 
 
 
 
32
  ## Chatting via RAG
33
  The `src/chat.py` brings up a [Gradio app](https://www.gradio.app/) with a chat interface as shown above. We use [LangChain](https://langchain.com) to define a RAG chain which, given a user query about the repository:
34
 
 
1
  # Overview
2
+ `repo2vec` enables you to index your codebase and chat with it by simply running two python scripts:
3
  ```
4
  pip install -r requirements.txt
5
 
 
11
  python src/index.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
12
  python src/chat.py $GITHUB_REPO_NAME --pinecone_index_name=$PINECONE_INDEX_NAME
13
  ```
14
+ This will index your entire codebase in a vector DB, then bring up a `gradio` app where you can ask questions about it. The assistant responses always include GitHub links to the documents retrieved for each query.
15
+
16
+ To make the gradio chat app accessible publicly, you can set `--share=true`:
17
+ ```
18
+ python src/chat.py $GITHUB_REPO_NAME --share=true ...
19
+ ```
20
 
21
  Here is, for example, a conversation about the repo [Storia-AI/image-eval](https://github.com/Storia-AI/image-eval):
22
  ![screenshot](assets/chat_screenshot.png)
 
34
  4. **Stores embeddings in a vector store**. See [VectorStore](src/vector_store.py).
35
  - By default, we use [Pinecone](https://pinecone.io) as a vector store, but you can easily plug in your own.
36
 
37
+ Note you can specify an inclusion or exclusion set for the file extensions you want indexed. To specify an extension inclusion set, you can add the `--include` flag:
38
+ ```
39
+ python src/index.py repo-org/repo-name --include=/path/to/file/with/extensions
40
+ ```
41
+ Conversely, to specify an extension exclusion set, you can add the `--exclude` flag:
42
+ ```
43
+ python src/index.py repo-org/repo-name --exclude=src/sample-exclude.txt
44
+ ```
45
+ Extensions must be specified one per line, in the form `.ext`.
46
+
47
  ## Chatting via RAG
48
  The `src/chat.py` brings up a [Gradio app](https://www.gradio.app/) with a chat interface as shown above. We use [LangChain](https://langchain.com) to define a RAG chain which, given a user query about the repository:
49
 
src/chat.py CHANGED
@@ -86,11 +86,18 @@ if __name__ == "__main__":
86
  parser = argparse.ArgumentParser(description="UI to chat with your codebase")
87
  parser.add_argument("repo_id", help="The ID of the repository to index")
88
  parser.add_argument(
89
- "--openai_model", default="gpt-4", help="The OpenAI model to use for response generation"
 
 
90
  )
91
  parser.add_argument(
92
  "--pinecone_index_name", required=True, help="Pinecone index name"
93
  )
 
 
 
 
 
94
  args = parser.parse_args()
95
 
96
  rag_chain = build_rag_chain(args)
@@ -108,7 +115,9 @@ if __name__ == "__main__":
108
  answer = append_sources_to_response(response)
109
  return answer
110
 
111
- gr.ChatInterface(_predict,
112
- title=args.repo_id,
113
- description=f"Code sage for your repo: {args.repo_id}",
114
- examples=["What does this repo do?", "Give me some sample code."]).launch()
 
 
 
86
  parser = argparse.ArgumentParser(description="UI to chat with your codebase")
87
  parser.add_argument("repo_id", help="The ID of the repository to index")
88
  parser.add_argument(
89
+ "--openai_model",
90
+ default="gpt-4",
91
+ help="The OpenAI model to use for response generation",
92
  )
93
  parser.add_argument(
94
  "--pinecone_index_name", required=True, help="Pinecone index name"
95
  )
96
+ parser.add_argument(
97
+ "--share",
98
+ default=False,
99
+ help="Whether to make the gradio app publicly accessible.",
100
+ )
101
  args = parser.parse_args()
102
 
103
  rag_chain = build_rag_chain(args)
 
115
  answer = append_sources_to_response(response)
116
  return answer
117
 
118
+ gr.ChatInterface(
119
+ _predict,
120
+ title=args.repo_id,
121
+ description=f"Code sage for your repo: {args.repo_id}",
122
+ examples=["What does this repo do?", "Give me some sample code."],
123
+ ).launch(share=args.share)
src/index.py CHANGED
@@ -21,6 +21,11 @@ MAX_CHUNKS_PER_BATCH = (
21
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
22
 
23
 
 
 
 
 
 
24
  def main():
25
  parser = argparse.ArgumentParser(description="Batch-embeds a repository")
26
  parser.add_argument("repo_id", help="The ID of the repository to index")
@@ -41,6 +46,12 @@ def main():
41
  parser.add_argument(
42
  "--pinecone_index_name", required=True, help="Pinecone index name"
43
  )
 
 
 
 
 
 
44
 
45
  args = parser.parse_args()
46
 
@@ -55,9 +66,19 @@ def main():
55
  )
56
  if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
57
  parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
 
 
 
 
 
58
 
59
  logging.info("Cloning the repository...")
60
- repo_manager = RepoManager(args.repo_id, local_dir=args.local_dir)
 
 
 
 
 
61
  repo_manager.clone()
62
 
63
  logging.info("Issuing embedding jobs...")
 
21
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
22
 
23
 
24
+ def _read_extensions(path):
25
+ with open(path, "r") as f:
26
+ return {line.strip().lower() for line in f}
27
+
28
+
29
  def main():
30
  parser = argparse.ArgumentParser(description="Batch-embeds a repository")
31
  parser.add_argument("repo_id", help="The ID of the repository to index")
 
46
  parser.add_argument(
47
  "--pinecone_index_name", required=True, help="Pinecone index name"
48
  )
49
+ parser.add_argument(
50
+ "--include", help="Path to a file containing a list of extensions to include. One extension per line."
51
+ )
52
+ parser.add_argument(
53
+ "--exclude", help="Path to a file containing a list of extensions to exclude. One extension per line."
54
+ )
55
 
56
  args = parser.parse_args()
57
 
 
66
  )
67
  if args.tokens_per_chunk * args.chunks_per_batch >= MAX_TOKENS_PER_JOB:
68
  parser.error(f"The maximum number of chunks per job is {MAX_TOKENS_PER_JOB}.")
69
+ if args.include and args.exclude:
70
+ parser.error("At most one of --include and --exclude can be specified.")
71
+
72
+ included_extensions = _read_extensions(args.include) if args.include else None
73
+ excluded_extensions = _read_extensions(args.exclude) if args.exclude else None
74
 
75
  logging.info("Cloning the repository...")
76
+ repo_manager = RepoManager(
77
+ args.repo_id,
78
+ local_dir=args.local_dir,
79
+ included_extensions=included_extensions,
80
+ excluded_extensions=excluded_extensions,
81
+ )
82
  repo_manager.clone()
83
 
84
  logging.info("Issuing embedding jobs...")
src/repo_manager.py CHANGED
@@ -11,7 +11,13 @@ from git import GitCommandError, Repo
11
  class RepoManager:
12
  """Class to manage a local clone of a GitHub repository."""
13
 
14
- def __init__(self, repo_id: str, local_dir: str = None):
 
 
 
 
 
 
15
  """
16
  Args:
17
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
@@ -23,11 +29,15 @@ class RepoManager:
23
  os.makedirs(self.local_dir)
24
  self.local_path = os.path.join(self.local_dir, repo_id)
25
  self.access_token = os.getenv("GITHUB_TOKEN")
 
 
26
 
27
  @cached_property
28
  def is_public(self) -> bool:
29
  """Checks whether a GitHub repository is publicly visible."""
30
- response = requests.get(f"https://api.github.com/repos/{self.repo_id}", timeout=10)
 
 
31
  # Note that the response will be 404 for both private and non-existent repos.
32
  return response.status_code == 200
33
 
@@ -40,13 +50,17 @@ class RepoManager:
40
  if self.access_token:
41
  headers["Authorization"] = f"token {self.access_token}"
42
 
43
- response = requests.get(f"https://api.github.com/repos/{self.repo_id}", headers=headers)
 
 
44
  if response.status_code == 200:
45
  branch = response.json().get("default_branch", "main")
46
  else:
47
  # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
48
  # most common naming for the default branch ("main").
49
- logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
 
 
50
  branch = "main"
51
  return branch
52
 
@@ -73,12 +87,20 @@ class RepoManager:
73
  return False
74
  return True
75
 
76
- def walk(
77
- self,
78
- included_extensions: set = None,
79
- excluded_extensions: set = None,
80
- log_dir: str = None,
81
- ):
 
 
 
 
 
 
 
 
82
  """Walks the local repository path and yields a tuple of (filepath, content) for each file.
83
  The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
84
 
@@ -87,24 +109,6 @@ class RepoManager:
87
  excluded_extensions: Optional set of extensions to exclude.
88
  log_dir: Optional directory where to log the included and excluded files.
89
  """
90
- # Convert included and excluded extensions to lowercase.
91
- if included_extensions:
92
- included_extensions = {ext.lower() for ext in included_extensions}
93
- if excluded_extensions:
94
- excluded_extensions = {ext.lower() for ext in excluded_extensions}
95
-
96
- def include(file_path: str) -> bool:
97
- _, extension = os.path.splitext(file_path)
98
- extension = extension.lower()
99
- if included_extensions and extension not in included_extensions:
100
- return False
101
- if excluded_extensions and extension in excluded_extensions:
102
- return False
103
- # Exclude hidden files and directories.
104
- if any(part.startswith(".") for part in file_path.split(os.path.sep)):
105
- return False
106
- return True
107
-
108
  # We will keep apending to these files during the iteration, so we need to clear them first.
109
  if log_dir:
110
  repo_name = self.repo_id.replace("/", "_")
@@ -117,7 +121,7 @@ class RepoManager:
117
 
118
  for root, _, files in os.walk(self.local_path):
119
  file_paths = [os.path.join(root, file) for file in files]
120
- included_file_paths = [f for f in file_paths if include(f)]
121
 
122
  if log_dir:
123
  with open(included_log_file, "a") as f:
@@ -136,11 +140,15 @@ class RepoManager:
136
  try:
137
  contents = f.read()
138
  except UnicodeDecodeError:
139
- logging.warning("Unable to decode file %s. Skipping.", file_path)
 
 
140
  continue
141
  yield file_path[len(self.local_dir) + 1 :], contents
142
 
143
  def github_link_for_file(self, file_path: str) -> str:
144
  """Converts a repository file path to a GitHub link."""
145
- file_path = file_path[len(self.repo_id):]
146
- return f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
 
 
 
11
  class RepoManager:
12
  """Class to manage a local clone of a GitHub repository."""
13
 
14
+ def __init__(
15
+ self,
16
+ repo_id: str,
17
+ local_dir: str = None,
18
+ included_extensions: set = None,
19
+ excluded_extensions: set = None,
20
+ ):
21
  """
22
  Args:
23
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/repo2vec".
 
29
  os.makedirs(self.local_dir)
30
  self.local_path = os.path.join(self.local_dir, repo_id)
31
  self.access_token = os.getenv("GITHUB_TOKEN")
32
+ self.included_extensions = included_extensions
33
+ self.excluded_extensions = excluded_extensions
34
 
35
  @cached_property
36
  def is_public(self) -> bool:
37
  """Checks whether a GitHub repository is publicly visible."""
38
+ response = requests.get(
39
+ f"https://api.github.com/repos/{self.repo_id}", timeout=10
40
+ )
41
  # Note that the response will be 404 for both private and non-existent repos.
42
  return response.status_code == 200
43
 
 
50
  if self.access_token:
51
  headers["Authorization"] = f"token {self.access_token}"
52
 
53
+ response = requests.get(
54
+ f"https://api.github.com/repos/{self.repo_id}", headers=headers
55
+ )
56
  if response.status_code == 200:
57
  branch = response.json().get("default_branch", "main")
58
  else:
59
  # This happens sometimes when we exceed the Github rate limit. The best bet in this case is to assume the
60
  # most common naming for the default branch ("main").
61
+ logging.warn(
62
+ f"Unable to fetch default branch for {self.repo_id}: {response.text}"
63
+ )
64
  branch = "main"
65
  return branch
66
 
 
87
  return False
88
  return True
89
 
90
+ def _should_include(self, file_path: str) -> bool:
91
+ """Checks whether the file should be indexed, based on the included and excluded extensions."""
92
+ _, extension = os.path.splitext(file_path)
93
+ extension = extension.lower()
94
+ if self.included_extensions and extension not in self.included_extensions:
95
+ return False
96
+ if self.excluded_extensions and extension in self.excluded_extensions:
97
+ return False
98
+ # Exclude hidden files and directories.
99
+ if any(part.startswith(".") for part in file_path.split(os.path.sep)):
100
+ return False
101
+ return True
102
+
103
+ def walk(self, log_dir: str = None):
104
  """Walks the local repository path and yields a tuple of (filepath, content) for each file.
105
  The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").
106
 
 
109
  excluded_extensions: Optional set of extensions to exclude.
110
  log_dir: Optional directory where to log the included and excluded files.
111
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # We will keep apending to these files during the iteration, so we need to clear them first.
113
  if log_dir:
114
  repo_name = self.repo_id.replace("/", "_")
 
121
 
122
  for root, _, files in os.walk(self.local_path):
123
  file_paths = [os.path.join(root, file) for file in files]
124
+ included_file_paths = [f for f in file_paths if self._should_include(f)]
125
 
126
  if log_dir:
127
  with open(included_log_file, "a") as f:
 
140
  try:
141
  contents = f.read()
142
  except UnicodeDecodeError:
143
+ logging.warning(
144
+ "Unable to decode file %s. Skipping.", file_path
145
+ )
146
  continue
147
  yield file_path[len(self.local_dir) + 1 :], contents
148
 
149
  def github_link_for_file(self, file_path: str) -> str:
150
  """Converts a repository file path to a GitHub link."""
151
+ file_path = file_path[len(self.repo_id) :]
152
+ return (
153
+ f"https://github.com/{self.repo_id}/blob/{self.default_branch}/{file_path}"
154
+ )
src/sample-exclude.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .avi
2
+ .bazel
3
+ .bin
4
+ .binpb
5
+ .bmp
6
+ .crt
7
+ .css
8
+ .dat
9
+ .db
10
+ .duckdb
11
+ .eot
12
+ .exe
13
+ .gif
14
+ .gguf
15
+ .glb
16
+ .gz
17
+ .ico
18
+ .icns
19
+ .inp
20
+ .ipynb
21
+ .isl
22
+ .jar
23
+ .jpeg
24
+ .jpg
25
+ .json
26
+ .key
27
+ .lock
28
+ .mo
29
+ .model
30
+ .mov
31
+ .mp3
32
+ .mp4
33
+ .otf
34
+ .out
35
+ .Packages
36
+ .pb
37
+ .pdf
38
+ .pem
39
+ .pickle
40
+ .png
41
+ .pt
42
+ .ptl
43
+ .s
44
+ .sqlite
45
+ .stl
46
+ .sum
47
+ .svg
48
+ .tar
49
+ .th
50
+ .tgz
51
+ .toml
52
+ .ts-fixture
53
+ .ttf
54
+ .wav
55
+ .webp
56
+ .wmv
57
+ .woff
58
+ .woff2
59
+ .xml
60
+ .yaml
61
+ .yml
62
+ .zip