Mihail Eric commited on
Commit
d8387cd
·
1 Parent(s): ca05bc9

move punkt download to vector store where it's actually used

Browse files
Files changed (4) hide show
  1. README.md +3 -12
  2. sage/.sample-env +0 -3
  3. sage/index.py +1 -18
  4. sage/vector_store.py +14 -0
README.md CHANGED
@@ -104,18 +104,9 @@ Export the API key of the desired provider:
104
  export ANTHROPIC_API_KEY=...
105
  ```
106
 
107
- For easier configuration, create a `.sage-env` file with the following contents (change the API keys names based on your desired setup):
108
- ```
109
- # Embeddings
110
- export OPENAI_API_KEY=
111
- # Vector store
112
- export PINECONE_API_KEY=
113
- # Reranking
114
- export NVIDIA_API_KEY=
115
- # Generation LLM
116
- export ANTHROPIC_API_KEY=
117
- # Github issues
118
- export GITHUB_TOKEN=
119
  ```
120
  </details>
121
 
 
104
  export ANTHROPIC_API_KEY=...
105
  ```
106
 
107
+ For easier configuration, adapt the entries within the sample `.sage-env` (change the API keys names based on your desired setup) and run:
108
+ ```
109
+ source .sage-env
 
 
 
 
 
 
 
 
 
110
  ```
111
  </details>
112
 
sage/.sample-env DELETED
@@ -1,3 +0,0 @@
1
- OPENAI_API_KEY=
2
- PINECONE_API_KEY=
3
- GITHUB_TOKEN=
 
 
 
 
sage/index.py CHANGED
@@ -14,20 +14,10 @@ from sage.embedder import build_batch_embedder_from_flags
14
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
15
  from sage.vector_store import build_vector_store_from_args
16
 
17
-
18
- from nltk.data import find
19
-
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger()
22
  logger.setLevel(logging.INFO)
23
 
24
- def is_punkt_downloaded():
25
- try:
26
- find('tokenizers/punkt_tab')
27
- return True
28
- except LookupError:
29
- return False
30
-
31
 
32
  def main():
33
  parser = configargparse.ArgParser(
@@ -53,14 +43,7 @@ def main():
53
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
54
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
55
 
56
- # We need nltk tokenizers for
57
- if is_punkt_downloaded():
58
- print("punkt is already downloaded")
59
- else:
60
- print("punkt is not downloaded")
61
- # Optionally download it
62
- nltk.download('punkt_tab')
63
-
64
  ######################
65
  # Step 1: Embeddings #
66
  ######################
 
14
  from sage.github import GitHubIssuesChunker, GitHubIssuesManager
15
  from sage.vector_store import build_vector_store_from_args
16
 
 
 
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger()
19
  logger.setLevel(logging.INFO)
20
 
 
 
 
 
 
 
 
21
 
22
  def main():
23
  parser = configargparse.ArgParser(
 
43
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
44
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
45
 
46
+
 
 
 
 
 
 
 
47
  ######################
48
  # Step 1: Embeddings #
49
  ######################
sage/vector_store.py CHANGED
@@ -12,6 +12,7 @@ from langchain_community.vectorstores import Marqo
12
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
13
  from langchain_core.documents import Document
14
  from langchain_core.embeddings import Embeddings
 
15
  from pinecone import Pinecone, ServerlessSpec
16
  from pinecone_text.sparse import BM25Encoder
17
 
@@ -20,6 +21,12 @@ from sage.data_manager import DataManager
20
 
21
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
22
 
 
 
 
 
 
 
23
 
24
  class VectorStore(ABC):
25
  """Abstract class for a vector store."""
@@ -71,6 +78,13 @@ class PineconeVectorStore(VectorStore):
71
  if alpha < 1.0:
72
  if bm25_cache and os.path.exists(bm25_cache):
73
  logging.info("Loading BM25 encoder from cache.")
 
 
 
 
 
 
 
74
  self.bm25_encoder = BM25Encoder()
75
  self.bm25_encoder.load(path=bm25_cache)
76
  else:
 
12
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
13
  from langchain_core.documents import Document
14
  from langchain_core.embeddings import Embeddings
15
+ from nltk.data import find
16
  from pinecone import Pinecone, ServerlessSpec
17
  from pinecone_text.sparse import BM25Encoder
18
 
 
21
 
22
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
23
 
24
+ def is_punkt_downloaded():
25
+ try:
26
+ find('tokenizers/punkt_tab')
27
+ return True
28
+ except LookupError:
29
+ return False
30
 
31
  class VectorStore(ABC):
32
  """Abstract class for a vector store."""
 
78
  if alpha < 1.0:
79
  if bm25_cache and os.path.exists(bm25_cache):
80
  logging.info("Loading BM25 encoder from cache.")
81
+ # We need nltk tokenizers for bm25 tokenization
82
+ if is_punkt_downloaded():
83
+ print("punkt is already downloaded")
84
+ else:
85
+ print("punkt is not downloaded")
86
+ # Optionally download it
87
+ nltk.download('punkt_tab')
88
  self.bm25_encoder = BM25Encoder()
89
  self.bm25_encoder.load(path=bm25_cache)
90
  else: