Spaces:
Running
Running
GitHub Actions commited on
Commit ·
ba41aa8
1
Parent(s): 7ca251e
Auto-format code with isort and black
Browse files- benchmarks/retrieval/retrieve.py +8 -6
- sage/config.py +3 -1
- sage/embedder.py +10 -19
- sage/reranker.py +1 -0
- sage/vector_store.py +1 -1
benchmarks/retrieval/retrieve.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
Make sure to `pip install ir_measures` before running this script.
|
| 4 |
"""
|
|
|
|
| 5 |
import json
|
| 6 |
import logging
|
| 7 |
import os
|
|
@@ -21,6 +22,7 @@ logger.setLevel(logging.INFO)
|
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
|
|
|
|
| 24 |
def main():
|
| 25 |
parser = configargparse.ArgParser(
|
| 26 |
description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
|
|
@@ -49,12 +51,12 @@ def main():
|
|
| 49 |
args = parser.parse_args()
|
| 50 |
sage.config.validate_vector_store_args(args)
|
| 51 |
repo_manager = GitHubRepoManager(
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
)
|
| 59 |
repo_manager.download()
|
| 60 |
retriever = build_retriever_from_args(args, repo_manager)
|
|
|
|
| 2 |
|
| 3 |
Make sure to `pip install ir_measures` before running this script.
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
import os
|
|
|
|
| 22 |
|
| 23 |
load_dotenv()
|
| 24 |
|
| 25 |
+
|
| 26 |
def main():
|
| 27 |
parser = configargparse.ArgParser(
|
| 28 |
description="Runs retrieval on a benchmark dataset.", ignore_unknown_config_file_keys=True
|
|
|
|
| 51 |
args = parser.parse_args()
|
| 52 |
sage.config.validate_vector_store_args(args)
|
| 53 |
repo_manager = GitHubRepoManager(
|
| 54 |
+
args.repo_id,
|
| 55 |
+
commit_hash=args.commit_hash,
|
| 56 |
+
access_token=os.getenv("GITHUB_TOKEN"),
|
| 57 |
+
local_dir=args.local_dir,
|
| 58 |
+
inclusion_file=args.include,
|
| 59 |
+
exclusion_file=args.exclude,
|
| 60 |
)
|
| 61 |
repo_manager.download()
|
| 62 |
retriever = build_retriever_from_args(args, repo_manager)
|
sage/config.py
CHANGED
|
@@ -313,7 +313,9 @@ def _validate_gemini_embedding_args(args):
|
|
| 313 |
"""Validates the configuration of the Gemini batch embedder and sets defaults."""
|
| 314 |
if not args.embedding_model:
|
| 315 |
args.embedding_model = "models/text-embedding-004"
|
| 316 |
-
assert os.environ[
|
|
|
|
|
|
|
| 317 |
if not args.chunks_per_batch:
|
| 318 |
args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
|
| 319 |
elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
|
|
|
|
| 313 |
"""Validates the configuration of the Gemini batch embedder and sets defaults."""
|
| 314 |
if not args.embedding_model:
|
| 315 |
args.embedding_model = "models/text-embedding-004"
|
| 316 |
+
assert os.environ[
|
| 317 |
+
"GOOGLE_API_KEY"
|
| 318 |
+
], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
|
| 319 |
if not args.chunks_per_batch:
|
| 320 |
args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
|
| 321 |
elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
|
sage/embedder.py
CHANGED
|
@@ -4,25 +4,17 @@ import json
|
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
import time
|
| 7 |
-
from abc import ABC
|
| 8 |
-
from abc import abstractmethod
|
| 9 |
from collections import Counter
|
| 10 |
-
from typing import Dict
|
| 11 |
-
from typing import Generator
|
| 12 |
-
from typing import List
|
| 13 |
-
from typing import Optional
|
| 14 |
-
from typing import Tuple
|
| 15 |
|
| 16 |
import google.generativeai as genai
|
| 17 |
import marqo
|
| 18 |
import requests
|
| 19 |
from openai import OpenAI
|
| 20 |
-
from tenacity import retry
|
| 21 |
-
from tenacity import stop_after_attempt
|
| 22 |
-
from tenacity import wait_random_exponential
|
| 23 |
|
| 24 |
-
from sage.chunker import Chunk
|
| 25 |
-
from sage.chunker import Chunker
|
| 26 |
from sage.constants import TEXT_FIELD
|
| 27 |
from sage.data_manager import DataManager
|
| 28 |
|
|
@@ -72,7 +64,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 72 |
|
| 73 |
if len(batch) > chunks_per_batch:
|
| 74 |
for i in range(0, len(batch), chunks_per_batch):
|
| 75 |
-
sub_batch = batch[i: i + chunks_per_batch]
|
| 76 |
openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
|
| 77 |
batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
|
| 78 |
if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
|
|
@@ -242,7 +234,7 @@ class VoyageBatchEmbedder(BatchEmbedder):
|
|
| 242 |
|
| 243 |
if len(batch) > chunks_per_batch:
|
| 244 |
for i in range(0, len(batch), chunks_per_batch):
|
| 245 |
-
sub_batch = batch[i: i + chunks_per_batch]
|
| 246 |
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 247 |
result = self._make_batch_request(sub_batch)
|
| 248 |
for chunk, datum in zip(sub_batch, result["data"]):
|
|
@@ -314,7 +306,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 314 |
|
| 315 |
if len(batch) > chunks_per_batch:
|
| 316 |
for i in range(0, len(batch), chunks_per_batch):
|
| 317 |
-
sub_batch = batch[i: i + chunks_per_batch]
|
| 318 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 319 |
self.index.add_documents(
|
| 320 |
documents=[chunk.metadata for chunk in sub_batch],
|
|
@@ -356,9 +348,8 @@ class GeminiBatchEmbedder(BatchEmbedder):
|
|
| 356 |
|
| 357 |
def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
|
| 358 |
return genai.embed_content(
|
| 359 |
-
model=self.embedding_model,
|
| 360 |
-
|
| 361 |
-
task_type="retrieval_document")
|
| 362 |
|
| 363 |
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 364 |
"""Issues batch embedding jobs for the entire dataset."""
|
|
@@ -375,7 +366,7 @@ class GeminiBatchEmbedder(BatchEmbedder):
|
|
| 375 |
|
| 376 |
if len(batch) > chunks_per_batch:
|
| 377 |
for i in range(0, len(batch), chunks_per_batch):
|
| 378 |
-
sub_batch = batch[i: i + chunks_per_batch]
|
| 379 |
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 380 |
result = self._make_batch_request(sub_batch)
|
| 381 |
for chunk, embedding in zip(sub_batch, result["embedding"]):
|
|
|
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
import time
|
| 7 |
+
from abc import ABC, abstractmethod
|
|
|
|
| 8 |
from collections import Counter
|
| 9 |
+
from typing import Dict, Generator, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
import google.generativeai as genai
|
| 12 |
import marqo
|
| 13 |
import requests
|
| 14 |
from openai import OpenAI
|
| 15 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
from sage.chunker import Chunk, Chunker
|
|
|
|
| 18 |
from sage.constants import TEXT_FIELD
|
| 19 |
from sage.data_manager import DataManager
|
| 20 |
|
|
|
|
| 64 |
|
| 65 |
if len(batch) > chunks_per_batch:
|
| 66 |
for i in range(0, len(batch), chunks_per_batch):
|
| 67 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 68 |
openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
|
| 69 |
batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
|
| 70 |
if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
|
|
|
|
| 234 |
|
| 235 |
if len(batch) > chunks_per_batch:
|
| 236 |
for i in range(0, len(batch), chunks_per_batch):
|
| 237 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 238 |
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 239 |
result = self._make_batch_request(sub_batch)
|
| 240 |
for chunk, datum in zip(sub_batch, result["data"]):
|
|
|
|
| 306 |
|
| 307 |
if len(batch) > chunks_per_batch:
|
| 308 |
for i in range(0, len(batch), chunks_per_batch):
|
| 309 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 310 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 311 |
self.index.add_documents(
|
| 312 |
documents=[chunk.metadata for chunk in sub_batch],
|
|
|
|
| 348 |
|
| 349 |
def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
|
| 350 |
return genai.embed_content(
|
| 351 |
+
model=self.embedding_model, content=[chunk.content for chunk in chunks], task_type="retrieval_document"
|
| 352 |
+
)
|
|
|
|
| 353 |
|
| 354 |
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 355 |
"""Issues batch embedding jobs for the entire dataset."""
|
|
|
|
| 366 |
|
| 367 |
if len(batch) > chunks_per_batch:
|
| 368 |
for i in range(0, len(batch), chunks_per_batch):
|
| 369 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 370 |
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 371 |
result = self._make_batch_request(sub_batch)
|
| 372 |
for chunk, embedding in zip(sub_batch, result["embedding"]):
|
sage/reranker.py
CHANGED
|
@@ -10,6 +10,7 @@ from langchain_core.documents import BaseDocumentCompressor
|
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
from langchain_voyageai import VoyageAIRerank
|
| 12 |
|
|
|
|
| 13 |
class RerankerProvider(Enum):
|
| 14 |
NONE = "none"
|
| 15 |
HUGGINGFACE = "huggingface"
|
|
|
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
from langchain_voyageai import VoyageAIRerank
|
| 12 |
|
| 13 |
+
|
| 14 |
class RerankerProvider(Enum):
|
| 15 |
NONE = "none"
|
| 16 |
HUGGINGFACE = "huggingface"
|
sage/vector_store.py
CHANGED
|
@@ -198,7 +198,7 @@ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager]
|
|
| 198 |
else:
|
| 199 |
print("punkt is not downloaded")
|
| 200 |
# Optionally download it
|
| 201 |
-
nltk.download(
|
| 202 |
corpus = [content for content, _ in data_manager.walk()]
|
| 203 |
bm25_encoder = BM25Encoder()
|
| 204 |
bm25_encoder.fit(corpus)
|
|
|
|
| 198 |
else:
|
| 199 |
print("punkt is not downloaded")
|
| 200 |
# Optionally download it
|
| 201 |
+
nltk.download("punkt_tab")
|
| 202 |
corpus = [content for content, _ in data_manager.walk()]
|
| 203 |
bm25_encoder = BM25Encoder()
|
| 204 |
bm25_encoder.fit(corpus)
|