Mihail Eric commited on
Commit
ca05bc9
·
1 Parent(s): 4e68d3a

download nltk if not downloaded

Browse files
Files changed (2) hide show
  1. README.md +16 -1
  2. sage/index.py +19 -0
README.md CHANGED
@@ -89,7 +89,9 @@ pip install git+https://github.com/Storia-AI/sage.git@main
89
  export PINECONE_INDEX_NAME=...
90
  ```
91
 
92
- 3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>. According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. Export the API key of the desired provider:
 
 
93
  ```
94
  export NVIDIA_API_KEY=... # or
95
  export VOYAGE_API_KEY=... # or
@@ -102,6 +104,19 @@ pip install git+https://github.com/Storia-AI/sage.git@main
102
  export ANTHROPIC_API_KEY=...
103
  ```
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  </details>
106
 
107
  ### Optional
 
89
  export PINECONE_INDEX_NAME=...
90
  ```
91
 
92
+ 3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>. According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. Note: for NVIDIA you should use the `nvidia/nv-rerankqa-mistral-4b-v3` reranker.
93
+
94
+ Export the API key of the desired provider:
95
  ```
96
  export NVIDIA_API_KEY=... # or
97
  export VOYAGE_API_KEY=... # or
 
104
  export ANTHROPIC_API_KEY=...
105
  ```
106
 
107
+ For easier configuration, create a `.sage-env` file with the following contents (change the API keys names based on your desired setup):
108
+ ```
109
+ # Embeddings
110
+ export OPENAI_API_KEY=
111
+ # Vector store
112
+ export PINECONE_API_KEY=
113
+ # Reranking
114
+ export NVIDIA_API_KEY=
115
+ # Generation LLM
116
+ export ANTHROPIC_API_KEY=
117
+ # Github issues
118
+ export GITHUB_TOKEN=
119
+ ```
120
  </details>
121
 
122
  ### Optional
sage/index.py CHANGED
@@ -1,6 +1,7 @@
1
  """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
 
3
  import logging
 
4
  import time
5
 
6
  import configargparse
@@ -13,10 +14,20 @@ from sage.embedder import build_batch_embedder_from_flags
13
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
14
  from sage.vector_store import build_vector_store_from_args
15
 
 
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger()
18
  logger.setLevel(logging.INFO)
19
 
 
 
 
 
 
 
 
20
 
21
  def main():
22
  parser = configargparse.ArgParser(
@@ -42,6 +53,14 @@ def main():
42
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
43
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
44
 
 
 
 
 
 
 
 
 
45
  ######################
46
  # Step 1: Embeddings #
47
  ######################
 
1
  """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
 
3
  import logging
4
+ import nltk
5
  import time
6
 
7
  import configargparse
 
14
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
15
  from sage.vector_store import build_vector_store_from_args
16
 
17
+
18
+ from nltk.data import find
19
+
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger()
22
  logger.setLevel(logging.INFO)
23
 
24
+ def is_punkt_downloaded():
25
+ try:
26
+ find('tokenizers/punkt_tab')
27
+ return True
28
+ except LookupError:
29
+ return False
30
+
31
 
32
  def main():
33
  parser = configargparse.ArgParser(
 
53
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
54
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
55
 
56
+ # We need nltk tokenizers for
57
+ if is_punkt_downloaded():
58
+ print("punkt is already downloaded")
59
+ else:
60
+ print("punkt is not downloaded")
61
+ # Optionally download it
62
+ nltk.download('punkt_tab')
63
+
64
  ######################
65
  # Step 1: Embeddings #
66
  ######################