GitHub Actions commited on
Commit
ba41aa8
·
1 Parent(s): 7ca251e

Auto-format code with isort and black

Browse files
benchmarks/retrieval/retrieve.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  Make sure to `pip install ir_measures` before running this script.
4
  """
 
5
  import json
6
  import logging
7
  import os
@@ -21,6 +22,7 @@ logger.setLevel(logging.INFO)
21
 
22
  load_dotenv()
23
 
 
24
  def main():
25
  parser = configargparse.ArgParser(
26
  description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
@@ -49,12 +51,12 @@ def main():
49
  args = parser.parse_args()
50
  sage.config.validate_vector_store_args(args)
51
  repo_manager = GitHubRepoManager(
52
- args.repo_id,
53
- commit_hash=args.commit_hash,
54
- access_token=os.getenv("GITHUB_TOKEN"),
55
- local_dir=args.local_dir,
56
- inclusion_file=args.include,
57
- exclusion_file=args.exclude,
58
  )
59
  repo_manager.download()
60
  retriever = build_retriever_from_args(args, repo_manager)
 
2
 
3
  Make sure to `pip install ir_measures` before running this script.
4
  """
5
+
6
  import json
7
  import logging
8
  import os
 
22
 
23
  load_dotenv()
24
 
25
+
26
  def main():
27
  parser = configargparse.ArgParser(
28
  description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
 
51
  args = parser.parse_args()
52
  sage.config.validate_vector_store_args(args)
53
  repo_manager = GitHubRepoManager(
54
+ args.repo_id,
55
+ commit_hash=args.commit_hash,
56
+ access_token=os.getenv("GITHUB_TOKEN"),
57
+ local_dir=args.local_dir,
58
+ inclusion_file=args.include,
59
+ exclusion_file=args.exclude,
60
  )
61
  repo_manager.download()
62
  retriever = build_retriever_from_args(args, repo_manager)
sage/config.py CHANGED
@@ -313,7 +313,9 @@ def _validate_gemini_embedding_args(args):
313
  """Validates the configuration of the Gemini batch embedder and sets defaults."""
314
  if not args.embedding_model:
315
  args.embedding_model = "models/text-embedding-004"
316
- assert os.environ["GOOGLE_API_KEY"], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
 
 
317
  if not args.chunks_per_batch:
318
  args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
319
  elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
 
313
  """Validates the configuration of the Gemini batch embedder and sets defaults."""
314
  if not args.embedding_model:
315
  args.embedding_model = "models/text-embedding-004"
316
+ assert os.environ[
317
+ "GOOGLE_API_KEY"
318
+ ], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
319
  if not args.chunks_per_batch:
320
  args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
321
  elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
sage/embedder.py CHANGED
@@ -4,25 +4,17 @@ import json
4
  import logging
5
  import os
6
  import time
7
- from abc import ABC
8
- from abc import abstractmethod
9
  from collections import Counter
10
- from typing import Dict
11
- from typing import Generator
12
- from typing import List
13
- from typing import Optional
14
- from typing import Tuple
15
 
16
  import google.generativeai as genai
17
  import marqo
18
  import requests
19
  from openai import OpenAI
20
- from tenacity import retry
21
- from tenacity import stop_after_attempt
22
- from tenacity import wait_random_exponential
23
 
24
- from sage.chunker import Chunk
25
- from sage.chunker import Chunker
26
  from sage.constants import TEXT_FIELD
27
  from sage.data_manager import DataManager
28
 
@@ -72,7 +64,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
72
 
73
  if len(batch) > chunks_per_batch:
74
  for i in range(0, len(batch), chunks_per_batch):
75
- sub_batch = batch[i: i + chunks_per_batch]
76
  openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
77
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
78
  if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
@@ -242,7 +234,7 @@ class VoyageBatchEmbedder(BatchEmbedder):
242
 
243
  if len(batch) > chunks_per_batch:
244
  for i in range(0, len(batch), chunks_per_batch):
245
- sub_batch = batch[i: i + chunks_per_batch]
246
  logging.info("Embedding %d chunks...", len(sub_batch))
247
  result = self._make_batch_request(sub_batch)
248
  for chunk, datum in zip(sub_batch, result["data"]):
@@ -314,7 +306,7 @@ class MarqoEmbedder(BatchEmbedder):
314
 
315
  if len(batch) > chunks_per_batch:
316
  for i in range(0, len(batch), chunks_per_batch):
317
- sub_batch = batch[i: i + chunks_per_batch]
318
  logging.info("Indexing %d chunks...", len(sub_batch))
319
  self.index.add_documents(
320
  documents=[chunk.metadata for chunk in sub_batch],
@@ -356,9 +348,8 @@ class GeminiBatchEmbedder(BatchEmbedder):
356
 
357
  def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
358
  return genai.embed_content(
359
- model=self.embedding_model,
360
- content=[chunk.content for chunk in chunks],
361
- task_type="retrieval_document")
362
 
363
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
364
  """Issues batch embedding jobs for the entire dataset."""
@@ -375,7 +366,7 @@ class GeminiBatchEmbedder(BatchEmbedder):
375
 
376
  if len(batch) > chunks_per_batch:
377
  for i in range(0, len(batch), chunks_per_batch):
378
- sub_batch = batch[i: i + chunks_per_batch]
379
  logging.info("Embedding %d chunks...", len(sub_batch))
380
  result = self._make_batch_request(sub_batch)
381
  for chunk, embedding in zip(sub_batch, result["embedding"]):
 
4
  import logging
5
  import os
6
  import time
7
+ from abc import ABC, abstractmethod
 
8
  from collections import Counter
9
+ from typing import Dict, Generator, List, Optional, Tuple
 
 
 
 
10
 
11
  import google.generativeai as genai
12
  import marqo
13
  import requests
14
  from openai import OpenAI
15
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
 
 
16
 
17
+ from sage.chunker import Chunk, Chunker
 
18
  from sage.constants import TEXT_FIELD
19
  from sage.data_manager import DataManager
20
 
 
64
 
65
  if len(batch) > chunks_per_batch:
66
  for i in range(0, len(batch), chunks_per_batch):
67
+ sub_batch = batch[i : i + chunks_per_batch]
68
  openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
69
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
70
  if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
 
234
 
235
  if len(batch) > chunks_per_batch:
236
  for i in range(0, len(batch), chunks_per_batch):
237
+ sub_batch = batch[i : i + chunks_per_batch]
238
  logging.info("Embedding %d chunks...", len(sub_batch))
239
  result = self._make_batch_request(sub_batch)
240
  for chunk, datum in zip(sub_batch, result["data"]):
 
306
 
307
  if len(batch) > chunks_per_batch:
308
  for i in range(0, len(batch), chunks_per_batch):
309
+ sub_batch = batch[i : i + chunks_per_batch]
310
  logging.info("Indexing %d chunks...", len(sub_batch))
311
  self.index.add_documents(
312
  documents=[chunk.metadata for chunk in sub_batch],
 
348
 
349
  def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
350
  return genai.embed_content(
351
+ model=self.embedding_model, content=[chunk.content for chunk in chunks], task_type="retrieval_document"
352
+ )
 
353
 
354
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
355
  """Issues batch embedding jobs for the entire dataset."""
 
366
 
367
  if len(batch) > chunks_per_batch:
368
  for i in range(0, len(batch), chunks_per_batch):
369
+ sub_batch = batch[i : i + chunks_per_batch]
370
  logging.info("Embedding %d chunks...", len(sub_batch))
371
  result = self._make_batch_request(sub_batch)
372
  for chunk, embedding in zip(sub_batch, result["embedding"]):
sage/reranker.py CHANGED
@@ -10,6 +10,7 @@ from langchain_core.documents import BaseDocumentCompressor
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
11
  from langchain_voyageai import VoyageAIRerank
12
 
 
13
  class RerankerProvider(Enum):
14
  NONE = "none"
15
  HUGGINGFACE = "huggingface"
 
10
  from langchain_nvidia_ai_endpoints import NVIDIARerank
11
  from langchain_voyageai import VoyageAIRerank
12
 
13
+
14
  class RerankerProvider(Enum):
15
  NONE = "none"
16
  HUGGINGFACE = "huggingface"
sage/vector_store.py CHANGED
@@ -198,7 +198,7 @@ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager]
198
  else:
199
  print("punkt is not downloaded")
200
  # Optionally download it
201
- nltk.download('punkt_tab')
202
  corpus = [content for content, _ in data_manager.walk()]
203
  bm25_encoder = BM25Encoder()
204
  bm25_encoder.fit(corpus)
 
198
  else:
199
  print("punkt is not downloaded")
200
  # Optionally download it
201
+ nltk.download("punkt_tab")
202
  corpus = [content for content, _ in data_manager.walk()]
203
  bm25_encoder = BM25Encoder()
204
  bm25_encoder.fit(corpus)