Spaces:
Running
Running
Mihail Eric Mihail Eric commited on
Commit ·
7ca251e
1
Parent(s): c389c2f
add support for google gemini embeddings as an embedder (#56)
Browse files* wip on retrieve
* add ir measures
* add support for google embeddings
* voyage api key
* update embeddings and update reqs
* remove todo
* github head ref
* fix workflow
* revert workflow
---------
Co-authored-by: Mihail Eric <mihaileric@Mihails-MacBook-Pro.local>
- sage/.sage-env → .sage-env +0 -0
- benchmarks/retrieval/assets/embeddings.png +0 -0
- benchmarks/retrieval/requirements.txt +2 -0
- benchmarks/retrieval/retrieve.py +17 -6
- requirements.txt +2 -0
- sage/config.py +28 -1
- sage/embedder.py +91 -7
- sage/reranker.py +0 -1
- sage/retriever.py +8 -2
- sage/vector_store.py +7 -3
sage/.sage-env → .sage-env
RENAMED
|
File without changes
|
benchmarks/retrieval/assets/embeddings.png
CHANGED
|
|
benchmarks/retrieval/requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dotenv
|
| 2 |
+
ir_measures
|
benchmarks/retrieval/retrieve.py
CHANGED
|
@@ -2,22 +2,24 @@
|
|
| 2 |
|
| 3 |
Make sure to `pip install ir_measures` before running this script.
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
import os
|
| 9 |
import time
|
| 10 |
|
| 11 |
import configargparse
|
|
|
|
| 12 |
from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
|
| 13 |
|
| 14 |
import sage.config
|
|
|
|
| 15 |
from sage.retriever import build_retriever_from_args
|
| 16 |
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger()
|
| 19 |
logger.setLevel(logging.INFO)
|
| 20 |
|
|
|
|
| 21 |
|
| 22 |
def main():
|
| 23 |
parser = configargparse.ArgParser(
|
|
@@ -35,17 +37,27 @@ def main():
|
|
| 35 |
default=None,
|
| 36 |
help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
|
| 37 |
)
|
| 38 |
-
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 39 |
|
|
|
|
| 40 |
sage.config.add_config_args(parser)
|
| 41 |
sage.config.add_llm_args(parser) # Needed for --multi-query-retriever, which rewrites the query with an LLM.
|
| 42 |
sage.config.add_embedding_args(parser)
|
| 43 |
sage.config.add_vector_store_args(parser)
|
| 44 |
sage.config.add_reranking_args(parser)
|
|
|
|
|
|
|
| 45 |
args = parser.parse_args()
|
| 46 |
sage.config.validate_vector_store_args(args)
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
with open(args.benchmark, "r") as f:
|
| 51 |
benchmark = json.load(f)
|
|
@@ -70,7 +82,7 @@ def main():
|
|
| 70 |
item["retrieved"] = []
|
| 71 |
for doc_idx, doc in enumerate(retrieved):
|
| 72 |
# The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
|
| 73 |
-
# the
|
| 74 |
# score, we use 1/(doc_idx+1) since it preserves the order of the documents.
|
| 75 |
score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
|
| 76 |
retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
|
|
@@ -83,7 +95,6 @@ def main():
|
|
| 83 |
print("Calculating metrics...")
|
| 84 |
results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
|
| 85 |
results = {str(key): value for key, value in results.items()}
|
| 86 |
-
|
| 87 |
if args.logs_dir:
|
| 88 |
if not os.path.exists(args.logs_dir):
|
| 89 |
os.makedirs(args.logs_dir)
|
|
|
|
| 2 |
|
| 3 |
Make sure to `pip install ir_measures` before running this script.
|
| 4 |
"""
|
|
|
|
| 5 |
import json
|
| 6 |
import logging
|
| 7 |
import os
|
| 8 |
import time
|
| 9 |
|
| 10 |
import configargparse
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
from ir_measures import MAP, MRR, P, Qrel, R, Rprec, ScoredDoc, calc_aggregate, nDCG
|
| 13 |
|
| 14 |
import sage.config
|
| 15 |
+
from sage.data_manager import GitHubRepoManager
|
| 16 |
from sage.retriever import build_retriever_from_args
|
| 17 |
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
| 19 |
logger = logging.getLogger()
|
| 20 |
logger.setLevel(logging.INFO)
|
| 21 |
|
| 22 |
+
load_dotenv()
|
| 23 |
|
| 24 |
def main():
|
| 25 |
parser = configargparse.ArgParser(
|
|
|
|
| 37 |
default=None,
|
| 38 |
help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
|
| 39 |
)
|
|
|
|
| 40 |
|
| 41 |
+
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 42 |
sage.config.add_config_args(parser)
|
| 43 |
sage.config.add_llm_args(parser) # Needed for --multi-query-retriever, which rewrites the query with an LLM.
|
| 44 |
sage.config.add_embedding_args(parser)
|
| 45 |
sage.config.add_vector_store_args(parser)
|
| 46 |
sage.config.add_reranking_args(parser)
|
| 47 |
+
sage.config.add_repo_args(parser)
|
| 48 |
+
sage.config.add_indexing_args(parser)
|
| 49 |
args = parser.parse_args()
|
| 50 |
sage.config.validate_vector_store_args(args)
|
| 51 |
+
repo_manager = GitHubRepoManager(
|
| 52 |
+
args.repo_id,
|
| 53 |
+
commit_hash=args.commit_hash,
|
| 54 |
+
access_token=os.getenv("GITHUB_TOKEN"),
|
| 55 |
+
local_dir=args.local_dir,
|
| 56 |
+
inclusion_file=args.include,
|
| 57 |
+
exclusion_file=args.exclude,
|
| 58 |
+
)
|
| 59 |
+
repo_manager.download()
|
| 60 |
+
retriever = build_retriever_from_args(args, repo_manager)
|
| 61 |
|
| 62 |
with open(args.benchmark, "r") as f:
|
| 63 |
benchmark = json.load(f)
|
|
|
|
| 82 |
item["retrieved"] = []
|
| 83 |
for doc_idx, doc in enumerate(retrieved):
|
| 84 |
# The absolute value of the scores below does not affect the metrics; it merely determines the ranking of
|
| 85 |
+
# the retrieved documents. The key of the score varies depending on the underlying retriever. If there's no
|
| 86 |
# score, we use 1/(doc_idx+1) since it preserves the order of the documents.
|
| 87 |
score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
|
| 88 |
retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
|
|
|
|
| 95 |
print("Calculating metrics...")
|
| 96 |
results = calc_aggregate([Rprec, P @ 1, R @ 3, nDCG @ 3, MAP, MRR], golden_docs, retrieved_docs)
|
| 97 |
results = {str(key): value for key, value in results.items()}
|
|
|
|
| 98 |
if args.logs_dir:
|
| 99 |
if not os.path.exists(args.logs_dir):
|
| 100 |
os.makedirs(args.logs_dir)
|
requirements.txt
CHANGED
|
@@ -3,6 +3,7 @@ Pygments==2.18.0
|
|
| 3 |
cohere==5.9.2
|
| 4 |
configargparse
|
| 5 |
fastapi==0.112.2
|
|
|
|
| 6 |
gradio>=4.26.0
|
| 7 |
langchain==0.2.16
|
| 8 |
langchain-anthropic==0.1.23
|
|
@@ -10,6 +11,7 @@ langchain-cohere==0.2.4
|
|
| 10 |
langchain-community==0.2.17
|
| 11 |
langchain-core==0.2.41
|
| 12 |
langchain-experimental==0.0.65
|
|
|
|
| 13 |
langchain-nvidia-ai-endpoints==0.2.2
|
| 14 |
langchain-ollama==0.1.3
|
| 15 |
langchain-openai==0.1.25
|
|
|
|
| 3 |
cohere==5.9.2
|
| 4 |
configargparse
|
| 5 |
fastapi==0.112.2
|
| 6 |
+
google-ai-generativelanguage==0.6.6
|
| 7 |
gradio>=4.26.0
|
| 8 |
langchain==0.2.16
|
| 9 |
langchain-anthropic==0.1.23
|
|
|
|
| 11 |
langchain-community==0.2.17
|
| 12 |
langchain-core==0.2.41
|
| 13 |
langchain-experimental==0.0.65
|
| 14 |
+
langchain-google-genai
|
| 15 |
langchain-nvidia-ai-endpoints==0.2.2
|
| 16 |
langchain-ollama==0.1.3
|
| 17 |
langchain-openai==0.1.25
|
sage/config.py
CHANGED
|
@@ -11,6 +11,11 @@ from configargparse import ArgumentParser
|
|
| 11 |
|
| 12 |
from sage.reranker import RerankerProvider
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
MARQO_MAX_CHUNKS_PER_BATCH = 64
|
| 15 |
# The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 16 |
OPENAI_MAX_TOKENS_PER_CHUNK = 8192
|
|
@@ -82,7 +87,7 @@ def add_repo_args(parser: ArgumentParser) -> Callable:
|
|
| 82 |
|
| 83 |
def add_embedding_args(parser: ArgumentParser) -> Callable:
|
| 84 |
"""Adds embedding-related arguments to the parser and returns a validator."""
|
| 85 |
-
parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo"])
|
| 86 |
parser.add(
|
| 87 |
"--embedding-model",
|
| 88 |
type=str,
|
|
@@ -304,6 +309,26 @@ def _validate_marqo_embedding_args(args):
|
|
| 304 |
)
|
| 305 |
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def validate_embedding_args(args):
|
| 308 |
"""Validates the configuration of the batch embedder and sets defaults."""
|
| 309 |
if args.embedding_provider == "openai":
|
|
@@ -312,6 +337,8 @@ def validate_embedding_args(args):
|
|
| 312 |
_validate_voyage_embedding_args(args)
|
| 313 |
elif args.embedding_provider == "marqo":
|
| 314 |
_validate_marqo_embedding_args(args)
|
|
|
|
|
|
|
| 315 |
else:
|
| 316 |
raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
|
| 317 |
|
|
|
|
| 11 |
|
| 12 |
from sage.reranker import RerankerProvider
|
| 13 |
|
| 14 |
+
# Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
|
| 15 |
+
# NOTE: MAX_CHUNKS_PER_BATCH isn't documented anywhere but we pick a reasonable value
|
| 16 |
+
GEMINI_MAX_CHUNKS_PER_BATCH = 64
|
| 17 |
+
GEMINI_MAX_TOKENS_PER_CHUNK = 2048
|
| 18 |
+
|
| 19 |
MARQO_MAX_CHUNKS_PER_BATCH = 64
|
| 20 |
# The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 21 |
OPENAI_MAX_TOKENS_PER_CHUNK = 8192
|
|
|
|
| 87 |
|
| 88 |
def add_embedding_args(parser: ArgumentParser) -> Callable:
|
| 89 |
"""Adds embedding-related arguments to the parser and returns a validator."""
|
| 90 |
+
parser.add("--embedding-provider", default="marqo", choices=["openai", "voyage", "marqo", "gemini"])
|
| 91 |
parser.add(
|
| 92 |
"--embedding-model",
|
| 93 |
type=str,
|
|
|
|
| 309 |
)
|
| 310 |
|
| 311 |
|
| 312 |
+
def _validate_gemini_embedding_args(args):
|
| 313 |
+
"""Validates the configuration of the Gemini batch embedder and sets defaults."""
|
| 314 |
+
if not args.embedding_model:
|
| 315 |
+
args.embedding_model = "models/text-embedding-004"
|
| 316 |
+
assert os.environ["GOOGLE_API_KEY"], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
|
| 317 |
+
if not args.chunks_per_batch:
|
| 318 |
+
args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
|
| 319 |
+
elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
|
| 320 |
+
args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
|
| 321 |
+
logging.warning(
|
| 322 |
+
f"Gemini enforces a limit of {GEMINI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
|
| 323 |
+
"Overwriting embeddings.chunks_per_batch."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if not args.tokens_per_chunk:
|
| 327 |
+
args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
|
| 328 |
+
if not args.embedding_size:
|
| 329 |
+
args.embedding_size = 768
|
| 330 |
+
|
| 331 |
+
|
| 332 |
def validate_embedding_args(args):
|
| 333 |
"""Validates the configuration of the batch embedder and sets defaults."""
|
| 334 |
if args.embedding_provider == "openai":
|
|
|
|
| 337 |
_validate_voyage_embedding_args(args)
|
| 338 |
elif args.embedding_provider == "marqo":
|
| 339 |
_validate_marqo_embedding_args(args)
|
| 340 |
+
elif args.embedding_provider == "gemini":
|
| 341 |
+
_validate_gemini_embedding_args(args)
|
| 342 |
else:
|
| 343 |
raise ValueError(f"Unrecognized --embedding-provider={args.embedding_provider}")
|
| 344 |
|
sage/embedder.py
CHANGED
|
@@ -4,16 +4,25 @@ import json
|
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
import time
|
| 7 |
-
from abc import ABC
|
|
|
|
| 8 |
from collections import Counter
|
| 9 |
-
from typing import Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
| 11 |
import marqo
|
| 12 |
import requests
|
| 13 |
from openai import OpenAI
|
| 14 |
-
from tenacity import retry
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
from sage.chunker import Chunk
|
|
|
|
| 17 |
from sage.constants import TEXT_FIELD
|
| 18 |
from sage.data_manager import DataManager
|
| 19 |
|
|
@@ -63,7 +72,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 63 |
|
| 64 |
if len(batch) > chunks_per_batch:
|
| 65 |
for i in range(0, len(batch), chunks_per_batch):
|
| 66 |
-
sub_batch = batch[i
|
| 67 |
openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
|
| 68 |
batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
|
| 69 |
if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
|
|
@@ -233,7 +242,7 @@ class VoyageBatchEmbedder(BatchEmbedder):
|
|
| 233 |
|
| 234 |
if len(batch) > chunks_per_batch:
|
| 235 |
for i in range(0, len(batch), chunks_per_batch):
|
| 236 |
-
sub_batch = batch[i
|
| 237 |
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 238 |
result = self._make_batch_request(sub_batch)
|
| 239 |
for chunk, datum in zip(sub_batch, result["data"]):
|
|
@@ -305,7 +314,7 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 305 |
|
| 306 |
if len(batch) > chunks_per_batch:
|
| 307 |
for i in range(0, len(batch), chunks_per_batch):
|
| 308 |
-
sub_batch = batch[i
|
| 309 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 310 |
self.index.add_documents(
|
| 311 |
documents=[chunk.metadata for chunk in sub_batch],
|
|
@@ -335,6 +344,79 @@ class MarqoEmbedder(BatchEmbedder):
|
|
| 335 |
return []
|
| 336 |
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 339 |
if args.embedding_provider == "openai":
|
| 340 |
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
|
@@ -344,5 +426,7 @@ def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker,
|
|
| 344 |
return MarqoEmbedder(
|
| 345 |
data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
|
| 346 |
)
|
|
|
|
|
|
|
| 347 |
else:
|
| 348 |
raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
|
|
|
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
import time
|
| 7 |
+
from abc import ABC
|
| 8 |
+
from abc import abstractmethod
|
| 9 |
from collections import Counter
|
| 10 |
+
from typing import Dict
|
| 11 |
+
from typing import Generator
|
| 12 |
+
from typing import List
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from typing import Tuple
|
| 15 |
|
| 16 |
+
import google.generativeai as genai
|
| 17 |
import marqo
|
| 18 |
import requests
|
| 19 |
from openai import OpenAI
|
| 20 |
+
from tenacity import retry
|
| 21 |
+
from tenacity import stop_after_attempt
|
| 22 |
+
from tenacity import wait_random_exponential
|
| 23 |
|
| 24 |
+
from sage.chunker import Chunk
|
| 25 |
+
from sage.chunker import Chunker
|
| 26 |
from sage.constants import TEXT_FIELD
|
| 27 |
from sage.data_manager import DataManager
|
| 28 |
|
|
|
|
| 72 |
|
| 73 |
if len(batch) > chunks_per_batch:
|
| 74 |
for i in range(0, len(batch), chunks_per_batch):
|
| 75 |
+
sub_batch = batch[i: i + chunks_per_batch]
|
| 76 |
openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
|
| 77 |
batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch]
|
| 78 |
if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs:
|
|
|
|
| 242 |
|
| 243 |
if len(batch) > chunks_per_batch:
|
| 244 |
for i in range(0, len(batch), chunks_per_batch):
|
| 245 |
+
sub_batch = batch[i: i + chunks_per_batch]
|
| 246 |
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 247 |
result = self._make_batch_request(sub_batch)
|
| 248 |
for chunk, datum in zip(sub_batch, result["data"]):
|
|
|
|
| 314 |
|
| 315 |
if len(batch) > chunks_per_batch:
|
| 316 |
for i in range(0, len(batch), chunks_per_batch):
|
| 317 |
+
sub_batch = batch[i: i + chunks_per_batch]
|
| 318 |
logging.info("Indexing %d chunks...", len(sub_batch))
|
| 319 |
self.index.add_documents(
|
| 320 |
documents=[chunk.metadata for chunk in sub_batch],
|
|
|
|
| 344 |
return []
|
| 345 |
|
| 346 |
|
| 347 |
+
class GeminiBatchEmbedder(BatchEmbedder):
|
| 348 |
+
"""Batch embedder that calls Gemini."""
|
| 349 |
+
|
| 350 |
+
def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: str):
|
| 351 |
+
self.data_manager = data_manager
|
| 352 |
+
self.chunker = chunker
|
| 353 |
+
self.embedding_data = []
|
| 354 |
+
self.embedding_model = embedding_model
|
| 355 |
+
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
|
| 356 |
+
|
| 357 |
+
def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
|
| 358 |
+
return genai.embed_content(
|
| 359 |
+
model=self.embedding_model,
|
| 360 |
+
content=[chunk.content for chunk in chunks],
|
| 361 |
+
task_type="retrieval_document")
|
| 362 |
+
|
| 363 |
+
def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 364 |
+
"""Issues batch embedding jobs for the entire dataset."""
|
| 365 |
+
batch = []
|
| 366 |
+
chunk_count = 0
|
| 367 |
+
|
| 368 |
+
request_count = 0
|
| 369 |
+
last_request_time = time.time()
|
| 370 |
+
|
| 371 |
+
for content, metadata in self.data_manager.walk():
|
| 372 |
+
chunks = self.chunker.chunk(content, metadata)
|
| 373 |
+
chunk_count += len(chunks)
|
| 374 |
+
batch.extend(chunks)
|
| 375 |
+
|
| 376 |
+
if len(batch) > chunks_per_batch:
|
| 377 |
+
for i in range(0, len(batch), chunks_per_batch):
|
| 378 |
+
sub_batch = batch[i: i + chunks_per_batch]
|
| 379 |
+
logging.info("Embedding %d chunks...", len(sub_batch))
|
| 380 |
+
result = self._make_batch_request(sub_batch)
|
| 381 |
+
for chunk, embedding in zip(sub_batch, result["embedding"]):
|
| 382 |
+
self.embedding_data.append((chunk.metadata, embedding))
|
| 383 |
+
request_count += 1
|
| 384 |
+
|
| 385 |
+
# Check if we've made more than 1500 requests in the last minute
|
| 386 |
+
# Rate limits here: https://ai.google.dev/gemini-api/docs/models/gemini
|
| 387 |
+
current_time = time.time()
|
| 388 |
+
elapsed_time = current_time - last_request_time
|
| 389 |
+
if elapsed_time < 60 and request_count >= 1400:
|
| 390 |
+
logging.info("Reached rate limit, pausing for 60 seconds...")
|
| 391 |
+
time.sleep(60)
|
| 392 |
+
last_request_time = current_time
|
| 393 |
+
request_count = 0
|
| 394 |
+
# Reset the last request time and request count if more than 60 sec have passed
|
| 395 |
+
elif elapsed_time > 60:
|
| 396 |
+
last_request_time = current_time
|
| 397 |
+
request_count = 0
|
| 398 |
+
|
| 399 |
+
batch = []
|
| 400 |
+
|
| 401 |
+
# Finally, commit the last batch.
|
| 402 |
+
if batch:
|
| 403 |
+
logging.info("Embedding %d chunks...", len(batch))
|
| 404 |
+
result = self._make_batch_request(batch)
|
| 405 |
+
for chunk, embedding in zip(batch, result["embedding"]):
|
| 406 |
+
self.embedding_data.append((chunk.metadata, embedding))
|
| 407 |
+
|
| 408 |
+
logging.info(f"Successfully embedded {chunk_count} chunks.")
|
| 409 |
+
|
| 410 |
+
def embeddings_are_ready(self, *args, **kwargs) -> bool:
|
| 411 |
+
"""Checks whether the batch embedding jobs are done."""
|
| 412 |
+
return True
|
| 413 |
+
|
| 414 |
+
def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]:
|
| 415 |
+
"""Yields (chunk_metadata, embedding) pairs for each chunk in the dataset."""
|
| 416 |
+
for chunk_metadata, embedding in self.embedding_data:
|
| 417 |
+
yield chunk_metadata, embedding
|
| 418 |
+
|
| 419 |
+
|
| 420 |
def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker, args) -> BatchEmbedder:
|
| 421 |
if args.embedding_provider == "openai":
|
| 422 |
return OpenAIBatchEmbedder(data_manager, chunker, args.local_dir, args.embedding_model, args.embedding_size)
|
|
|
|
| 426 |
return MarqoEmbedder(
|
| 427 |
data_manager, chunker, index_name=args.index_namespace, url=args.marqo_url, model=args.embedding_model
|
| 428 |
)
|
| 429 |
+
elif args.embedding_provider == "gemini":
|
| 430 |
+
return GeminiBatchEmbedder(data_manager, chunker, embedding_model=args.embedding_model)
|
| 431 |
else:
|
| 432 |
raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
|
sage/reranker.py
CHANGED
|
@@ -10,7 +10,6 @@ from langchain_core.documents import BaseDocumentCompressor
|
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
from langchain_voyageai import VoyageAIRerank
|
| 12 |
|
| 13 |
-
|
| 14 |
class RerankerProvider(Enum):
|
| 15 |
NONE = "none"
|
| 16 |
HUGGINGFACE = "huggingface"
|
|
|
|
| 10 |
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
| 11 |
from langchain_voyageai import VoyageAIRerank
|
| 12 |
|
|
|
|
| 13 |
class RerankerProvider(Enum):
|
| 14 |
NONE = "none"
|
| 15 |
HUGGINGFACE = "huggingface"
|
sage/retriever.py
CHANGED
|
@@ -1,24 +1,30 @@
|
|
|
|
|
|
|
|
| 1 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 2 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
|
|
|
| 3 |
from langchain_openai import OpenAIEmbeddings
|
| 4 |
from langchain_voyageai import VoyageAIEmbeddings
|
| 5 |
|
|
|
|
| 6 |
from sage.llm import build_llm_via_langchain
|
| 7 |
from sage.reranker import build_reranker
|
| 8 |
from sage.vector_store import build_vector_store_from_args
|
| 9 |
|
| 10 |
|
| 11 |
-
def build_retriever_from_args(args):
|
| 12 |
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 13 |
|
| 14 |
if args.embedding_provider == "openai":
|
| 15 |
embeddings = OpenAIEmbeddings(model=args.embedding_model)
|
| 16 |
elif args.embedding_provider == "voyage":
|
| 17 |
embeddings = VoyageAIEmbeddings(model=args.embedding_model)
|
|
|
|
|
|
|
| 18 |
else:
|
| 19 |
embeddings = None
|
| 20 |
|
| 21 |
-
retriever = build_vector_store_from_args(args).as_retriever(
|
| 22 |
top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
|
| 23 |
)
|
| 24 |
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
from langchain.retrievers import ContextualCompressionRetriever
|
| 4 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 5 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 6 |
from langchain_openai import OpenAIEmbeddings
|
| 7 |
from langchain_voyageai import VoyageAIEmbeddings
|
| 8 |
|
| 9 |
+
from sage.data_manager import DataManager
|
| 10 |
from sage.llm import build_llm_via_langchain
|
| 11 |
from sage.reranker import build_reranker
|
| 12 |
from sage.vector_store import build_vector_store_from_args
|
| 13 |
|
| 14 |
|
| 15 |
+
def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
|
| 16 |
"""Builds a retriever (with optional reranking) from command-line arguments."""
|
| 17 |
|
| 18 |
if args.embedding_provider == "openai":
|
| 19 |
embeddings = OpenAIEmbeddings(model=args.embedding_model)
|
| 20 |
elif args.embedding_provider == "voyage":
|
| 21 |
embeddings = VoyageAIEmbeddings(model=args.embedding_model)
|
| 22 |
+
elif args.embedding_provider == "gemini":
|
| 23 |
+
embeddings = GoogleGenerativeAIEmbeddings(model=args.embedding_model)
|
| 24 |
else:
|
| 25 |
embeddings = None
|
| 26 |
|
| 27 |
+
retriever = build_vector_store_from_args(args, data_manager).as_retriever(
|
| 28 |
top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
|
| 29 |
)
|
| 30 |
|
sage/vector_store.py
CHANGED
|
@@ -75,7 +75,6 @@ class PineconeVectorStore(VectorStore):
|
|
| 75 |
self.dimension = dimension
|
| 76 |
self.client = Pinecone()
|
| 77 |
self.alpha = alpha
|
| 78 |
-
|
| 79 |
if alpha < 1.0:
|
| 80 |
if bm25_cache and os.path.exists(bm25_cache):
|
| 81 |
logging.info("Loading BM25 encoder from cache.")
|
|
@@ -192,9 +191,14 @@ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager]
|
|
| 192 |
"""
|
| 193 |
if args.vector_store_provider == "pinecone":
|
| 194 |
bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
|
| 195 |
-
|
| 196 |
-
if not os.path.exists(bm25_cache) and data_manager:
|
| 197 |
logging.info("Fitting BM25 encoder on the corpus...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
corpus = [content for content, _ in data_manager.walk()]
|
| 199 |
bm25_encoder = BM25Encoder()
|
| 200 |
bm25_encoder.fit(corpus)
|
|
|
|
| 75 |
self.dimension = dimension
|
| 76 |
self.client = Pinecone()
|
| 77 |
self.alpha = alpha
|
|
|
|
| 78 |
if alpha < 1.0:
|
| 79 |
if bm25_cache and os.path.exists(bm25_cache):
|
| 80 |
logging.info("Loading BM25 encoder from cache.")
|
|
|
|
| 191 |
"""
|
| 192 |
if args.vector_store_provider == "pinecone":
|
| 193 |
bm25_cache = os.path.join(".bm25_cache", args.index_namespace, "bm25_encoder.json")
|
| 194 |
+
if args.retrieval_alpha < 1.0 and not os.path.exists(bm25_cache) and data_manager:
|
|
|
|
| 195 |
logging.info("Fitting BM25 encoder on the corpus...")
|
| 196 |
+
if is_punkt_downloaded():
|
| 197 |
+
print("punkt is already downloaded")
|
| 198 |
+
else:
|
| 199 |
+
print("punkt is not downloaded")
|
| 200 |
+
# Optionally download it
|
| 201 |
+
nltk.download('punkt_tab')
|
| 202 |
corpus = [content for content, _ in data_manager.walk()]
|
| 203 |
bm25_encoder = BM25Encoder()
|
| 204 |
bm25_encoder.fit(corpus)
|