juliaturc commited on
Commit
9802b75
·
1 Parent(s): df1b188

Add YAML configurations (#38)

Browse files
MANIFEST.in CHANGED
@@ -1 +1,3 @@
1
- include sage/sample-exclude.txt
 
 
 
1
+ include sage/sample-exclude.txt
2
+ include sage/configs/local.yaml
3
+ include sage/configs/remote.yaml
README.md CHANGED
@@ -49,7 +49,7 @@ pip install git+https://github.com/Storia-AI/sage.git@main
49
  2. Enables chatting via LLM + RAG (requiring access to an LLM)
50
 
51
  <details open>
52
- <summary><strong>:computer: Running locally</strong></summary>
53
 
54
  1. To index the codebase locally, we use the open-source project <a href="https://github.com/marqo-ai/marqo">Marqo</a>, which is both an embedder and a vector store. To bring up a Marqo instance:
55
 
@@ -70,7 +70,7 @@ pip install git+https://github.com/Storia-AI/sage.git@main
70
  </details>
71
 
72
  <details>
73
- <summary><strong>:cloud: Using external providers</strong></summary>
74
 
75
  1. We support <a href="https://openai.com/">OpenAI</a> for embeddings (they have a super fast batch embedding API) and <a href="https://www.pinecone.io/">Pinecone</a> for the vector store. So you will need two API keys:
76
 
@@ -84,37 +84,27 @@ pip install git+https://github.com/Storia-AI/sage.git@main
84
  export PINECONE_INDEX_NAME=...
85
  ```
86
 
87
- 2. For chatting with an LLM, we support OpenAI and Anthropic. For the latter, set an additional API key:
88
-
89
  ```
90
- export ANTHROPIC_API_KEY=...
 
 
91
  ```
92
 
93
- </details>
94
-
95
- <br>
96
- <summary><strong>Optional</strong></summary>
97
-
98
- - By default, we use an <a href="https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2">open-source re-ranker</a>. For higher accuracy, you can use <a href="https://cohere.com/rerank">Cohere</a>, <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
99
-
100
  ```
101
- export COHERE_API_KEY=...
102
- export NVIDIA_API_KEY=...
103
- export JINA_API_KEY=...
104
  ```
105
 
106
- We are seeing significant gains in accuracy from these proprietary rerankers.
107
 
108
- - If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
 
109
 
110
  export GITHUB_TOKEN=...
111
 
112
-
113
  ## Running it
114
 
115
- <details open>
116
- <summary><strong>:computer: Run locally</strong></summary>
117
-
118
  1. Select your desired repository:
119
  ```
120
  export GITHUB_REPO=huggingface/transformers
@@ -124,41 +114,18 @@ pip install git+https://github.com/Storia-AI/sage.git@main
124
  ```
125
  sage-index $GITHUB_REPO
126
  ```
 
127
 
128
  3. Chat with the repository, once it's indexed:
129
  ```
130
  sage-chat $GITHUB_REPO
131
  ```
132
- To get a public URL for your chat app, set `--share=true`.
133
-
134
  </details>
135
 
136
- <details>
137
- <summary><strong>:cloud: Use external providers</strong></summary>
138
-
139
- 1. Select your desired repository:
140
- ```
141
- export GITHUB_REPO=huggingface/transformers
142
- ```
143
-
144
- 2. Index the repository. This might take a few minutes, depending on its size.
145
- ```
146
- sage-index $GITHUB_REPO \
147
- --embedder-type=openai \
148
- --vector-store=pinecone \
149
- --index-name=$PINECONE_INDEX_NAME
150
- ```
151
-
152
- 3. Chat with the repository, once it's indexed:
153
- ```
154
- sage-chat $GITHUB_REPO \
155
- --vector-store-type=pinecone \
156
- --index-name=$PINECONE_INDEX_NAME \
157
- --llm-provider=openai \
158
- --llm-model=gpt-4
159
- ```
160
- To get a public URL for your chat app, set `--share=true`.
161
- </details>
162
 
163
  ## Additional features
164
 
 
49
  2. Enables chatting via LLM + RAG (requiring access to an LLM)
50
 
51
  <details open>
52
+ <summary><strong>:computer: Running locally (lower quality)</strong></summary>
53
 
54
  1. To index the codebase locally, we use the open-source project <a href="https://github.com/marqo-ai/marqo">Marqo</a>, which is both an embedder and a vector store. To bring up a Marqo instance:
55
 
 
70
  </details>
71
 
72
  <details>
73
+ <summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
74
 
75
  1. We support <a href="https://openai.com/">OpenAI</a> for embeddings (they have a super fast batch embedding API) and <a href="https://www.pinecone.io/">Pinecone</a> for the vector store. So you will need two API keys:
76
 
 
84
  export PINECONE_INDEX_NAME=...
85
  ```
86
 
87
+ 3. For reranking, we use <a href="https://cohere.com/rerank">Cohere</a> by default, but you can also try rerankers from <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
 
88
  ```
89
+ export COHERE_API_KEY=... # or
90
+ export NVIDIA_API_KEY=... # or
91
+ export JINA_API_KEY=...
92
  ```
93
 
94
+ 4. For chatting with an LLM, we support OpenAI and Anthropic. For the latter, set an additional API key:
 
 
 
 
 
 
95
  ```
96
+ export ANTHROPIC_API_KEY=...
 
 
97
  ```
98
 
99
+ </details>
100
 
101
+ ### Optional
102
+ If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
103
 
104
  export GITHUB_TOKEN=...
105
 
 
106
  ## Running it
107
 
 
 
 
108
  1. Select your desired repository:
109
  ```
110
  export GITHUB_REPO=huggingface/transformers
 
114
  ```
115
  sage-index $GITHUB_REPO
116
  ```
117
+ To use external providers instead of running locally, set `--mode=remote`.
118
 
119
  3. Chat with the repository, once it's indexed:
120
  ```
121
  sage-chat $GITHUB_REPO
122
  ```
123
+ To use external providers instead of running locally, set `--mode=remote`.
 
124
  </details>
125
 
126
+ ### Notes:
127
+ - To get a public URL for your chat app, set `--share=true`.
128
+ - You can overwrite the default settings (e.g. desired embedding model or LLM) via command line flags. Run `sage-index --help` or `sage-chat --help` for a full list.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  ## Additional features
131
 
sage/chat.py CHANGED
@@ -3,11 +3,11 @@
3
  You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
4
  """
5
 
6
- import argparse
7
  import logging
8
- import os
9
 
 
10
  import gradio as gr
 
11
  from dotenv import load_dotenv
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -15,9 +15,10 @@ from langchain.retrievers import ContextualCompressionRetriever
15
  from langchain.schema import AIMessage, HumanMessage
16
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
 
18
- import sage.vector_store as vector_store
19
  from sage.llm import build_llm_via_langchain
20
- from sage.reranker import build_reranker, RerankerProvider
 
21
 
22
  load_dotenv()
23
 
@@ -27,7 +28,7 @@ def build_rag_chain(args):
27
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
28
 
29
  retriever_top_k = 5 if args.reranker_provider == "none" else 25
30
- retriever = vector_store.build_from_args(args).as_retriever(top_k=retriever_top_k)
31
  compressor = build_reranker(args.reranker_provider, args.reranker_model)
32
  if compressor:
33
  retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
@@ -70,49 +71,27 @@ def build_rag_chain(args):
70
 
71
 
72
  def main():
73
- parser = argparse.ArgumentParser(description="UI to chat with your codebase")
74
- parser.add_argument("repo_id", help="The ID of the repository to index")
75
- parser.add_argument("--llm-provider", default="ollama", choices=["openai", "anthropic", "ollama"])
76
- parser.add_argument(
77
- "--llm-model",
78
- help="The LLM name. Must be supported by the provider specified via --llm-provider.",
79
  )
80
- parser.add_argument("--vector-store-type", default="marqo", choices=["pinecone", "marqo"])
81
- parser.add_argument("--index-name", help="Vector store index name. Required for Pinecone.")
82
- parser.add_argument(
83
- "--marqo-url",
84
- default="http://localhost:8882",
85
- help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
86
- )
87
- parser.add_argument("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
88
- parser.add_argument(
89
- "--reranker-model",
90
- help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
91
- "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
92
- )
93
- parser.add_argument(
94
  "--share",
95
  default=False,
96
  help="Whether to make the gradio app publicly accessible.",
97
  )
98
- parser.add_argument(
99
- "--hybrid-retrieval",
100
- action=argparse.BooleanOptionalAction,
101
- default=True,
102
- help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
103
- "retrieval. This is only relevant if using Pinecone as the vector store.",
104
- )
 
 
105
  args = parser.parse_args()
106
 
107
- if not args.llm_model:
108
- if args.llm_provider == "openai":
109
- args.llm_model = "gpt-4"
110
- elif args.llm_provider == "anthropic":
111
- args.llm_model = "claude-3-opus-20240229"
112
- elif args.llm_provider == "ollama":
113
- args.llm_model = "llama3.1"
114
- else:
115
- raise ValueError("Please specify --llm_model")
116
 
117
  rag_chain = build_rag_chain(args)
118
 
 
3
  You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
4
  """
5
 
 
6
  import logging
 
7
 
8
+ import configargparse
9
  import gradio as gr
10
+ import pkg_resources
11
  from dotenv import load_dotenv
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
15
  from langchain.schema import AIMessage, HumanMessage
16
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
 
18
+ import sage.config as sage_config
19
  from sage.llm import build_llm_via_langchain
20
+ from sage.reranker import build_reranker
21
+ from sage.vector_store import build_vector_store_from_args
22
 
23
  load_dotenv()
24
 
 
28
  llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
29
 
30
  retriever_top_k = 5 if args.reranker_provider == "none" else 25
31
+ retriever = build_vector_store_from_args(args).as_retriever(top_k=retriever_top_k)
32
  compressor = build_reranker(args.reranker_provider, args.reranker_model)
33
  if compressor:
34
  retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
 
71
 
72
 
73
  def main():
74
+ parser = configargparse.ArgParser(
75
+ description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
 
 
 
 
76
  )
77
+ parser.add(
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "--share",
79
  default=False,
80
  help="Whether to make the gradio app publicly accessible.",
81
  )
82
+ sage_config.add_config_args(parser)
83
+
84
+ arg_validators = [
85
+ sage_config.add_repo_args(parser),
86
+ sage_config.add_vector_store_args(parser),
87
+ sage_config.add_reranking_args(parser),
88
+ sage_config.add_llm_args(parser),
89
+ ]
90
+
91
  args = parser.parse_args()
92
 
93
+ for validator in arg_validators:
94
+ validator(args)
 
 
 
 
 
 
 
95
 
96
  rag_chain = build_rag_chain(args)
97
 
sage/config.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility methods to define and validate flags."""
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import re
7
+ from typing import Callable
8
+
9
+ import pkg_resources
10
+ from configargparse import ArgumentParser
11
+
12
+ from sage.reranker import RerankerProvider
13
+
14
+ MARQO_MAX_CHUNKS_PER_BATCH = 64
15
+ # The ADA embedder from OpenAI has a maximum of 8192 tokens.
16
+ OPENAI_MAX_TOKENS_PER_CHUNK = 8192
17
+ # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
18
+ OPENAI_MAX_CHUNKS_PER_BATCH = 2048
19
+ # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
20
+ OPENAI_MAX_TOKENS_PER_JOB = 3_000_000
21
+
22
+ # Note that OpenAI embedding models have fixed dimensions, however, taking a slice of them is possible.
23
+ # See "Reducing embedding dimensions" under https://platform.openai.com/docs/guides/embeddings/use-cases and
24
+ # https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
25
+ OPENAI_DEFAULT_EMBEDDING_SIZE = {
26
+ "text-embedding-ada-002": 1536,
27
+ "text-embedding-3-small": 1536,
28
+ "text-embedding-3-large": 3072,
29
+ }
30
+
31
+
32
+ def add_config_args(parser: ArgumentParser):
33
+ """Adds configuration-related arguments to the parser."""
34
+ parser.add(
35
+ "--mode",
36
+ choices=["local", "remote"],
37
+ default="local",
38
+ help="Whether to use local-only resources or call third-party providers.",
39
+ )
40
+ parser.add(
41
+ "--config",
42
+ is_config_file=True,
43
+ help="Path to .yaml configuration file.",
44
+ )
45
+ args, _ = parser.parse_known_args()
46
+ config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml")
47
+ parser.set_defaults(config=config_file)
48
+
49
+
50
+ def add_repo_args(parser: ArgumentParser) -> Callable:
51
+ """Adds repository-related arguments to the parser and returns a validator."""
52
+ parser.add("repo_id", help="The ID of the repository to index")
53
+ parser.add("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.")
54
+ parser.add(
55
+ "--local-dir",
56
+ default="repos",
57
+ help="The local directory to store the repository",
58
+ )
59
+ return validate_repo_args
60
+
61
+
62
+ def add_embedding_args(parser: ArgumentParser) -> Callable:
63
+ """Adds embedding-related arguments to the parser and returns a validator."""
64
+ parser.add("--embedding-provider", default="marqo", choices=["openai", "marqo"])
65
+ parser.add(
66
+ "--embedding-model",
67
+ type=str,
68
+ default=None,
69
+ help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
70
+ )
71
+ parser.add(
72
+ "--embedding-size",
73
+ type=int,
74
+ default=None,
75
+ help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for "
76
+ "large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.",
77
+ )
78
+ parser.add(
79
+ "--tokens-per-chunk",
80
+ type=int,
81
+ default=800,
82
+ help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
83
+ )
84
+ parser.add(
85
+ "--chunks-per-batch",
86
+ type=int,
87
+ help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
88
+ )
89
+ parser.add(
90
+ "--max-embedding-jobs",
91
+ type=int,
92
+ help="Maximum number of embedding jobs to run. Specifying this might result in "
93
+ "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
94
+ )
95
+ return validate_embedding_args
96
+
97
+
98
+ def add_vector_store_args(parser: ArgumentParser) -> Callable:
99
+ """Adds vector store-related arguments to the parser and returns a validator."""
100
+ parser.add("--vector-store-provider", default="marqo", choices=["pinecone", "marqo"])
101
+ parser.add(
102
+ "--pinecone-index-name",
103
+ default=None,
104
+ help="Pinecone index name. Required if using Pinecone as the vector store. If the index doesn't exist already, "
105
+ "we will create it.",
106
+ )
107
+ parser.add(
108
+ "--index-namespace",
109
+ default=None,
110
+ help="Index namespace for this repo. When not specified, we default it to a derivative of the repo name.",
111
+ )
112
+ parser.add(
113
+ "--marqo-url",
114
+ default="http://localhost:8882",
115
+ help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
116
+ )
117
+ parser.add(
118
+ "--hybrid-retrieval",
119
+ action=argparse.BooleanOptionalAction,
120
+ default=True,
121
+ help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
122
+ "retrieval. This is only relevant if using Pinecone as the vector store.",
123
+ )
124
+ return validate_vector_store_args
125
+
126
+
127
+ def add_indexing_args(parser: ArgumentParser) -> Callable:
128
+ """Adds indexing-related arguments to the parser and returns a validator."""
129
+ parser.add(
130
+ "--include",
131
+ help="Path to a file containing a list of extensions to include. One extension per line.",
132
+ )
133
+ parser.add(
134
+ "--exclude",
135
+ help="Path to a file containing a list of extensions to exclude. One extension per line.",
136
+ )
137
+ # Pass --no-index-repo in order to not index the repository.
138
+ parser.add(
139
+ "--index-repo",
140
+ action=argparse.BooleanOptionalAction,
141
+ default=True,
142
+ help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.",
143
+ )
144
+ # Pass --no-index-issues in order to not index the issues.
145
+ parser.add(
146
+ "--index-issues",
147
+ action=argparse.BooleanOptionalAction,
148
+ default=False,
149
+ help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True. When "
150
+ "--index-issues is set, you must also set a GITHUB_TOKEN environment variable.",
151
+ )
152
+ # Pass --no-index-issue-comments in order to not index the comments of GitHub issues.
153
+ parser.add(
154
+ "--index-issue-comments",
155
+ action=argparse.BooleanOptionalAction,
156
+ default=False,
157
+ help="Whether to index the comments of GitHub issues. This is only relevant if --index-issues is set. "
158
+ "GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
159
+ "of the gains anyway.",
160
+ )
161
+ return validate_indexing_args
162
+
163
+
164
+ def add_reranking_args(parser: ArgumentParser) -> Callable:
165
+ """Adds reranking-related arguments to the parser."""
166
+ parser.add("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
167
+ parser.add(
168
+ "--reranker-model",
169
+ help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
170
+ "SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
171
+ )
172
+ # Trivial validator (nothing to check).
173
+ return lambda _: True
174
+
175
+
176
+ def add_llm_args(parser: ArgumentParser) -> Callable:
177
+ """Adds language model-related arguments to the parser."""
178
+ parser.add("--llm-provider", default="ollama", choices=["openai", "anthropic", "ollama"])
179
+ parser.add(
180
+ "--llm-model",
181
+ help="The LLM name. Must be supported by the provider specified via --llm-provider.",
182
+ )
183
+ # Trivial validator (nothing to check).
184
+ return lambda _: True
185
+
186
+
187
+ def validate_repo_args(args):
188
+ """Validates the configuration of the repository."""
189
+ if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
190
+ raise ValueError("repo_id must be in the format 'owner/repo'")
191
+
192
+
193
+ def _validate_openai_embedding_args(args):
194
+ """Validates the configuration of the OpenAI batch embedder and sets defaults."""
195
+ if args.embedding_provider == "openai" and not os.getenv("OPENAI_API_KEY"):
196
+ raise ValueError("Please set the OPENAI_API_KEY environment variable.")
197
+
198
+ if not args.embedding_model:
199
+ args.embedding_model = "text-embedding-ada-002"
200
+
201
+ if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
202
+ raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
203
+
204
+ if not args.embedding_size:
205
+ args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
206
+
207
+ if not args.tokens_per_chunk:
208
+ # https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.
209
+ args.tokens_per_chunk = 800
210
+ elif args.tokens_per_chunk > OPENAI_MAX_TOKENS_PER_CHUNK:
211
+ args.tokens_per_chunk = OPENAI_MAX_TOKENS_PER_CHUNK
212
+ logging.warning(
213
+ f"OpenAI enforces a limit of {OPENAI_MAX_TOKENS_PER_CHUNK} tokens per chunk. "
214
+ "Overwriting embeddings.tokens_per_chunk."
215
+ )
216
+
217
+ if not args.chunks_per_batch:
218
+ args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
219
+ elif args.chunks_per_batch > OPENAI_MAX_CHUNKS_PER_BATCH:
220
+ args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
221
+ logging.warning(
222
+ f"OpenAI enforces a limit of {OPENAI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
223
+ "Overwriting embeddings.chunks_per_batch."
224
+ )
225
+
226
+ chunks_per_job = args.tokens_per_chunk * args.chunks_per_batch
227
+ if chunks_per_job >= OPENAI_MAX_TOKENS_PER_JOB:
228
+ raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
229
+
230
+
231
+ def _validate_marqo_embedding_args(args):
232
+ """Validates the configuration of the Marqo batch embedder and sets defaults."""
233
+ if not args.embedding_model:
234
+ args.embedding_model = "hf/e5-base-v2"
235
+
236
+ if not args.chunks_per_batch:
237
+ args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
238
+ elif args.chunks_per_batch > MARQO_MAX_CHUNKS_PER_BATCH:
239
+ args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
240
+ logging.warning(
241
+ f"Marqo enforces a limit of {MARQO_MAX_CHUNKS_PER_BATCH} chunks per batch. "
242
+ "Overwriting embeddings.chunks_per_batch."
243
+ )
244
+
245
+
246
+ def validate_embedding_args(args):
247
+ """Validates the configuration of the batch embedder and sets defaults."""
248
+ if args.embedding_provider == "openai":
249
+ _validate_openai_embedding_args(args)
250
+ elif args.embedding_provider == "marqo":
251
+ _validate_marqo_embedding_args(args)
252
+ else:
253
+ raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
254
+
255
+
256
+ def validate_vector_store_args(args):
257
+ """Validates the configuration of the vector store and sets defaults."""
258
+
259
+ if not args.index_namespace:
260
+ args.index_namespace = args.repo_id
261
+ if args.commit_hash:
262
+ args.index_namespace += "/" + args.commit_hash
263
+ if args.vector_store_provider == "marqo":
264
+ # Marqo doesn't allow slashes in the index namespace.
265
+ args.index_namespace = args.index_namespace.replace("/", "_")
266
+
267
+ if args.vector_store_provider == "marqo":
268
+ if not args.marqo_url:
269
+ args.marqo_url = "http://localhost:8882"
270
+ if "/" in args.index_namespace:
271
+ raise ValueError(f"Marqo doesn't allow slashes in --index-namespace={args.index_namespace}.")
272
+
273
+ elif args.vector_store_provider == "pinecone":
274
+ if not os.getenv("PINECONE_API_KEY"):
275
+ raise ValueError("Please set the PINECONE_API_KEY environment variable.")
276
+ if not args.pinecone_index_name:
277
+ raise ValueError(f"Please set the vector_store.pinecone_index_name value.")
278
+
279
+
280
+ def validate_indexing_args(args):
281
+ """Validates the indexing configuration and sets defaults."""
282
+ if args.include and args.exclude:
283
+ raise ValueError("At most one of indexing.include and indexing.exclude can be specified.")
284
+ if not args.include and not args.exclude:
285
+ args.exclude = pkg_resources.resource_filename(__name__, "sample-exclude.txt")
286
+ if args.include and not os.path.exists(args.include):
287
+ raise ValueError(f"Path --include={args.include} does not exist.")
288
+ if args.exclude and not os.path.exists(args.exclude):
289
+ raise ValueError(f"Path --exclude={args.exclude} does not exist.")
290
+ if not args.index_repo and not args.index_issues:
291
+ raise ValueError("Either --index_repo or --index_issues must be set to true.")
292
+ if args.index_issues and not os.getenv("GITHUB_TOKEN"):
293
+ raise ValueError("Please set the GITHUB_TOKEN environment variable.")
sage/configs/local.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embeddings
2
+ embedding-provider: marqo
3
+ embedding-model: hf/e5-base-v2
4
+ tokens-per-chunk: 800
5
+ chunks-per-batch: 64
6
+
7
+ # Vector store
8
+ vector-store-provider: marqo
9
+
10
+ # LLM
11
+ llm-provider: ollama
12
+ llm-model: llama3.1
13
+
14
+ # Reranking
15
+ reranking-provider: huggingface
16
+ reranking-model: cross-encoder/ms-marco-MiniLM-L-6-v2
sage/configs/remote.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embeddings
2
+ embedding-provider: openai
3
+ embedding-model: text-embedding-ada-002
4
+ tokens-per-chunk: 800
5
+ chunks-per-batch: 2000
6
+
7
+ # Vector store
8
+ vector-store-provider: pinecone
9
+ pinecone-index-name: sage
10
+ hybrid-retrieval: true
11
+
12
+ # LLM
13
+ llm-provider: openai
14
+ llm-model: gpt-4
15
+
16
+ # Reranking
17
+ reranking-provider: cohere
18
+ reranking-model: rerank-english-v3.0
sage/embedder.py CHANGED
@@ -268,11 +268,11 @@ class MarqoEmbedder(BatchEmbedder):
268
 
269
 
270
  def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
271
- if args.embedder_type == "openai":
272
  return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
273
- elif args.embedder_type == "marqo":
274
  return MarqoEmbedder(
275
- data_manager, chunker, index_name=args.index_name, url=args.marqo_url, model=args.embedding_model
276
  )
277
  else:
278
- raise ValueError(f"Unrecognized embedder type {args.embedder_type}")
 
268
 
269
 
270
  def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
271
+ if args.embedding_provider == "openai":
272
  return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
273
+ elif args.embedding_provider == "marqo":
274
  return MarqoEmbedder(
275
+ data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
276
  )
277
  else:
278
+ raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
sage/index.py CHANGED
@@ -1,196 +1,46 @@
1
  """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
 
3
- import argparse
4
  import logging
5
- import os
6
  import time
7
 
 
8
  import pkg_resources
9
 
 
10
  from sage.chunker import UniversalFileChunker
11
  from sage.data_manager import GitHubRepoManager
12
  from sage.embedder import build_batch_embedder_from_flags
13
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
14
- from sage.vector_store import build_from_args
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger()
18
  logger.setLevel(logging.INFO)
19
 
20
- MARQO_MAX_CHUNKS_PER_BATCH = 64
21
-
22
- OPENAI_MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
23
- OPENAI_MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
24
- OPENAI_MAX_TOKENS_PER_JOB = (
25
- 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
26
- )
27
-
28
- # Note that OpenAI embedding models have fixed dimensions, however, taking a slice of them is possible.
29
- # See "Reducing embedding dimensions" under https://platform.openai.com/docs/guides/embeddings/use-cases and
30
- # https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
31
- OPENAI_DEFAULT_EMBEDDING_SIZE = {
32
- "text-embedding-ada-002": 1536,
33
- "text-embedding-3-small": 1536,
34
- "text-embedding-3-large": 3072,
35
- }
36
-
37
 
38
  def main():
39
- parser = argparse.ArgumentParser(description="Batch-embeds a GitHub repository and its issues.")
40
- parser.add_argument("repo_id", help="The ID of the repository to index")
41
- parser.add_argument("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.")
42
- parser.add_argument("--embedder-type", default="marqo", choices=["openai", "marqo"])
43
- parser.add_argument(
44
- "--embedding-model",
45
- type=str,
46
- default=None,
47
- help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
48
- )
49
- parser.add_argument(
50
- "--embedding-size",
51
- type=int,
52
- default=None,
53
- help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for "
54
- "large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.",
55
- )
56
- parser.add_argument("--vector-store-type", default="marqo", choices=["pinecone", "marqo"])
57
- parser.add_argument(
58
- "--local-dir",
59
- default="repos",
60
- help="The local directory to store the repository",
61
- )
62
- parser.add_argument(
63
- "--tokens-per-chunk",
64
- type=int,
65
- default=800,
66
- help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
67
- )
68
- parser.add_argument(
69
- "--chunks-per-batch",
70
- type=int,
71
- help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
72
- )
73
- parser.add_argument(
74
- "--pinecone-index-name",
75
- default=None,
76
- help="Pinecone index name. Required if using Pinecone as the vector store. If the index doesn't exist already, "
77
- "we will create it.",
78
- )
79
- parser.add_argument(
80
- "--index-namespace",
81
- default=None,
82
- help="Index namespace for this repo. When not specified, we default it to a derivative of the repo name."
83
- )
84
- parser.add_argument(
85
- "--include",
86
- help="Path to a file containing a list of extensions to include. One extension per line.",
87
- )
88
- parser.add_argument(
89
- "--exclude",
90
- help="Path to a file containing a list of extensions to exclude. One extension per line.",
91
- )
92
- parser.add_argument(
93
- "--max-embedding-jobs",
94
- type=int,
95
- help="Maximum number of embedding jobs to run. Specifying this might result in "
96
- "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
97
  )
98
- parser.add_argument(
99
- "--marqo-url",
100
- default="http://localhost:8882",
101
- help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
102
- )
103
- # Pass --no-index-repo in order to not index the repository.
104
- parser.add_argument(
105
- "--index-repo",
106
- action=argparse.BooleanOptionalAction,
107
- default=True,
108
- help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.",
109
- )
110
- # Pass --no-index-issues in order to not index the issues.
111
- parser.add_argument(
112
- "--index-issues",
113
- action=argparse.BooleanOptionalAction,
114
- default=False,
115
- help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True. When "
116
- "--index-issues is set, you must also set a GITHUB_TOKEN environment variable.",
117
- )
118
- parser.add_argument(
119
- "--index-issue-comments",
120
- action=argparse.BooleanOptionalAction,
121
- default=False,
122
- help="Whether to index the comments of GitHub issues. This is only relevant if --index-issues is set. "
123
- "GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
124
- "of the gains anyway.",
125
- )
126
- parser.add_argument(
127
- "--hybrid-retrieval",
128
- action=argparse.BooleanOptionalAction,
129
- default=True,
130
- help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
131
- "retrieval. This is only relevant if using Pinecone as the vector store.",
132
- )
133
- args = parser.parse_args()
134
 
135
- # Validate embedder and vector store compatibility.
136
- if args.embedder_type == "openai" and args.vector_store_type != "pinecone":
137
- parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
138
- if args.embedder_type == "marqo" and args.vector_store_type != "marqo":
139
- parser.error("When using the marqo embedder, the vector store type must also be marqo.")
140
- if args.vector_store_type == "marqo":
141
- if "/" in args.index_namespace:
142
- parser.error("The index namespace cannot contain slashes when using Marqo as the vector store.")
143
- elif args.vector_store_type == "pinecone" and not args.pinecone_index_name:
144
- parser.error("When using Pinecone as the vector store, you must specify --pinecone-index-name")
145
 
146
- # Validate embedder parameters.
147
- if args.embedder_type == "marqo":
148
- if args.embedding_model is None:
149
- args.embedding_model = "hf/e5-base-v2"
150
- if args.chunks_per_batch is None:
151
- args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
152
- elif args.chunks_per_batch > MARQO_MAX_CHUNKS_PER_BATCH:
153
- args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
154
- logging.warning(
155
- f"Marqo enforces a limit of {MARQO_MAX_CHUNKS_PER_BATCH} chunks per batch. "
156
- "Overwriting --chunks_per_batch."
157
- )
158
- elif args.embedder_type == "openai":
159
- if args.tokens_per_chunk > OPENAI_MAX_TOKENS_PER_CHUNK:
160
- args.tokens_per_chunk = OPENAI_MAX_TOKENS_PER_CHUNK
161
- logging.warning(
162
- f"OpenAI enforces a limit of {OPENAI_MAX_TOKENS_PER_CHUNK} tokens per chunk. "
163
- "Overwriting --tokens_per_chunk."
164
- )
165
- if args.chunks_per_batch is None:
166
- args.chunks_per_batch = 2000
167
- elif args.chunks_per_batch > OPENAI_MAX_CHUNKS_PER_BATCH:
168
- args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
169
- logging.warning(
170
- f"OpenAI enforces a limit of {OPENAI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
171
- "Overwriting --chunks_per_batch."
172
- )
173
- if args.tokens_per_chunk * args.chunks_per_batch >= OPENAI_MAX_TOKENS_PER_JOB:
174
- parser.error(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}.")
175
- if args.embedding_model is None:
176
- args.embedding_model = "text-embedding-ada-002"
177
- if args.embedding_size is None:
178
- args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
179
 
180
- if args.include and args.exclude:
181
- parser.error("At most one of --include and --exclude can be specified.")
182
- if not args.include and not args.exclude:
183
- args.exclude = pkg_resources.resource_filename(__name__, "sample-exclude.txt")
184
- if not args.index_repo and not args.index_issues:
185
- parser.error("At least one of --index-repo and --index-issues must be true.")
186
 
187
- # Fail early on missing environment variables.
188
- if args.embedder_type == "openai" and not os.getenv("OPENAI_API_KEY"):
189
- parser.error("Please set the OPENAI_API_KEY environment variable.")
190
- if args.vector_store_type == "pinecone" and not os.getenv("PINECONE_API_KEY"):
191
- parser.error("Please set the PINECONE_API_KEY environment variable.")
192
- if args.index_issues and not os.getenv("GITHUB_TOKEN"):
193
- parser.error("Please set the GITHUB_TOKEN environment variable.")
194
 
195
  ######################
196
  # Step 1: Embeddings #
@@ -228,7 +78,7 @@ def main():
228
  # Step 2: Vector Store #
229
  ########################
230
 
231
- if args.vector_store_type == "marqo":
232
  # Marqo computes embeddings and stores them in the vector store at once, so we're done.
233
  logging.info("Done!")
234
  return
@@ -240,7 +90,7 @@ def main():
240
  time.sleep(30)
241
 
242
  logging.info("Moving embeddings to the repo vector store...")
243
- repo_vector_store = build_from_args(args)
244
  repo_vector_store.ensure_exists()
245
  repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
246
 
@@ -251,7 +101,7 @@ def main():
251
  time.sleep(30)
252
 
253
  logging.info("Moving embeddings to the issues vector store...")
254
- issues_vector_store = build_from_args(args)
255
  issues_vector_store.ensure_exists()
256
  issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
257
 
 
1
  """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
 
 
3
  import logging
 
4
  import time
5
 
6
+ import configargparse
7
  import pkg_resources
8
 
9
+ import sage.config as sage_config
10
  from sage.chunker import UniversalFileChunker
11
  from sage.data_manager import GitHubRepoManager
12
  from sage.embedder import build_batch_embedder_from_flags
13
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
14
+ from sage.vector_store import build_vector_store_from_args
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger()
18
  logger.setLevel(logging.INFO)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def main():
22
+ parser = configargparse.ArgParser(
23
+ description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
+ sage_config.add_config_args(parser)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ arg_validators = [
28
+ sage_config.add_repo_args(parser),
29
+ sage_config.add_embedding_args(parser),
30
+ sage_config.add_vector_store_args(parser),
31
+ sage_config.add_indexing_args(parser),
32
+ ]
 
 
 
 
33
 
34
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ for validator in arg_validators:
37
+ validator(args)
 
 
 
 
38
 
39
+ # Additionally validate embedder and vector store compatibility.
40
+ if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
41
+ parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
42
+ if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
43
+ parser.error("When using the marqo embedder, the vector store type must also be marqo.")
 
 
44
 
45
  ######################
46
  # Step 1: Embeddings #
 
78
  # Step 2: Vector Store #
79
  ########################
80
 
81
+ if args.vector_store_provider == "marqo":
82
  # Marqo computes embeddings and stores them in the vector store at once, so we're done.
83
  logging.info("Done!")
84
  return
 
90
  time.sleep(30)
91
 
92
  logging.info("Moving embeddings to the repo vector store...")
93
+ repo_vector_store = build_vector_store_from_args(args)
94
  repo_vector_store.ensure_exists()
95
  repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
96
 
 
101
  time.sleep(30)
102
 
103
  logging.info("Moving embeddings to the issues vector store...")
104
+ issues_vector_store = build_vector_store_from_args(args)
105
  issues_vector_store.ensure_exists()
106
  issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
107
 
sage/llm.py CHANGED
@@ -10,12 +10,12 @@ def build_llm_via_langchain(provider: str, model: str):
10
  if provider == "openai":
11
  if "OPENAI_API_KEY" not in os.environ:
12
  raise ValueError("Please set the OPENAI_API_KEY environment variable.")
13
- return ChatOpenAI(model=model)
14
  elif provider == "anthropic":
15
  if "ANTHROPIC_API_KEY" not in os.environ:
16
  raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
17
- return ChatAnthropic(model=model)
18
  elif provider == "ollama":
19
- return ChatOllama(model=model)
20
  else:
21
  raise ValueError(f"Unrecognized LLM provider {provider}. Contributons are welcome!")
 
10
  if provider == "openai":
11
  if "OPENAI_API_KEY" not in os.environ:
12
  raise ValueError("Please set the OPENAI_API_KEY environment variable.")
13
+ return ChatOpenAI(model=model or "gpt-4")
14
  elif provider == "anthropic":
15
  if "ANTHROPIC_API_KEY" not in os.environ:
16
  raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
17
+ return ChatAnthropic(model=model or "claude-3-opus-20240229")
18
  elif provider == "ollama":
19
+ return ChatOllama(model=model or "llama3.1")
20
  else:
21
  raise ValueError(f"Unrecognized LLM provider {provider}. Contributons are welcome!")
sage/vector_store.py CHANGED
@@ -146,35 +146,16 @@ class MarqoVectorStore(VectorStore):
146
  return vectorstore.as_retriever(search_kwargs={"k": top_k})
147
 
148
 
149
- def build_from_args(args: dict) -> VectorStore:
150
  """Builds a vector store from the given command-line arguments."""
151
- if args.vector_store_type == "pinecone":
152
- if not args.pinecone_index_name:
153
- raise ValueError("Please specify --pinecone-index-name for Pinecone.")
154
- dimension = args.embedding_size if "embedding_size" in args else None
155
-
156
- index_namespace = args.index_namespace
157
- if not index_namespace:
158
- index_namespace = args.repo_id
159
- if args.commit_hash:
160
- namespace += "/" + args.commit_hash
161
-
162
  return PineconeVectorStore(
163
  index_name=args.pinecone_index_name,
164
- namespace=index_namespace,
165
- dimension=dimension,
166
  hybrid=args.hybrid_retrieval,
167
  )
168
- elif args.vector_store_type == "marqo":
169
- marqo_url = args.marqo_url or "http://localhost:8882"
170
-
171
- index_namespace = args.index_namespace
172
- if not index_namespace:
173
- # Marqo doesn't allow slashes in the index name.
174
- index_namespace = args.repo_id.split("/")[1]
175
- if args.commit_hash:
176
- index_namespace += "_" + args.commit_hash
177
-
178
- return MarqoVectorStore(url=marqo_url, index_name=index_namespace)
179
  else:
180
- raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
 
146
  return vectorstore.as_retriever(search_kwargs={"k": top_k})
147
 
148
 
149
+ def build_vector_store_from_args(args: dict) -> VectorStore:
150
  """Builds a vector store from the given command-line arguments."""
151
+ if args.vector_store_provider == "pinecone":
 
 
 
 
 
 
 
 
 
 
152
  return PineconeVectorStore(
153
  index_name=args.pinecone_index_name,
154
+ namespace=args.index_namespace,
155
+ dimension=args.embedding_size if "embedding_size" in args else None,
156
  hybrid=args.hybrid_retrieval,
157
  )
158
+ elif args.vector_store_provider == "marqo":
159
+ return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
 
 
 
 
 
 
 
 
 
160
  else:
161
+ raise ValueError(f"Unrecognized vector store type {args.vector_store_provider}")