Spaces:
Running
Running
Add YAML configurations (#38)
Browse files- MANIFEST.in +3 -1
- README.md +16 -49
- sage/chat.py +20 -41
- sage/config.py +293 -0
- sage/configs/local.yaml +16 -0
- sage/configs/remote.yaml +18 -0
- sage/embedder.py +4 -4
- sage/index.py +23 -173
- sage/llm.py +3 -3
- sage/vector_store.py +7 -26
MANIFEST.in
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
-
include sage/sample-exclude.txt
|
|
|
|
|
|
|
|
|
| 1 |
+
include sage/sample-exclude.txt
|
| 2 |
+
include sage/configs/local.yaml
|
| 3 |
+
include sage/configs/remote.yaml
|
README.md
CHANGED
|
@@ -49,7 +49,7 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 49 |
2. Enables chatting via LLM + RAG (requiring access to an LLM)
|
| 50 |
|
| 51 |
<details open>
|
| 52 |
-
<summary><strong>:computer: Running locally</strong></summary>
|
| 53 |
|
| 54 |
1. To index the codebase locally, we use the open-source project <a href="https://github.com/marqo-ai/marqo">Marqo</a>, which is both an embedder and a vector store. To bring up a Marqo instance:
|
| 55 |
|
|
@@ -70,7 +70,7 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 70 |
</details>
|
| 71 |
|
| 72 |
<details>
|
| 73 |
-
<summary><strong>:cloud: Using external providers</strong></summary>
|
| 74 |
|
| 75 |
1. We support <a href="https://openai.com/">OpenAI</a> for embeddings (they have a super fast batch embedding API) and <a href="https://www.pinecone.io/">Pinecone</a> for the vector store. So you will need two API keys:
|
| 76 |
|
|
@@ -84,37 +84,27 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 84 |
export PINECONE_INDEX_NAME=...
|
| 85 |
```
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
```
|
| 90 |
-
export
|
|
|
|
|
|
|
| 91 |
```
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
<br>
|
| 96 |
-
<summary><strong>Optional</strong></summary>
|
| 97 |
-
|
| 98 |
-
- By default, we use an <a href="https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2">open-source re-ranker</a>. For higher accuracy, you can use <a href="https://cohere.com/rerank">Cohere</a>, <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
|
| 99 |
-
|
| 100 |
```
|
| 101 |
-
export
|
| 102 |
-
export NVIDIA_API_KEY=...
|
| 103 |
-
export JINA_API_KEY=...
|
| 104 |
```
|
| 105 |
|
| 106 |
-
|
| 107 |
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
export GITHUB_TOKEN=...
|
| 111 |
|
| 112 |
-
|
| 113 |
## Running it
|
| 114 |
|
| 115 |
-
<details open>
|
| 116 |
-
<summary><strong>:computer: Run locally</strong></summary>
|
| 117 |
-
|
| 118 |
1. Select your desired repository:
|
| 119 |
```
|
| 120 |
export GITHUB_REPO=huggingface/transformers
|
|
@@ -124,41 +114,18 @@ pip install git+https://github.com/Storia-AI/sage.git@main
|
|
| 124 |
```
|
| 125 |
sage-index $GITHUB_REPO
|
| 126 |
```
|
|
|
|
| 127 |
|
| 128 |
3. Chat with the repository, once it's indexed:
|
| 129 |
```
|
| 130 |
sage-chat $GITHUB_REPO
|
| 131 |
```
|
| 132 |
-
To
|
| 133 |
-
|
| 134 |
</details>
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
1. Select your desired repository:
|
| 140 |
-
```
|
| 141 |
-
export GITHUB_REPO=huggingface/transformers
|
| 142 |
-
```
|
| 143 |
-
|
| 144 |
-
2. Index the repository. This might take a few minutes, depending on its size.
|
| 145 |
-
```
|
| 146 |
-
sage-index $GITHUB_REPO \
|
| 147 |
-
--embedder-type=openai \
|
| 148 |
-
--vector-store=pinecone \
|
| 149 |
-
--index-name=$PINECONE_INDEX_NAME
|
| 150 |
-
```
|
| 151 |
-
|
| 152 |
-
3. Chat with the repository, once it's indexed:
|
| 153 |
-
```
|
| 154 |
-
sage-chat $GITHUB_REPO \
|
| 155 |
-
--vector-store-type=pinecone \
|
| 156 |
-
--index-name=$PINECONE_INDEX_NAME \
|
| 157 |
-
--llm-provider=openai \
|
| 158 |
-
--llm-model=gpt-4
|
| 159 |
-
```
|
| 160 |
-
To get a public URL for your chat app, set `--share=true`.
|
| 161 |
-
</details>
|
| 162 |
|
| 163 |
## Additional features
|
| 164 |
|
|
|
|
| 49 |
2. Enables chatting via LLM + RAG (requiring access to an LLM)
|
| 50 |
|
| 51 |
<details open>
|
| 52 |
+
<summary><strong>:computer: Running locally (lower quality)</strong></summary>
|
| 53 |
|
| 54 |
1. To index the codebase locally, we use the open-source project <a href="https://github.com/marqo-ai/marqo">Marqo</a>, which is both an embedder and a vector store. To bring up a Marqo instance:
|
| 55 |
|
|
|
|
| 70 |
</details>
|
| 71 |
|
| 72 |
<details>
|
| 73 |
+
<summary><strong>:cloud: Using external providers (higher quality)</strong></summary>
|
| 74 |
|
| 75 |
1. We support <a href="https://openai.com/">OpenAI</a> for embeddings (they have a super fast batch embedding API) and <a href="https://www.pinecone.io/">Pinecone</a> for the vector store. So you will need two API keys:
|
| 76 |
|
|
|
|
| 84 |
export PINECONE_INDEX_NAME=...
|
| 85 |
```
|
| 86 |
|
| 87 |
+
3. For reranking, we use <a href="https://cohere.com/rerank">Cohere</a> by default, but you can also try rerankers from <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a> or <a href="https://jina.ai/reranker/">Jina</a>:
|
|
|
|
| 88 |
```
|
| 89 |
+
export COHERE_API_KEY=... # or
|
| 90 |
+
export NVIDIA_API_KEY=... # or
|
| 91 |
+
export JINA_API_KEY=...
|
| 92 |
```
|
| 93 |
|
| 94 |
+
4. For chatting with an LLM, we support OpenAI and Anthropic. For the latter, set an additional API key:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
```
|
| 96 |
+
export ANTHROPIC_API_KEY=...
|
|
|
|
|
|
|
| 97 |
```
|
| 98 |
|
| 99 |
+
</details>
|
| 100 |
|
| 101 |
+
### Optional
|
| 102 |
+
If you are planning on indexing GitHub issues in addition to the codebase, you will need a GitHub token:
|
| 103 |
|
| 104 |
export GITHUB_TOKEN=...
|
| 105 |
|
|
|
|
| 106 |
## Running it
|
| 107 |
|
|
|
|
|
|
|
|
|
|
| 108 |
1. Select your desired repository:
|
| 109 |
```
|
| 110 |
export GITHUB_REPO=huggingface/transformers
|
|
|
|
| 114 |
```
|
| 115 |
sage-index $GITHUB_REPO
|
| 116 |
```
|
| 117 |
+
To use external providers instead of running locally, set `--mode=remote`.
|
| 118 |
|
| 119 |
3. Chat with the repository, once it's indexed:
|
| 120 |
```
|
| 121 |
sage-chat $GITHUB_REPO
|
| 122 |
```
|
| 123 |
+
To use external providers instead of running locally, set `--mode=remote`.
|
|
|
|
| 124 |
</details>
|
| 125 |
|
| 126 |
+
### Notes:
|
| 127 |
+
- To get a public URL for your chat app, set `--share=true`.
|
| 128 |
+
- You can overwrite the default settings (e.g. desired embedding model or LLM) via command line flags. Run `sage-index --help` or `sage-chat --help` for a full list.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
## Additional features
|
| 131 |
|
sage/chat.py
CHANGED
|
@@ -3,11 +3,11 @@
|
|
| 3 |
You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import argparse
|
| 7 |
import logging
|
| 8 |
-
import os
|
| 9 |
|
|
|
|
| 10 |
import gradio as gr
|
|
|
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
@@ -15,9 +15,10 @@ from langchain.retrievers import ContextualCompressionRetriever
|
|
| 15 |
from langchain.schema import AIMessage, HumanMessage
|
| 16 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 17 |
|
| 18 |
-
import sage.
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
-
from sage.reranker import build_reranker
|
|
|
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
|
|
@@ -27,7 +28,7 @@ def build_rag_chain(args):
|
|
| 27 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 28 |
|
| 29 |
retriever_top_k = 5 if args.reranker_provider == "none" else 25
|
| 30 |
-
retriever =
|
| 31 |
compressor = build_reranker(args.reranker_provider, args.reranker_model)
|
| 32 |
if compressor:
|
| 33 |
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
|
@@ -70,49 +71,27 @@ def build_rag_chain(args):
|
|
| 70 |
|
| 71 |
|
| 72 |
def main():
|
| 73 |
-
parser =
|
| 74 |
-
|
| 75 |
-
parser.add_argument("--llm-provider", default="ollama", choices=["openai", "anthropic", "ollama"])
|
| 76 |
-
parser.add_argument(
|
| 77 |
-
"--llm-model",
|
| 78 |
-
help="The LLM name. Must be supported by the provider specified via --llm-provider.",
|
| 79 |
)
|
| 80 |
-
parser.
|
| 81 |
-
parser.add_argument("--index-name", help="Vector store index name. Required for Pinecone.")
|
| 82 |
-
parser.add_argument(
|
| 83 |
-
"--marqo-url",
|
| 84 |
-
default="http://localhost:8882",
|
| 85 |
-
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 86 |
-
)
|
| 87 |
-
parser.add_argument("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"--reranker-model",
|
| 90 |
-
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
| 91 |
-
"SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
|
| 92 |
-
)
|
| 93 |
-
parser.add_argument(
|
| 94 |
"--share",
|
| 95 |
default=False,
|
| 96 |
help="Whether to make the gradio app publicly accessible.",
|
| 97 |
)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
args = parser.parse_args()
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
args.llm_model = "gpt-4"
|
| 110 |
-
elif args.llm_provider == "anthropic":
|
| 111 |
-
args.llm_model = "claude-3-opus-20240229"
|
| 112 |
-
elif args.llm_provider == "ollama":
|
| 113 |
-
args.llm_model = "llama3.1"
|
| 114 |
-
else:
|
| 115 |
-
raise ValueError("Please specify --llm_model")
|
| 116 |
|
| 117 |
rag_chain = build_rag_chain(args)
|
| 118 |
|
|
|
|
| 3 |
You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
import logging
|
|
|
|
| 7 |
|
| 8 |
+
import configargparse
|
| 9 |
import gradio as gr
|
| 10 |
+
import pkg_resources
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
|
| 15 |
from langchain.schema import AIMessage, HumanMessage
|
| 16 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 17 |
|
| 18 |
+
import sage.config as sage_config
|
| 19 |
from sage.llm import build_llm_via_langchain
|
| 20 |
+
from sage.reranker import build_reranker
|
| 21 |
+
from sage.vector_store import build_vector_store_from_args
|
| 22 |
|
| 23 |
load_dotenv()
|
| 24 |
|
|
|
|
| 28 |
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
|
| 29 |
|
| 30 |
retriever_top_k = 5 if args.reranker_provider == "none" else 25
|
| 31 |
+
retriever = build_vector_store_from_args(args).as_retriever(top_k=retriever_top_k)
|
| 32 |
compressor = build_reranker(args.reranker_provider, args.reranker_model)
|
| 33 |
if compressor:
|
| 34 |
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def main():
|
| 74 |
+
parser = configargparse.ArgParser(
|
| 75 |
+
description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
+
parser.add(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"--share",
|
| 79 |
default=False,
|
| 80 |
help="Whether to make the gradio app publicly accessible.",
|
| 81 |
)
|
| 82 |
+
sage_config.add_config_args(parser)
|
| 83 |
+
|
| 84 |
+
arg_validators = [
|
| 85 |
+
sage_config.add_repo_args(parser),
|
| 86 |
+
sage_config.add_vector_store_args(parser),
|
| 87 |
+
sage_config.add_reranking_args(parser),
|
| 88 |
+
sage_config.add_llm_args(parser),
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
args = parser.parse_args()
|
| 92 |
|
| 93 |
+
for validator in arg_validators:
|
| 94 |
+
validator(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
rag_chain = build_rag_chain(args)
|
| 97 |
|
sage/config.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility methods to define and validate flags."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from typing import Callable
|
| 8 |
+
|
| 9 |
+
import pkg_resources
|
| 10 |
+
from configargparse import ArgumentParser
|
| 11 |
+
|
| 12 |
+
from sage.reranker import RerankerProvider
|
| 13 |
+
|
| 14 |
+
MARQO_MAX_CHUNKS_PER_BATCH = 64
|
| 15 |
+
# The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 16 |
+
OPENAI_MAX_TOKENS_PER_CHUNK = 8192
|
| 17 |
+
# The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
| 18 |
+
OPENAI_MAX_CHUNKS_PER_BATCH = 2048
|
| 19 |
+
# The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
| 20 |
+
OPENAI_MAX_TOKENS_PER_JOB = 3_000_000
|
| 21 |
+
|
| 22 |
+
# Note that OpenAI embedding models have fixed dimensions, however, taking a slice of them is possible.
|
| 23 |
+
# See "Reducing embedding dimensions" under https://platform.openai.com/docs/guides/embeddings/use-cases and
|
| 24 |
+
# https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
|
| 25 |
+
OPENAI_DEFAULT_EMBEDDING_SIZE = {
|
| 26 |
+
"text-embedding-ada-002": 1536,
|
| 27 |
+
"text-embedding-3-small": 1536,
|
| 28 |
+
"text-embedding-3-large": 3072,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def add_config_args(parser: ArgumentParser):
|
| 33 |
+
"""Adds configuration-related arguments to the parser."""
|
| 34 |
+
parser.add(
|
| 35 |
+
"--mode",
|
| 36 |
+
choices=["local", "remote"],
|
| 37 |
+
default="local",
|
| 38 |
+
help="Whether to use local-only resources or call third-party providers.",
|
| 39 |
+
)
|
| 40 |
+
parser.add(
|
| 41 |
+
"--config",
|
| 42 |
+
is_config_file=True,
|
| 43 |
+
help="Path to .yaml configuration file.",
|
| 44 |
+
)
|
| 45 |
+
args, _ = parser.parse_known_args()
|
| 46 |
+
config_file = pkg_resources.resource_filename(__name__, f"configs/{args.mode}.yaml")
|
| 47 |
+
parser.set_defaults(config=config_file)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def add_repo_args(parser: ArgumentParser) -> Callable:
|
| 51 |
+
"""Adds repository-related arguments to the parser and returns a validator."""
|
| 52 |
+
parser.add("repo_id", help="The ID of the repository to index")
|
| 53 |
+
parser.add("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.")
|
| 54 |
+
parser.add(
|
| 55 |
+
"--local-dir",
|
| 56 |
+
default="repos",
|
| 57 |
+
help="The local directory to store the repository",
|
| 58 |
+
)
|
| 59 |
+
return validate_repo_args
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def add_embedding_args(parser: ArgumentParser) -> Callable:
|
| 63 |
+
"""Adds embedding-related arguments to the parser and returns a validator."""
|
| 64 |
+
parser.add("--embedding-provider", default="marqo", choices=["openai", "marqo"])
|
| 65 |
+
parser.add(
|
| 66 |
+
"--embedding-model",
|
| 67 |
+
type=str,
|
| 68 |
+
default=None,
|
| 69 |
+
help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
|
| 70 |
+
)
|
| 71 |
+
parser.add(
|
| 72 |
+
"--embedding-size",
|
| 73 |
+
type=int,
|
| 74 |
+
default=None,
|
| 75 |
+
help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for "
|
| 76 |
+
"large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.",
|
| 77 |
+
)
|
| 78 |
+
parser.add(
|
| 79 |
+
"--tokens-per-chunk",
|
| 80 |
+
type=int,
|
| 81 |
+
default=800,
|
| 82 |
+
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 83 |
+
)
|
| 84 |
+
parser.add(
|
| 85 |
+
"--chunks-per-batch",
|
| 86 |
+
type=int,
|
| 87 |
+
help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
|
| 88 |
+
)
|
| 89 |
+
parser.add(
|
| 90 |
+
"--max-embedding-jobs",
|
| 91 |
+
type=int,
|
| 92 |
+
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 93 |
+
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 94 |
+
)
|
| 95 |
+
return validate_embedding_args
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def add_vector_store_args(parser: ArgumentParser) -> Callable:
|
| 99 |
+
"""Adds vector store-related arguments to the parser and returns a validator."""
|
| 100 |
+
parser.add("--vector-store-provider", default="marqo", choices=["pinecone", "marqo"])
|
| 101 |
+
parser.add(
|
| 102 |
+
"--pinecone-index-name",
|
| 103 |
+
default=None,
|
| 104 |
+
help="Pinecone index name. Required if using Pinecone as the vector store. If the index doesn't exist already, "
|
| 105 |
+
"we will create it.",
|
| 106 |
+
)
|
| 107 |
+
parser.add(
|
| 108 |
+
"--index-namespace",
|
| 109 |
+
default=None,
|
| 110 |
+
help="Index namespace for this repo. When not specified, we default it to a derivative of the repo name.",
|
| 111 |
+
)
|
| 112 |
+
parser.add(
|
| 113 |
+
"--marqo-url",
|
| 114 |
+
default="http://localhost:8882",
|
| 115 |
+
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 116 |
+
)
|
| 117 |
+
parser.add(
|
| 118 |
+
"--hybrid-retrieval",
|
| 119 |
+
action=argparse.BooleanOptionalAction,
|
| 120 |
+
default=True,
|
| 121 |
+
help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
|
| 122 |
+
"retrieval. This is only relevant if using Pinecone as the vector store.",
|
| 123 |
+
)
|
| 124 |
+
return validate_vector_store_args
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def add_indexing_args(parser: ArgumentParser) -> Callable:
|
| 128 |
+
"""Adds indexing-related arguments to the parser and returns a validator."""
|
| 129 |
+
parser.add(
|
| 130 |
+
"--include",
|
| 131 |
+
help="Path to a file containing a list of extensions to include. One extension per line.",
|
| 132 |
+
)
|
| 133 |
+
parser.add(
|
| 134 |
+
"--exclude",
|
| 135 |
+
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 136 |
+
)
|
| 137 |
+
# Pass --no-index-repo in order to not index the repository.
|
| 138 |
+
parser.add(
|
| 139 |
+
"--index-repo",
|
| 140 |
+
action=argparse.BooleanOptionalAction,
|
| 141 |
+
default=True,
|
| 142 |
+
help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.",
|
| 143 |
+
)
|
| 144 |
+
# Pass --no-index-issues in order to not index the issues.
|
| 145 |
+
parser.add(
|
| 146 |
+
"--index-issues",
|
| 147 |
+
action=argparse.BooleanOptionalAction,
|
| 148 |
+
default=False,
|
| 149 |
+
help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True. When "
|
| 150 |
+
"--index-issues is set, you must also set a GITHUB_TOKEN environment variable.",
|
| 151 |
+
)
|
| 152 |
+
# Pass --no-index-issue-comments in order to not index the comments of GitHub issues.
|
| 153 |
+
parser.add(
|
| 154 |
+
"--index-issue-comments",
|
| 155 |
+
action=argparse.BooleanOptionalAction,
|
| 156 |
+
default=False,
|
| 157 |
+
help="Whether to index the comments of GitHub issues. This is only relevant if --index-issues is set. "
|
| 158 |
+
"GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
|
| 159 |
+
"of the gains anyway.",
|
| 160 |
+
)
|
| 161 |
+
return validate_indexing_args
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def add_reranking_args(parser: ArgumentParser) -> Callable:
|
| 165 |
+
"""Adds reranking-related arguments to the parser."""
|
| 166 |
+
parser.add("--reranker-provider", default="huggingface", choices=[r.value for r in RerankerProvider])
|
| 167 |
+
parser.add(
|
| 168 |
+
"--reranker-model",
|
| 169 |
+
help="The reranker model name. When --reranker-provider=huggingface, we suggest choosing a model from the "
|
| 170 |
+
"SentenceTransformers Cross-Encoders library https://huggingface.co/cross-encoder?sort_models=downloads#models",
|
| 171 |
+
)
|
| 172 |
+
# Trivial validator (nothing to check).
|
| 173 |
+
return lambda _: True
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def add_llm_args(parser: ArgumentParser) -> Callable:
|
| 177 |
+
"""Adds language model-related arguments to the parser."""
|
| 178 |
+
parser.add("--llm-provider", default="ollama", choices=["openai", "anthropic", "ollama"])
|
| 179 |
+
parser.add(
|
| 180 |
+
"--llm-model",
|
| 181 |
+
help="The LLM name. Must be supported by the provider specified via --llm-provider.",
|
| 182 |
+
)
|
| 183 |
+
# Trivial validator (nothing to check).
|
| 184 |
+
return lambda _: True
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def validate_repo_args(args):
|
| 188 |
+
"""Validates the configuration of the repository."""
|
| 189 |
+
if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
|
| 190 |
+
raise ValueError("repo_id must be in the format 'owner/repo'")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _validate_openai_embedding_args(args):
|
| 194 |
+
"""Validates the configuration of the OpenAI batch embedder and sets defaults."""
|
| 195 |
+
if args.embedding_provider == "openai" and not os.getenv("OPENAI_API_KEY"):
|
| 196 |
+
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
| 197 |
+
|
| 198 |
+
if not args.embedding_model:
|
| 199 |
+
args.embedding_model = "text-embedding-ada-002"
|
| 200 |
+
|
| 201 |
+
if args.embedding_model not in OPENAI_DEFAULT_EMBEDDING_SIZE.keys():
|
| 202 |
+
raise ValueError(f"Unrecognized embeddings.model={args.embedding_model}")
|
| 203 |
+
|
| 204 |
+
if not args.embedding_size:
|
| 205 |
+
args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
|
| 206 |
+
|
| 207 |
+
if not args.tokens_per_chunk:
|
| 208 |
+
# https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.
|
| 209 |
+
args.tokens_per_chunk = 800
|
| 210 |
+
elif args.tokens_per_chunk > OPENAI_MAX_TOKENS_PER_CHUNK:
|
| 211 |
+
args.tokens_per_chunk = OPENAI_MAX_TOKENS_PER_CHUNK
|
| 212 |
+
logging.warning(
|
| 213 |
+
f"OpenAI enforces a limit of {OPENAI_MAX_TOKENS_PER_CHUNK} tokens per chunk. "
|
| 214 |
+
"Overwriting embeddings.tokens_per_chunk."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if not args.chunks_per_batch:
|
| 218 |
+
args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
|
| 219 |
+
elif args.chunks_per_batch > OPENAI_MAX_CHUNKS_PER_BATCH:
|
| 220 |
+
args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
|
| 221 |
+
logging.warning(
|
| 222 |
+
f"OpenAI enforces a limit of {OPENAI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
|
| 223 |
+
"Overwriting embeddings.chunks_per_batch."
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
chunks_per_job = args.tokens_per_chunk * args.chunks_per_batch
|
| 227 |
+
if chunks_per_job >= OPENAI_MAX_TOKENS_PER_JOB:
|
| 228 |
+
raise ValueError(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}. Got {chunks_per_job}")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _validate_marqo_embedding_args(args):
|
| 232 |
+
"""Validates the configuration of the Marqo batch embedder and sets defaults."""
|
| 233 |
+
if not args.embedding_model:
|
| 234 |
+
args.embedding_model = "hf/e5-base-v2"
|
| 235 |
+
|
| 236 |
+
if not args.chunks_per_batch:
|
| 237 |
+
args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
|
| 238 |
+
elif args.chunks_per_batch > MARQO_MAX_CHUNKS_PER_BATCH:
|
| 239 |
+
args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
|
| 240 |
+
logging.warning(
|
| 241 |
+
f"Marqo enforces a limit of {MARQO_MAX_CHUNKS_PER_BATCH} chunks per batch. "
|
| 242 |
+
"Overwriting embeddings.chunks_per_batch."
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def validate_embedding_args(args):
|
| 247 |
+
"""Validates the configuration of the batch embedder and sets defaults."""
|
| 248 |
+
if args.embedding_provider == "openai":
|
| 249 |
+
_validate_openai_embedding_args(args)
|
| 250 |
+
elif args.embedding_provider == "marqo":
|
| 251 |
+
_validate_marqo_embedding_args(args)
|
| 252 |
+
else:
|
| 253 |
+
raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def validate_vector_store_args(args):
|
| 257 |
+
"""Validates the configuration of the vector store and sets defaults."""
|
| 258 |
+
|
| 259 |
+
if not args.index_namespace:
|
| 260 |
+
args.index_namespace = args.repo_id
|
| 261 |
+
if args.commit_hash:
|
| 262 |
+
args.index_namespace += "/" + args.commit_hash
|
| 263 |
+
if args.vector_store_provider == "marqo":
|
| 264 |
+
# Marqo doesn't allow slashes in the index namespace.
|
| 265 |
+
args.index_namespace = args.index_namespace.replace("/", "_")
|
| 266 |
+
|
| 267 |
+
if args.vector_store_provider == "marqo":
|
| 268 |
+
if not args.marqo_url:
|
| 269 |
+
args.marqo_url = "http://localhost:8882"
|
| 270 |
+
if "/" in args.index_namespace:
|
| 271 |
+
raise ValueError(f"Marqo doesn't allow slashes in --index-namespace={args.index_namespace}.")
|
| 272 |
+
|
| 273 |
+
elif args.vector_store_provider == "pinecone":
|
| 274 |
+
if not os.getenv("PINECONE_API_KEY"):
|
| 275 |
+
raise ValueError("Please set the PINECONE_API_KEY environment variable.")
|
| 276 |
+
if not args.pinecone_index_name:
|
| 277 |
+
raise ValueError(f"Please set the vector_store.pinecone_index_name value.")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def validate_indexing_args(args):
|
| 281 |
+
"""Validates the indexing configuration and sets defaults."""
|
| 282 |
+
if args.include and args.exclude:
|
| 283 |
+
raise ValueError("At most one of indexing.include and indexing.exclude can be specified.")
|
| 284 |
+
if not args.include and not args.exclude:
|
| 285 |
+
args.exclude = pkg_resources.resource_filename(__name__, "sample-exclude.txt")
|
| 286 |
+
if args.include and not os.path.exists(args.include):
|
| 287 |
+
raise ValueError(f"Path --include={args.include} does not exist.")
|
| 288 |
+
if args.exclude and not os.path.exists(args.exclude):
|
| 289 |
+
raise ValueError(f"Path --exclude={args.exclude} does not exist.")
|
| 290 |
+
if not args.index_repo and not args.index_issues:
|
| 291 |
+
raise ValueError("Either --index_repo or --index_issues must be set to true.")
|
| 292 |
+
if args.index_issues and not os.getenv("GITHUB_TOKEN"):
|
| 293 |
+
raise ValueError("Please set the GITHUB_TOKEN environment variable.")
|
sage/configs/local.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Embeddings
|
| 2 |
+
embedding-provider: marqo
|
| 3 |
+
embedding-model: hf/e5-base-v2
|
| 4 |
+
tokens-per-chunk: 800
|
| 5 |
+
chunks-per-batch: 64
|
| 6 |
+
|
| 7 |
+
# Vector store
|
| 8 |
+
vector-store-provider: marqo
|
| 9 |
+
|
| 10 |
+
# LLM
|
| 11 |
+
llm-provider: ollama
|
| 12 |
+
llm-model: llama3.1
|
| 13 |
+
|
| 14 |
+
# Reranking
|
| 15 |
+
reranking-provider: huggingface
|
| 16 |
+
reranking-model: cross-encoder/ms-marco-MiniLM-L-6-v2
|
sage/configs/remote.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Embeddings
|
| 2 |
+
embedding-provider: openai
|
| 3 |
+
embedding-model: text-embedding-ada-002
|
| 4 |
+
tokens-per-chunk: 800
|
| 5 |
+
chunks-per-batch: 2000
|
| 6 |
+
|
| 7 |
+
# Vector store
|
| 8 |
+
vector-store-provider: pinecone
|
| 9 |
+
pinecone-index-name: sage
|
| 10 |
+
hybrid-retrieval: true
|
| 11 |
+
|
| 12 |
+
# LLM
|
| 13 |
+
llm-provider: openai
|
| 14 |
+
llm-model: gpt-4
|
| 15 |
+
|
| 16 |
+
# Reranking
|
| 17 |
+
reranking-provider: cohere
|
| 18 |
+
reranking-model: rerank-english-v3.0
|
sage/embedder.py
CHANGED
|
@@ -268,11 +268,11 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 268 |
|
| 269 |
|
| 270 |
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 271 |
-
if args.
|
| 272 |
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
| 273 |
-
elif args.
|
| 274 |
return MarqoEmbedder(
|
| 275 |
-
data_manager, chunker, index_name=args.
|
| 276 |
)
|
| 277 |
else:
|
| 278 |
-
raise ValueError(f"Unrecognized embedder type {args.
|
|
|
|
| 268 |
|
| 269 |
|
| 270 |
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 271 |
+
if args.embedding_provider == "openai":
|
| 272 |
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
| 273 |
+
elif args.embedding_provider == "marqo":
|
| 274 |
return MarqoEmbedder(
|
| 275 |
+
data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
|
| 276 |
)
|
| 277 |
else:
|
| 278 |
+
raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
|
sage/index.py
CHANGED
|
@@ -1,196 +1,46 @@
|
|
| 1 |
"""Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
|
| 2 |
|
| 3 |
-
import argparse
|
| 4 |
import logging
|
| 5 |
-
import os
|
| 6 |
import time
|
| 7 |
|
|
|
|
| 8 |
import pkg_resources
|
| 9 |
|
|
|
|
| 10 |
from sage.chunker import UniversalFileChunker
|
| 11 |
from sage.data_manager import GitHubRepoManager
|
| 12 |
from sage.embedder import build_batch_embedder_from_flags
|
| 13 |
from sage.github import GitHubIssuesChunker, GitHubIssuesManager
|
| 14 |
-
from sage.vector_store import
|
| 15 |
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logger = logging.getLogger()
|
| 18 |
logger.setLevel(logging.INFO)
|
| 19 |
|
| 20 |
-
MARQO_MAX_CHUNKS_PER_BATCH = 64
|
| 21 |
-
|
| 22 |
-
OPENAI_MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 23 |
-
OPENAI_MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
| 24 |
-
OPENAI_MAX_TOKENS_PER_JOB = (
|
| 25 |
-
3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# Note that OpenAI embedding models have fixed dimensions, however, taking a slice of them is possible.
|
| 29 |
-
# See "Reducing embedding dimensions" under https://platform.openai.com/docs/guides/embeddings/use-cases and
|
| 30 |
-
# https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-dimensions
|
| 31 |
-
OPENAI_DEFAULT_EMBEDDING_SIZE = {
|
| 32 |
-
"text-embedding-ada-002": 1536,
|
| 33 |
-
"text-embedding-3-small": 1536,
|
| 34 |
-
"text-embedding-3-large": 3072,
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
|
| 38 |
def main():
|
| 39 |
-
parser =
|
| 40 |
-
|
| 41 |
-
parser.add_argument("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.")
|
| 42 |
-
parser.add_argument("--embedder-type", default="marqo", choices=["openai", "marqo"])
|
| 43 |
-
parser.add_argument(
|
| 44 |
-
"--embedding-model",
|
| 45 |
-
type=str,
|
| 46 |
-
default=None,
|
| 47 |
-
help="The embedding model. Defaults to `text-embedding-ada-002` for OpenAI and `hf/e5-base-v2` for Marqo.",
|
| 48 |
-
)
|
| 49 |
-
parser.add_argument(
|
| 50 |
-
"--embedding-size",
|
| 51 |
-
type=int,
|
| 52 |
-
default=None,
|
| 53 |
-
help="The embedding size to use for OpenAI text-embedding-3* models. Defaults to 1536 for small and 3072 for "
|
| 54 |
-
"large. Note that no other OpenAI models support a dynamic embedding size, nor do models used with Marqo.",
|
| 55 |
-
)
|
| 56 |
-
parser.add_argument("--vector-store-type", default="marqo", choices=["pinecone", "marqo"])
|
| 57 |
-
parser.add_argument(
|
| 58 |
-
"--local-dir",
|
| 59 |
-
default="repos",
|
| 60 |
-
help="The local directory to store the repository",
|
| 61 |
-
)
|
| 62 |
-
parser.add_argument(
|
| 63 |
-
"--tokens-per-chunk",
|
| 64 |
-
type=int,
|
| 65 |
-
default=800,
|
| 66 |
-
help="https://arxiv.org/pdf/2406.14497 recommends a value between 200-800.",
|
| 67 |
-
)
|
| 68 |
-
parser.add_argument(
|
| 69 |
-
"--chunks-per-batch",
|
| 70 |
-
type=int,
|
| 71 |
-
help="Maximum chunks per batch. We recommend 2000 for the OpenAI embedder. Marqo enforces a limit of 64.",
|
| 72 |
-
)
|
| 73 |
-
parser.add_argument(
|
| 74 |
-
"--pinecone-index-name",
|
| 75 |
-
default=None,
|
| 76 |
-
help="Pinecone index name. Required if using Pinecone as the vector store. If the index doesn't exist already, "
|
| 77 |
-
"we will create it.",
|
| 78 |
-
)
|
| 79 |
-
parser.add_argument(
|
| 80 |
-
"--index-namespace",
|
| 81 |
-
default=None,
|
| 82 |
-
help="Index namespace for this repo. When not specified, we default it to a derivative of the repo name."
|
| 83 |
-
)
|
| 84 |
-
parser.add_argument(
|
| 85 |
-
"--include",
|
| 86 |
-
help="Path to a file containing a list of extensions to include. One extension per line.",
|
| 87 |
-
)
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"--exclude",
|
| 90 |
-
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 91 |
-
)
|
| 92 |
-
parser.add_argument(
|
| 93 |
-
"--max-embedding-jobs",
|
| 94 |
-
type=int,
|
| 95 |
-
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 96 |
-
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 97 |
)
|
| 98 |
-
|
| 99 |
-
"--marqo-url",
|
| 100 |
-
default="http://localhost:8882",
|
| 101 |
-
help="URL for the Marqo server. Required if using Marqo as embedder or vector store.",
|
| 102 |
-
)
|
| 103 |
-
# Pass --no-index-repo in order to not index the repository.
|
| 104 |
-
parser.add_argument(
|
| 105 |
-
"--index-repo",
|
| 106 |
-
action=argparse.BooleanOptionalAction,
|
| 107 |
-
default=True,
|
| 108 |
-
help="Whether to index the repository. At least one of --index-repo and --index-issues must be True.",
|
| 109 |
-
)
|
| 110 |
-
# Pass --no-index-issues in order to not index the issues.
|
| 111 |
-
parser.add_argument(
|
| 112 |
-
"--index-issues",
|
| 113 |
-
action=argparse.BooleanOptionalAction,
|
| 114 |
-
default=False,
|
| 115 |
-
help="Whether to index GitHub issues. At least one of --index-repo and --index-issues must be True. When "
|
| 116 |
-
"--index-issues is set, you must also set a GITHUB_TOKEN environment variable.",
|
| 117 |
-
)
|
| 118 |
-
parser.add_argument(
|
| 119 |
-
"--index-issue-comments",
|
| 120 |
-
action=argparse.BooleanOptionalAction,
|
| 121 |
-
default=False,
|
| 122 |
-
help="Whether to index the comments of GitHub issues. This is only relevant if --index-issues is set. "
|
| 123 |
-
"GitHub's API for downloading comments is quite slow. Indexing solely the body of an issue seems to bring most "
|
| 124 |
-
"of the gains anyway.",
|
| 125 |
-
)
|
| 126 |
-
parser.add_argument(
|
| 127 |
-
"--hybrid-retrieval",
|
| 128 |
-
action=argparse.BooleanOptionalAction,
|
| 129 |
-
default=True,
|
| 130 |
-
help="Whether to use a hybrid of vector DB + BM25 retrieval. When set to False, we only use vector DB "
|
| 131 |
-
"retrieval. This is only relevant if using Pinecone as the vector store.",
|
| 132 |
-
)
|
| 133 |
-
args = parser.parse_args()
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
if "/" in args.index_namespace:
|
| 142 |
-
parser.error("The index namespace cannot contain slashes when using Marqo as the vector store.")
|
| 143 |
-
elif args.vector_store_type == "pinecone" and not args.pinecone_index_name:
|
| 144 |
-
parser.error("When using Pinecone as the vector store, you must specify --pinecone-index-name")
|
| 145 |
|
| 146 |
-
|
| 147 |
-
if args.embedder_type == "marqo":
|
| 148 |
-
if args.embedding_model is None:
|
| 149 |
-
args.embedding_model = "hf/e5-base-v2"
|
| 150 |
-
if args.chunks_per_batch is None:
|
| 151 |
-
args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
|
| 152 |
-
elif args.chunks_per_batch > MARQO_MAX_CHUNKS_PER_BATCH:
|
| 153 |
-
args.chunks_per_batch = MARQO_MAX_CHUNKS_PER_BATCH
|
| 154 |
-
logging.warning(
|
| 155 |
-
f"Marqo enforces a limit of {MARQO_MAX_CHUNKS_PER_BATCH} chunks per batch. "
|
| 156 |
-
"Overwriting --chunks_per_batch."
|
| 157 |
-
)
|
| 158 |
-
elif args.embedder_type == "openai":
|
| 159 |
-
if args.tokens_per_chunk > OPENAI_MAX_TOKENS_PER_CHUNK:
|
| 160 |
-
args.tokens_per_chunk = OPENAI_MAX_TOKENS_PER_CHUNK
|
| 161 |
-
logging.warning(
|
| 162 |
-
f"OpenAI enforces a limit of {OPENAI_MAX_TOKENS_PER_CHUNK} tokens per chunk. "
|
| 163 |
-
"Overwriting --tokens_per_chunk."
|
| 164 |
-
)
|
| 165 |
-
if args.chunks_per_batch is None:
|
| 166 |
-
args.chunks_per_batch = 2000
|
| 167 |
-
elif args.chunks_per_batch > OPENAI_MAX_CHUNKS_PER_BATCH:
|
| 168 |
-
args.chunks_per_batch = OPENAI_MAX_CHUNKS_PER_BATCH
|
| 169 |
-
logging.warning(
|
| 170 |
-
f"OpenAI enforces a limit of {OPENAI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
|
| 171 |
-
"Overwriting --chunks_per_batch."
|
| 172 |
-
)
|
| 173 |
-
if args.tokens_per_chunk * args.chunks_per_batch >= OPENAI_MAX_TOKENS_PER_JOB:
|
| 174 |
-
parser.error(f"The maximum number of chunks per job is {OPENAI_MAX_TOKENS_PER_JOB}.")
|
| 175 |
-
if args.embedding_model is None:
|
| 176 |
-
args.embedding_model = "text-embedding-ada-002"
|
| 177 |
-
if args.embedding_size is None:
|
| 178 |
-
args.embedding_size = OPENAI_DEFAULT_EMBEDDING_SIZE.get(args.embedding_model)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
if not args.include and not args.exclude:
|
| 183 |
-
args.exclude = pkg_resources.resource_filename(__name__, "sample-exclude.txt")
|
| 184 |
-
if not args.index_repo and not args.index_issues:
|
| 185 |
-
parser.error("At least one of --index-repo and --index-issues must be true.")
|
| 186 |
|
| 187 |
-
#
|
| 188 |
-
if args.
|
| 189 |
-
parser.error("
|
| 190 |
-
if args.
|
| 191 |
-
parser.error("
|
| 192 |
-
if args.index_issues and not os.getenv("GITHUB_TOKEN"):
|
| 193 |
-
parser.error("Please set the GITHUB_TOKEN environment variable.")
|
| 194 |
|
| 195 |
######################
|
| 196 |
# Step 1: Embeddings #
|
|
@@ -228,7 +78,7 @@ def main():
|
|
| 228 |
# Step 2: Vector Store #
|
| 229 |
########################
|
| 230 |
|
| 231 |
-
if args.
|
| 232 |
# Marqo computes embeddings and stores them in the vector store at once, so we're done.
|
| 233 |
logging.info("Done!")
|
| 234 |
return
|
|
@@ -240,7 +90,7 @@ def main():
|
|
| 240 |
time.sleep(30)
|
| 241 |
|
| 242 |
logging.info("Moving embeddings to the repo vector store...")
|
| 243 |
-
repo_vector_store =
|
| 244 |
repo_vector_store.ensure_exists()
|
| 245 |
repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
|
| 246 |
|
|
@@ -251,7 +101,7 @@ def main():
|
|
| 251 |
time.sleep(30)
|
| 252 |
|
| 253 |
logging.info("Moving embeddings to the issues vector store...")
|
| 254 |
-
issues_vector_store =
|
| 255 |
issues_vector_store.ensure_exists()
|
| 256 |
issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
|
| 257 |
|
|
|
|
| 1 |
"""Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
|
| 2 |
|
|
|
|
| 3 |
import logging
|
|
|
|
| 4 |
import time
|
| 5 |
|
| 6 |
+
import configargparse
|
| 7 |
import pkg_resources
|
| 8 |
|
| 9 |
+
import sage.config as sage_config
|
| 10 |
from sage.chunker import UniversalFileChunker
|
| 11 |
from sage.data_manager import GitHubRepoManager
|
| 12 |
from sage.embedder import build_batch_embedder_from_flags
|
| 13 |
from sage.github import GitHubIssuesChunker, GitHubIssuesManager
|
| 14 |
+
from sage.vector_store import build_vector_store_from_args
|
| 15 |
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
| 17 |
logger = logging.getLogger()
|
| 18 |
logger.setLevel(logging.INFO)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def main():
|
| 22 |
+
parser = configargparse.ArgParser(
|
| 23 |
+
description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
+
sage_config.add_config_args(parser)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
arg_validators = [
|
| 28 |
+
sage_config.add_repo_args(parser),
|
| 29 |
+
sage_config.add_embedding_args(parser),
|
| 30 |
+
sage_config.add_vector_store_args(parser),
|
| 31 |
+
sage_config.add_indexing_args(parser),
|
| 32 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
for validator in arg_validators:
|
| 37 |
+
validator(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Additionally validate embedder and vector store compatibility.
|
| 40 |
+
if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
|
| 41 |
+
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
|
| 42 |
+
if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
|
| 43 |
+
parser.error("When using the marqo embedder, the vector store type must also be marqo.")
|
|
|
|
|
|
|
| 44 |
|
| 45 |
######################
|
| 46 |
# Step 1: Embeddings #
|
|
|
|
| 78 |
# Step 2: Vector Store #
|
| 79 |
########################
|
| 80 |
|
| 81 |
+
if args.vector_store_provider == "marqo":
|
| 82 |
# Marqo computes embeddings and stores them in the vector store at once, so we're done.
|
| 83 |
logging.info("Done!")
|
| 84 |
return
|
|
|
|
| 90 |
time.sleep(30)
|
| 91 |
|
| 92 |
logging.info("Moving embeddings to the repo vector store...")
|
| 93 |
+
repo_vector_store = build_vector_store_from_args(args)
|
| 94 |
repo_vector_store.ensure_exists()
|
| 95 |
repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
|
| 96 |
|
|
|
|
| 101 |
time.sleep(30)
|
| 102 |
|
| 103 |
logging.info("Moving embeddings to the issues vector store...")
|
| 104 |
+
issues_vector_store = build_vector_store_from_args(args)
|
| 105 |
issues_vector_store.ensure_exists()
|
| 106 |
issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
|
| 107 |
|
sage/llm.py
CHANGED
|
@@ -10,12 +10,12 @@ def build_llm_via_langchain(provider: str, model: str):
|
|
| 10 |
if provider == "openai":
|
| 11 |
if "OPENAI_API_KEY" not in os.environ:
|
| 12 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
| 13 |
-
return ChatOpenAI(model=model)
|
| 14 |
elif provider == "anthropic":
|
| 15 |
if "ANTHROPIC_API_KEY" not in os.environ:
|
| 16 |
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
|
| 17 |
-
return ChatAnthropic(model=model)
|
| 18 |
elif provider == "ollama":
|
| 19 |
-
return ChatOllama(model=model)
|
| 20 |
else:
|
| 21 |
raise ValueError(f"Unrecognized LLM provider {provider}. Contributons are welcome!")
|
|
|
|
| 10 |
if provider == "openai":
|
| 11 |
if "OPENAI_API_KEY" not in os.environ:
|
| 12 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
| 13 |
+
return ChatOpenAI(model=model or "gpt-4")
|
| 14 |
elif provider == "anthropic":
|
| 15 |
if "ANTHROPIC_API_KEY" not in os.environ:
|
| 16 |
raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.")
|
| 17 |
+
return ChatAnthropic(model=model or "claude-3-opus-20240229")
|
| 18 |
elif provider == "ollama":
|
| 19 |
+
return ChatOllama(model=model or "llama3.1")
|
| 20 |
else:
|
| 21 |
raise ValueError(f"Unrecognized LLM provider {provider}. Contributons are welcome!")
|
sage/vector_store.py
CHANGED
|
@@ -146,35 +146,16 @@ class MarqoVectorStore(VectorStore):
|
|
| 146 |
return vectorstore.as_retriever(search_kwargs={"k": top_k})
|
| 147 |
|
| 148 |
|
| 149 |
-
def
|
| 150 |
"""Builds a vector store from the given command-line arguments."""
|
| 151 |
-
if args.
|
| 152 |
-
if not args.pinecone_index_name:
|
| 153 |
-
raise ValueError("Please specify --pinecone-index-name for Pinecone.")
|
| 154 |
-
dimension = args.embedding_size if "embedding_size" in args else None
|
| 155 |
-
|
| 156 |
-
index_namespace = args.index_namespace
|
| 157 |
-
if not index_namespace:
|
| 158 |
-
index_namespace = args.repo_id
|
| 159 |
-
if args.commit_hash:
|
| 160 |
-
namespace += "/" + args.commit_hash
|
| 161 |
-
|
| 162 |
return PineconeVectorStore(
|
| 163 |
index_name=args.pinecone_index_name,
|
| 164 |
-
namespace=index_namespace,
|
| 165 |
-
dimension=
|
| 166 |
hybrid=args.hybrid_retrieval,
|
| 167 |
)
|
| 168 |
-
elif args.
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
index_namespace = args.index_namespace
|
| 172 |
-
if not index_namespace:
|
| 173 |
-
# Marqo doesn't allow slashes in the index name.
|
| 174 |
-
index_namespace = args.repo_id.split("/")[1]
|
| 175 |
-
if args.commit_hash:
|
| 176 |
-
index_namespace += "_" + args.commit_hash
|
| 177 |
-
|
| 178 |
-
return MarqoVectorStore(url=marqo_url, index_name=index_namespace)
|
| 179 |
else:
|
| 180 |
-
raise ValueError(f"Unrecognized vector store type {args.
|
|
|
|
| 146 |
return vectorstore.as_retriever(search_kwargs={"k": top_k})
|
| 147 |
|
| 148 |
|
| 149 |
+
def build_vector_store_from_args(args: dict) -> VectorStore:
|
| 150 |
"""Builds a vector store from the given command-line arguments."""
|
| 151 |
+
if args.vector_store_provider == "pinecone":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
return PineconeVectorStore(
|
| 153 |
index_name=args.pinecone_index_name,
|
| 154 |
+
namespace=args.index_namespace,
|
| 155 |
+
dimension=args.embedding_size if "embedding_size" in args else None,
|
| 156 |
hybrid=args.hybrid_retrieval,
|
| 157 |
)
|
| 158 |
+
elif args.vector_store_provider == "marqo":
|
| 159 |
+
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_namespace)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
else:
|
| 161 |
+
raise ValueError(f"Unrecognized vector store type {args.vector_store_provider}")
|