Junhoee commited on
Commit
c40a1c4
·
verified ·
1 Parent(s): 7344c5b

Update megumin_agent/retrieval.py

Browse files
Files changed (1) hide show
  1. megumin_agent/retrieval.py +24 -15
megumin_agent/retrieval.py CHANGED
@@ -36,6 +36,7 @@ ANSWER_KEYS = (
36
  COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
37
  EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
38
  EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
 
39
 
40
 
41
  def _normalize_text(value: Any) -> str:
@@ -180,22 +181,30 @@ def _embed_texts(
180
  if not texts:
181
  return np.zeros((0, output_dimensionality), dtype="float32")
182
 
183
- response = _get_genai_client().models.embed_content(
184
- model=embedding_model,
185
- contents=texts,
186
- config=types.EmbedContentConfig(
187
- task_type=task_type,
188
- output_dimensionality=output_dimensionality,
189
- ),
190
- )
191
- vectors = np.array(
192
- [embedding.values for embedding in response.embeddings],
193
- dtype="float32",
194
- )
195
- if vectors.size == 0:
 
 
 
 
 
 
 
 
 
196
  return np.zeros((0, output_dimensionality), dtype="float32")
197
- faiss.normalize_L2(vectors)
198
- return vectors
199
 
200
 
201
  @lru_cache(maxsize=8)
 
36
  COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
37
  EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
38
  EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
39
+ EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
40
 
41
 
42
  def _normalize_text(value: Any) -> str:
 
181
  if not texts:
182
  return np.zeros((0, output_dimensionality), dtype="float32")
183
 
184
+ batches: list[np.ndarray] = []
185
+ batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100))
186
+ for start in range(0, len(texts), batch_size):
187
+ chunk = texts[start : start + batch_size]
188
+ response = _get_genai_client().models.embed_content(
189
+ model=embedding_model,
190
+ contents=chunk,
191
+ config=types.EmbedContentConfig(
192
+ task_type=task_type,
193
+ output_dimensionality=output_dimensionality,
194
+ ),
195
+ )
196
+ vectors = np.array(
197
+ [embedding.values for embedding in response.embeddings],
198
+ dtype="float32",
199
+ )
200
+ if vectors.size == 0:
201
+ continue
202
+ faiss.normalize_L2(vectors)
203
+ batches.append(vectors)
204
+
205
+ if not batches:
206
  return np.zeros((0, output_dimensionality), dtype="float32")
207
+ return np.vstack(batches)
 
208
 
209
 
210
  @lru_cache(maxsize=8)