Mihail Eric commited on
Commit
f2ad04a
·
2 Parent(s): 8b42d65 9581b48

download nltk if not detected (#42)

Browse files
Files changed (5) hide show
  1. README.md +7 -1
  2. sage/.sage-env +10 -0
  3. sage/.sample-env +0 -3
  4. sage/index.py +1 -0
  5. sage/vector_store.py +14 -0
README.md CHANGED
@@ -89,7 +89,9 @@ pip install git+https://github.com/Storia-AI/sage.git@main
89
  export PINECONE_INDEX_NAME=...
90
  ```
91
 
92
- 3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>. According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. Export the API key of the desired provider:
 
 
93
  ```
94
  export NVIDIA_API_KEY=... # or
95
  export VOYAGE_API_KEY=... # or
@@ -102,6 +104,10 @@ pip install git+https://github.com/Storia-AI/sage.git@main
102
  export ANTHROPIC_API_KEY=...
103
  ```
104
 
 
 
 
 
105
  </details>
106
 
107
  ### Optional
 
89
  export PINECONE_INDEX_NAME=...
90
  ```
91
 
92
+ 3. For reranking, we support <a href="https://developer.nvidia.com/blog/enhancing-rag-pipelines-with-re-ranking/">NVIDIA</a>, <a href="https://docs.voyageai.com/docs/reranker">Voyage</a>, <a href="https://cohere.com/rerank">Cohere</a>, and <a href="https://jina.ai/reranker/">Jina</a>. According to [our experiments](benchmark/retrieval/README.md), NVIDIA performs best. Note: for NVIDIA you should use the `nvidia/nv-rerankqa-mistral-4b-v3` reranker.
93
+
94
+ Export the API key of the desired provider:
95
  ```
96
  export NVIDIA_API_KEY=... # or
97
  export VOYAGE_API_KEY=... # or
 
104
  export ANTHROPIC_API_KEY=...
105
  ```
106
 
107
+ For easier configuration, adapt the entries within the sample `.sage-env` (change the API keys names based on your desired setup) and run:
108
+ ```
109
+ source .sage-env
110
+ ```
111
  </details>
112
 
113
  ### Optional
sage/.sage-env ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Embeddings
2
+ export OPENAI_API_KEY=
3
+ # Vector store
4
+ export PINECONE_API_KEY=
5
+ # Reranking
6
+ export NVIDIA_API_KEY=
7
+ # Generation LLM
8
+ export ANTHROPIC_API_KEY=
9
+ # Github issues
10
+ export GITHUB_TOKEN=
sage/.sample-env DELETED
@@ -1,3 +0,0 @@
1
- OPENAI_API_KEY=
2
- PINECONE_API_KEY=
3
- GITHUB_TOKEN=
 
 
 
 
sage/index.py CHANGED
@@ -42,6 +42,7 @@ def main():
42
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
43
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
44
 
 
45
  ######################
46
  # Step 1: Embeddings #
47
  ######################
 
42
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
43
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
44
 
45
+
46
  ######################
47
  # Step 1: Embeddings #
48
  ######################
sage/vector_store.py CHANGED
@@ -12,6 +12,7 @@ from langchain_community.vectorstores import Marqo
12
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
13
  from langchain_core.documents import Document
14
  from langchain_core.embeddings import Embeddings
 
15
  from pinecone import Pinecone, ServerlessSpec
16
  from pinecone_text.sparse import BM25Encoder
17
 
@@ -20,6 +21,12 @@ from sage.data_manager import DataManager
20
 
21
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
22
 
 
 
 
 
 
 
23
 
24
  class VectorStore(ABC):
25
  """Abstract class for a vector store."""
@@ -69,6 +76,13 @@ class PineconeVectorStore(VectorStore):
69
  if alpha < 1.0:
70
  if bm25_cache and os.path.exists(bm25_cache):
71
  logging.info("Loading BM25 encoder from cache.")
 
 
 
 
 
 
 
72
  self.bm25_encoder = BM25Encoder()
73
  self.bm25_encoder.load(path=bm25_cache)
74
  else:
 
12
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
13
  from langchain_core.documents import Document
14
  from langchain_core.embeddings import Embeddings
15
+ from nltk.data import find
16
  from pinecone import Pinecone, ServerlessSpec
17
  from pinecone_text.sparse import BM25Encoder
18
 
 
21
 
22
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
23
 
24
+ def is_punkt_downloaded():
25
+ try:
26
+ find('tokenizers/punkt_tab')
27
+ return True
28
+ except LookupError:
29
+ return False
30
 
31
  class VectorStore(ABC):
32
  """Abstract class for a vector store."""
 
76
  if alpha < 1.0:
77
  if bm25_cache and os.path.exists(bm25_cache):
78
  logging.info("Loading BM25 encoder from cache.")
79
+ # We need nltk tokenizers for bm25 tokenization
80
+ if is_punkt_downloaded():
81
+ print("punkt is already downloaded")
82
+ else:
83
+ print("punkt is not downloaded")
84
+ # Optionally download it
85
+ nltk.download('punkt_tab')
86
  self.bm25_encoder = BM25Encoder()
87
  self.bm25_encoder.load(path=bm25_cache)
88
  else: