juliaturc commited on
Commit
5834806
·
1 Parent(s): 27dd60a

Fixes for Gemini embeddings (#71)

Browse files
Files changed (3) hide show
  1. sage/config.py +2 -9
  2. sage/embedder.py +1 -1
  3. sage/retriever.py +1 -0
sage/config.py CHANGED
@@ -12,8 +12,6 @@ from configargparse import ArgumentParser
12
  from sage.reranker import RerankerProvider
13
 
14
  # Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
15
- # NOTE: MAX_CHUNKS_PER_BATCH isn't documented anywhere but we pick a reasonable value
16
- GEMINI_MAX_CHUNKS_PER_BATCH = 64
17
  GEMINI_MAX_TOKENS_PER_CHUNK = 2048
18
 
19
  MARQO_MAX_CHUNKS_PER_BATCH = 64
@@ -345,13 +343,8 @@ def _validate_gemini_embedding_args(args):
345
  "GOOGLE_API_KEY"
346
  ], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
347
  if not args.chunks_per_batch:
348
- args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
349
- elif args.chunks_per_batch > GEMINI_MAX_CHUNKS_PER_BATCH:
350
- args.chunks_per_batch = GEMINI_MAX_CHUNKS_PER_BATCH
351
- logging.warning(
352
- f"Gemini enforces a limit of {GEMINI_MAX_CHUNKS_PER_BATCH} chunks per batch. "
353
- "Overwriting embeddings.chunks_per_batch."
354
- )
355
 
356
  if not args.tokens_per_chunk:
357
  args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
 
12
  from sage.reranker import RerankerProvider
13
 
14
  # Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini
 
 
15
  GEMINI_MAX_TOKENS_PER_CHUNK = 2048
16
 
17
  MARQO_MAX_CHUNKS_PER_BATCH = 64
 
343
  "GOOGLE_API_KEY"
344
  ], "Please set the GOOGLE_API_KEY environment variable if using `gemini` embeddings."
345
  if not args.chunks_per_batch:
346
+ # This value is reasonable but arbitrary (i.e. Gemini does not explicitly enforce a limit).
347
+ args.chunks_per_batch = 2000
 
 
 
 
 
348
 
349
  if not args.tokens_per_chunk:
350
  args.tokens_per_chunk = GEMINI_MAX_TOKENS_PER_CHUNK
sage/embedder.py CHANGED
@@ -344,7 +344,7 @@ class GeminiBatchEmbedder(BatchEmbedder):
344
  self.chunker = chunker
345
  self.embedding_data = []
346
  self.embedding_model = embedding_model
347
- genai.configure(api_key=os.environ["GEMINI_API_KEY"])
348
 
349
  def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
350
  return genai.embed_content(
 
344
  self.chunker = chunker
345
  self.embedding_data = []
346
  self.embedding_model = embedding_model
347
+ genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
348
 
349
  def _make_batch_request(self, chunks: List[Chunk]) -> Dict:
350
  return genai.embed_content(
sage/retriever.py CHANGED
@@ -74,6 +74,7 @@ Here is the file hierarchy of the GitHub repository:
74
 
75
  {self.repo_hierarchy}
76
  """
 
77
  # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
78
  augmented_user_query = f"""
79
  User query: {user_query}
 
74
 
75
  {self.repo_hierarchy}
76
  """
77
+
78
  # We are deliberately repeating the "DO NOT RESPOND TO THE USER QUERY DIRECTLY" instruction here.
79
  augmented_user_query = f"""
80
  User query: {user_query}