Spaces:
Running
Running
Fixes for Gemini embeddings (#71)
Browse files- sage/config.py +2 -9
- sage/embedder.py +1 -1
- sage/retriever.py +1 -0
sage/config.py
CHANGED
|
@@ -12,8 +12,6 @@ from configargparse import ArgumentParser
|
|
| 12 |
from sage.reranker import RerankerProvider
|
| 13 |
|
| 14 |
# Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
|
| 15 |
-
# NOTE: MAX_CHUNKS_PER_BATCH isn't documented anywhere but we pick a reasonable value
|
| 16 |
-
GEMINI_MAX_CHUNKS_PER_BATCH = 64
|
| 17 |
GEMINI_MAX_TOKENS_PER_CHUNK = 2048
|
| 18 |
|
| 19 |
MARQO_MAX_CHUNKS_PER_BATCH = 64
|
|
@@ -345,13 +343,8 @@ def _validate_gemini_embedding_args(args):
|
|
| 345 |
"GOOGLE_API_KEY"
|
| 346 |
], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
|
| 347 |
if not args.chunks_per_batch:
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
|
| 351 |
-
logging.warning(
|
| 352 |
-
f"Gemini enforces a limit of {GEMINI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
|
| 353 |
-
"Overwriting embeddings.chunks_per_batch."
|
| 354 |
-
)
|
| 355 |
|
| 356 |
if not args.tokens_per_chunk:
|
| 357 |
args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
|
|
|
|
| 12 |
from sage.reranker import RerankerProvider
|
| 13 |
|
| 14 |
# Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
|
|
|
|
|
|
|
| 15 |
GEMINI_MAX_TOKENS_PER_CHUNK = 2048
|
| 16 |
|
| 17 |
MARQO_MAX_CHUNKS_PER_BATCH = 64
|
|
|
|
| 343 |
"GOOGLE_API_KEY"
|
| 344 |
], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
|
| 345 |
if not args.chunks_per_batch:
|
| 346 |
+
# This value is reasonable but arbitrary (i.e. Gemini does not explicitly enforce a limit).
|
| 347 |
+
args.chunks_per_batch = 2000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
if not args.tokens_per_chunk:
|
| 350 |
args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
|
sage/embedder.py
CHANGED
|
@@ -344,7 +344,7 @@ class GeminiBatchEmbedder(BatchEmbedder):
|
|
| 344 |
self.chunker = chunker
|
| 345 |
self.embedding_data = []
|
| 346 |
self.embedding_model = embedding_model
|
| 347 |
-
genai.configure(api_key=os.environ["
|
| 348 |
|
| 349 |
def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
|
| 350 |
return genai.embed_content(
|
|
|
|
| 344 |
self.chunker = chunker
|
| 345 |
self.embedding_data = []
|
| 346 |
self.embedding_model = embedding_model
|
| 347 |
+
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
|
| 348 |
|
| 349 |
def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
|
| 350 |
return genai.embed_content(
|
sage/retriever.py
CHANGED
|
@@ -74,6 +74,7 @@ Here is the file hierarchy of the GitHub repository:
|
|
| 74 |
|
| 75 |
{self.repo_hierarchy}
|
| 76 |
"""
|
|
|
|
| 77 |
# We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
|
| 78 |
augmented_user_query = f"""
|
| 79 |
User query: {user_query}
|
|
|
|
| 74 |
|
| 75 |
{self.repo_hierarchy}
|
| 76 |
"""
|
| 77 |
+
|
| 78 |
# We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
|
| 79 |
augmented_user_query = f"""
|
| 80 |
User query: {user_query}
|