Mihail Eric Mihail Eric commited on
Commit
7ca251e
·
1 Parent(s): c389c2f

add support for google gemini embeddings as an embedder (#56)

Browse files

* wip on retrieve

* add ir measures

* add support for google embeddings

* voyage api key

* update embeddings and update reqs

* remove todo

* github head ref

* fix workflow

* revert workflow

---------

Co-authored-by: Mihail Eric <mihaileric@Mihails-MacBook-Pro.local>

sage/.sage-env → .sage-env RENAMED
File without changes
benchmarks/retrieval/assets/embeddings.png CHANGED
benchmarks/retrieval/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ dotenv
2
+ ir_measures
benchmarks/retrieval/retrieve.py CHANGED
@@ -2,22 +2,24 @@
2
 
3
  Make sure to `pip install ir_measures` before running this script.
4
  """
5
-
6
  import json
7
  import logging
8
  import os
9
  import time
10
 
11
  import configargparse
 
12
  from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
13
 
14
  import sage.config
 
15
  from sage.retriever import build_retriever_from_args
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger()
19
  logger.setLevel(logging.INFO)
20
 
 
21
 
22
  def main():
23
  parser = configargparse.ArgParser(
@@ -35,17 +37,27 @@ def main():
35
  default=None,
36
  help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
37
  )
38
- parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
39
 
 
40
  sage.config.add_config_args(parser)
41
  sage.config.add_llm_args(parser) # Needed for --multi-query-retriever, which rewrites the query with an LLM.
42
  sage.config.add_embedding_args(parser)
43
  sage.config.add_vector_store_args(parser)
44
  sage.config.add_reranking_args(parser)
 
 
45
  args = parser.parse_args()
46
  sage.config.validate_vector_store_args(args)
47
-
48
- retriever = build_retriever_from_args(args)
 
 
 
 
 
 
 
 
49
 
50
  with open(args.benchmark, "r") as f:
51
  benchmark = json.load(f)
@@ -70,7 +82,7 @@ def main():
70
  item["retrieved"] = []
71
  for doc_idx, doc in enumerate(retrieved):
72
  # The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
73
- # the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
74
  # score, we use 1/(doc_idx+1) since it preserves the order of the documents.
75
  score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
76
  retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
@@ -83,7 +95,6 @@ def main():
83
  print("Calculating metrics...")
84
  results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
85
  results = {str(key): value for key, value in results.items()}
86
-
87
  if args.logs_dir:
88
  if not os.path.exists(args.logs_dir):
89
  os.makedirs(args.logs_dir)
 
2
 
3
  Make sure to `pip install ir_measures` before running this script.
4
  """
 
5
  import json
6
  import logging
7
  import os
8
  import time
9
 
10
  import configargparse
11
+ from dotenv import load_dotenv
12
  from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
13
 
14
  import sage.config
15
+ from sage.data_manager import GitHubRepoManager
16
  from sage.retriever import build_retriever_from_args
17
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger()
20
  logger.setLevel(logging.INFO)
21
 
22
+ load_dotenv()
23
 
24
  def main():
25
  parser = configargparse.ArgParser(
 
37
  default=None,
38
  help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
39
  )
 
40
 
41
+ parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
42
  sage.config.add_config_args(parser)
43
  sage.config.add_llm_args(parser) # Needed for --multi-query-retriever, which rewrites the query with an LLM.
44
  sage.config.add_embedding_args(parser)
45
  sage.config.add_vector_store_args(parser)
46
  sage.config.add_reranking_args(parser)
47
+ sage.config.add_repo_args(parser)
48
+ sage.config.add_indexing_args(parser)
49
  args = parser.parse_args()
50
  sage.config.validate_vector_store_args(args)
51
+ repo_manager = GitHubRepoManager(
52
+ args.repo_id,
53
+ commit_hash=args.commit_hash,
54
+ access_token=os.getenv("GITHUB_TOKEN"),
55
+ local_dir=args.local_dir,
56
+ inclusion_file=args.include,
57
+ exclusion_file=args.exclude,
58
+ )
59
+ repo_manager.download()
60
+ retriever = build_retriever_from_args(args, repo_manager)
61
 
62
  with open(args.benchmark, "r") as f:
63
  benchmark = json.load(f)
 
82
  item["retrieved"] = []
83
  for doc_idx, doc in enumerate(retrieved):
84
  # The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
85
+ # the retrieved documents. The key of the score varies depending on the underlying retriever. If there's no
86
  # score, we use 1/(doc_idx+1) since it preserves the order of the documents.
87
  score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
88
  retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
 
95
  print("Calculating metrics...")
96
  results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
97
  results = {str(key): value for key, value in results.items()}
 
98
  if args.logs_dir:
99
  if not os.path.exists(args.logs_dir):
100
  os.makedirs(args.logs_dir)
requirements.txt CHANGED
@@ -3,6 +3,7 @@ Pygments==2.18.0
3
  cohere==5.9.2
4
  configargparse
5
  fastapi==0.112.2
 
6
  gradio>=4.26.0
7
  langchain==0.2.16
8
  langchain-anthropic==0.1.23
@@ -10,6 +11,7 @@ langchain-cohere==0.2.4
10
  langchain-community==0.2.17
11
  langchain-core==0.2.41
12
  langchain-experimental==0.0.65
 
13
  langchain-nvidia-ai-endpoints==0.2.2
14
  langchain-ollama==0.1.3
15
  langchain-openai==0.1.25
 
3
  cohere==5.9.2
4
  configargparse
5
  fastapi==0.112.2
6
+ google-ai-generativelanguage==0.6.6
7
  gradio>=4.26.0
8
  langchain==0.2.16
9
  langchain-anthropic==0.1.23
 
11
  langchain-community==0.2.17
12
  langchain-core==0.2.41
13
  langchain-experimental==0.0.65
14
+ langchain-google-genai
15
  langchain-nvidia-ai-endpoints==0.2.2
16
  langchain-ollama==0.1.3
17
  langchain-openai==0.1.25
sage/config.py CHANGED
@@ -11,6 +11,11 @@ from configargparse import ArgumentParser
11
 
12
  from sage.reranker import RerankerProvider
13
 
 
 
 
 
 
14
  MARQO_MAX_CHUNKS_PER_BATCH = 64
15
  # The ADA embedder from OpenAI has a maximum of 8192 tokens.
16
  OPENAI_MAX_TOKENS_PER_CHUNK = 8192
@@ -82,7 +87,7 @@ def add_repo_args(parser: ArgumentParser) -> Callable:
82
 
83
  def add_embedding_args(parser: ArgumentParser) -> Callable:
84
  """Adds embedding-related arguments to the parser and returns a validator."""
85
- parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo"])
86
  parser.add(
87
  "--embedding-model",
88
  type=str,
@@ -304,6 +309,26 @@ def _validate_marqo_embedding_args(args):
304
  )
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def validate_embedding_args(args):
308
  """Validates the configuration of the batch embedder and sets defaults."""
309
  if args.embedding_provider == "openai":
@@ -312,6 +337,8 @@ def validate_embedding_args(args):
312
  _validate_voyage_embedding_args(args)
313
  elif args.embedding_provider == "marqo":
314
  _validate_marqo_embedding_args(args)
 
 
315
  else:
316
  raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
317
 
 
11
 
12
  from sage.reranker import RerankerProvider
13
 
14
+ # Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
15
+ # NOTE: MAX_CHUNKS_PER_BATCH isn't documented anywhere but we pick a reasonable value
16
+ GEMINI_MAX_CHUNKS_PER_BATCH = 64
17
+ GEMINI_MAX_TOKENS_PER_CHUNK = 2048
18
+
19
  MARQO_MAX_CHUNKS_PER_BATCH = 64
20
  # The ADA embedder from OpenAI has a maximum of 8192 tokens.
21
  OPENAI_MAX_TOKENS_PER_CHUNK = 8192
 
87
 
88
  def add_embedding_args(parser: ArgumentParser) -> Callable:
89
  """Adds embedding-related arguments to the parser and returns a validator."""
90
+ parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo", "gemini"])
91
  parser.add(
92
  "--embedding-model",
93
  type=str,
 
309
  )
310
 
311
 
312
+ def _validate_gemini_embedding_args(args):
313
+ """Validates the configuration of the Gemini batch embedder and sets defaults."""
314
+ if not args.embedding_model:
315
+ args.embedding_model = "models/text-embedding-004"
316
+ assert os.environ["GOOGLE_API_KEY"], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
317
+ if not args.chunks_per_batch:
318
+ args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
319
+ elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
320
+ args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
321
+ logging.warning(
322
+ f"Gemini enforces a limit of {GEMINI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
323
+ "Overwriting embeddings.chunks_per_batch."
324
+ )
325
+
326
+ if not args.tokens_per_chunk:
327
+ args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
328
+ if not args.embedding_size:
329
+ args.embedding_size = 768
330
+
331
+
332
  def validate_embedding_args(args):
333
  """Validates the configuration of the batch embedder and sets defaults."""
334
  if args.embedding_provider == "openai":
 
337
  _validate_voyage_embedding_args(args)
338
  elif args.embedding_provider == "marqo":
339
  _validate_marqo_embedding_args(args)
340
+ elif args.embedding_provider == "gemini":
341
+ _validate_gemini_embedding_args(args)
342
  else:
343
  raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
344
 
sage/embedder.py CHANGED
@@ -4,16 +4,25 @@ import json
4
  import logging
5
  import os
6
  import time
7
- from abc import ABC, abstractmethod
 
8
  from collections import Counter
9
- from typing import Dict, Generator, List, Optional, Tuple
 
 
 
 
10
 
 
11
  import marqo
12
  import requests
13
  from openai import OpenAI
14
- from tenacity import retry, stop_after_attempt, wait_random_exponential
 
 
15
 
16
- from sage.chunker import Chunk, Chunker
 
17
  from sage.constants import TEXT_FIELD
18
  from sage.data_manager import DataManager
19
 
@@ -63,7 +72,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
63
 
64
  if len(batch) > chunks_per_batch:
65
  for i in range(0, len(batch), chunks_per_batch):
66
- sub_batch = batch[i : i + chunks_per_batch]
67
  openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
68
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
69
  if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
@@ -233,7 +242,7 @@ class VoyageBatchEmbedder(BatchEmbedder):
233
 
234
  if len(batch) > chunks_per_batch:
235
  for i in range(0, len(batch), chunks_per_batch):
236
- sub_batch = batch[i : i + chunks_per_batch]
237
  logging.info("Embedding %d chunks...", len(sub_batch))
238
  result = self._make_batch_request(sub_batch)
239
  for chunk, datum in zip(sub_batch, result["data"]):
@@ -305,7 +314,7 @@ class MarqoEmbedder(BatchEmbedder):
305
 
306
  if len(batch) > chunks_per_batch:
307
  for i in range(0, len(batch), chunks_per_batch):
308
- sub_batch = batch[i : i + chunks_per_batch]
309
  logging.info("Indexing %d chunks...", len(sub_batch))
310
  self.index.add_documents(
311
  documents=[chunk.metadata for chunk in sub_batch],
@@ -335,6 +344,79 @@ class MarqoEmbedder(BatchEmbedder):
335
  return []
336
 
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
339
  if args.embedding_provider == "openai":
340
  return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
@@ -344,5 +426,7 @@ def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker,
344
  return MarqoEmbedder(
345
  data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
346
  )
 
 
347
  else:
348
  raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
 
4
  import logging
5
  import os
6
  import time
7
+ from abc import ABC
8
+ from abc import abstractmethod
9
  from collections import Counter
10
+ from typing import Dict
11
+ from typing import Generator
12
+ from typing import List
13
+ from typing import Optional
14
+ from typing import Tuple
15
 
16
+ import google.generativeai as genai
17
  import marqo
18
  import requests
19
  from openai import OpenAI
20
+ from tenacity import retry
21
+ from tenacity import stop_after_attempt
22
+ from tenacity import wait_random_exponential
23
 
24
+ from sage.chunker import Chunk
25
+ from sage.chunker import Chunker
26
  from sage.constants import TEXT_FIELD
27
  from sage.data_manager import DataManager
28
 
 
72
 
73
  if len(batch) > chunks_per_batch:
74
  for i in range(0, len(batch), chunks_per_batch):
75
+ sub_batch = batch[i: i + chunks_per_batch]
76
  openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
77
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
78
  if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
 
242
 
243
  if len(batch) > chunks_per_batch:
244
  for i in range(0, len(batch), chunks_per_batch):
245
+ sub_batch = batch[i: i + chunks_per_batch]
246
  logging.info("Embedding %d chunks...", len(sub_batch))
247
  result = self._make_batch_request(sub_batch)
248
  for chunk, datum in zip(sub_batch, result["data"]):
 
314
 
315
  if len(batch) > chunks_per_batch:
316
  for i in range(0, len(batch), chunks_per_batch):
317
+ sub_batch = batch[i: i + chunks_per_batch]
318
  logging.info("Indexing %d chunks...", len(sub_batch))
319
  self.index.add_documents(
320
  documents=[chunk.metadata for chunk in sub_batch],
 
344
  return []
345
 
346
 
347
+ class GeminiBatchEmbedder(BatchEmbedder):
348
+ """Batch embedder that calls Gemini."""
349
+
350
+ def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: str):
351
+ self.data_manager = data_manager
352
+ self.chunker = chunker
353
+ self.embedding_data = []
354
+ self.embedding_model = embedding_model
355
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
356
+
357
+ def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
358
+ return genai.embed_content(
359
+ model=self.embedding_model,
360
+ content=[chunk.content for chunk in chunks],
361
+ task_type="retrieval_document")
362
+
363
+ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
364
+ """Issues batch embedding jobs for the entire dataset."""
365
+ batch = []
366
+ chunk_count = 0
367
+
368
+ request_count = 0
369
+ last_request_time = time.time()
370
+
371
+ for content, metadata in self.data_manager.walk():
372
+ chunks = self.chunker.chunk(content, metadata)
373
+ chunk_count += len(chunks)
374
+ batch.extend(chunks)
375
+
376
+ if len(batch) > chunks_per_batch:
377
+ for i in range(0, len(batch), chunks_per_batch):
378
+ sub_batch = batch[i: i + chunks_per_batch]
379
+ logging.info("Embedding %d chunks...", len(sub_batch))
380
+ result = self._make_batch_request(sub_batch)
381
+ for chunk, embedding in zip(sub_batch, result["embedding"]):
382
+ self.embedding_data.append((chunk.metadata, embedding))
383
+ request_count += 1
384
+
385
+ # Check if we've made more than 1500 requests in the last minute
386
+ # Rate limits here: https://ai.google.dev/gemini-api/docs/models/gemini
387
+ current_time = time.time()
388
+ elapsed_time = current_time - last_request_time
389
+ if elapsed_time < 60 and request_count >= 1400:
390
+ logging.info("Reached rate limit, pausing for 60 seconds...")
391
+ time.sleep(60)
392
+ last_request_time = current_time
393
+ request_count = 0
394
+ # Reset the last request time and request count if more than 60 sec have passed
395
+ elif elapsed_time > 60:
396
+ last_request_time = current_time
397
+ request_count = 0
398
+
399
+ batch = []
400
+
401
+ # Finally, commit the last batch.
402
+ if batch:
403
+ logging.info("Embedding %d chunks...", len(batch))
404
+ result = self._make_batch_request(batch)
405
+ for chunk, embedding in zip(batch, result["embedding"]):
406
+ self.embedding_data.append((chunk.metadata, embedding))
407
+
408
+ logging.info(f"Successfully embedded {chunk_count} chunks.")
409
+
410
+ def embeddings_are_ready(self, *args, **kwargs) -> bool:
411
+ """Checks whether the batch embedding jobs are done."""
412
+ return True
413
+
414
+ def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]:
415
+ """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
416
+ for chunk_metadata, embedding in self.embedding_data:
417
+ yield chunk_metadata, embedding
418
+
419
+
420
  def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
421
  if args.embedding_provider == "openai":
422
  return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
 
426
  return MarqoEmbedder(
427
  data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
428
  )
429
+ elif args.embedding_provider == "gemini":
430
+ return GeminiBatchEmbedder(data_manager, chunker, embedding_model=args.embedding_model)
431
  else:
432
  raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
sage/reranker.py CHANGED
@@ -10,7 +10,6 @@ from langchain_core.documents import BaseDocumentCompressor
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
11
  from langchain_voyageai import VoyageAIRerank
12
 
13
-
14
  class RerankerProvider(Enum):
15
  NONE = "none"
16
  HUGGINGFACE = "huggingface"
 
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
11
  from langchain_voyageai import VoyageAIRerank
12
 
 
13
  class RerankerProvider(Enum):
14
  NONE = "none"
15
  HUGGINGFACE = "huggingface"
sage/retriever.py CHANGED
@@ -1,24 +1,30 @@
 
 
1
  from langchain.retrievers import ContextualCompressionRetriever
2
  from langchain.retrievers.multi_query import MultiQueryRetriever
 
3
  from langchain_openai import OpenAIEmbeddings
4
  from langchain_voyageai import VoyageAIEmbeddings
5
 
 
6
  from sage.llm import build_llm_via_langchain
7
  from sage.reranker import build_reranker
8
  from sage.vector_store import build_vector_store_from_args
9
 
10
 
11
- def build_retriever_from_args(args):
12
  """Builds a retriever (with optional reranking) from command-line arguments."""
13
 
14
  if args.embedding_provider == "openai":
15
  embeddings = OpenAIEmbeddings(model=args.embedding_model)
16
  elif args.embedding_provider == "voyage":
17
  embeddings = VoyageAIEmbeddings(model=args.embedding_model)
 
 
18
  else:
19
  embeddings = None
20
 
21
- retriever = build_vector_store_from_args(args).as_retriever(
22
  top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
23
  )
24
 
 
1
+ from typing import Optional
2
+
3
  from langchain.retrievers import ContextualCompressionRetriever
4
  from langchain.retrievers.multi_query import MultiQueryRetriever
5
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
6
  from langchain_openai import OpenAIEmbeddings
7
  from langchain_voyageai import VoyageAIEmbeddings
8
 
9
+ from sage.data_manager import DataManager
10
  from sage.llm import build_llm_via_langchain
11
  from sage.reranker import build_reranker
12
  from sage.vector_store import build_vector_store_from_args
13
 
14
 
15
+ def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
16
  """Builds a retriever (with optional reranking) from command-line arguments."""
17
 
18
  if args.embedding_provider == "openai":
19
  embeddings = OpenAIEmbeddings(model=args.embedding_model)
20
  elif args.embedding_provider == "voyage":
21
  embeddings = VoyageAIEmbeddings(model=args.embedding_model)
22
+ elif args.embedding_provider == "gemini":
23
+ embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
24
  else:
25
  embeddings = None
26
 
27
+ retriever = build_vector_store_from_args(args, data_manager).as_retriever(
28
  top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
29
  )
30
 
sage/vector_store.py CHANGED
@@ -75,7 +75,6 @@ class PineconeVectorStore(VectorStore):
75
  self.dimension = dimension
76
  self.client = Pinecone()
77
  self.alpha = alpha
78
-
79
  if alpha < 1.0:
80
  if bm25_cache and os.path.exists(bm25_cache):
81
  logging.info("Loading BM25 encoder from cache.")
@@ -192,9 +191,14 @@ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager]
192
  """
193
  if args.vector_store_provider == "pinecone":
194
  bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
195
-
196
- if not os.path.exists(bm25_cache) and data_manager:
197
  logging.info("Fitting BM25 encoder on the corpus...")
 
 
 
 
 
 
198
  corpus = [content for content, _ in data_manager.walk()]
199
  bm25_encoder = BM25Encoder()
200
  bm25_encoder.fit(corpus)
 
75
  self.dimension = dimension
76
  self.client = Pinecone()
77
  self.alpha = alpha
 
78
  if alpha < 1.0:
79
  if bm25_cache and os.path.exists(bm25_cache):
80
  logging.info("Loading BM25 encoder from cache.")
 
191
  """
192
  if args.vector_store_provider == "pinecone":
193
  bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
194
+ if args.retrieval_alpha < 1.0 and not os.path.exists(bm25_cache) and data_manager:
 
195
  logging.info("Fitting BM25 encoder on the corpus...")
196
+ if is_punkt_downloaded():
197
+ print("punkt is already downloaded")
198
+ else:
199
+ print("punkt is not downloaded")
200
+ # Optionally download it
201
+ nltk.download('punkt_tab')
202
  corpus = [content for content, _ in data_manager.walk()]
203
  bm25_encoder = BM25Encoder()
204
  bm25_encoder.fit(corpus)