Spaces:
Sleeping
Sleeping
Update megumin_agent/retrieval.py
Browse files- megumin_agent/retrieval.py +24 -15
megumin_agent/retrieval.py
CHANGED
|
@@ -36,6 +36,7 @@ ANSWER_KEYS = (
|
|
| 36 |
COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
|
| 37 |
EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
|
| 38 |
EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def _normalize_text(value: Any) -> str:
|
|
@@ -180,22 +181,30 @@ def _embed_texts(
|
|
| 180 |
if not texts:
|
| 181 |
return np.zeros((0, output_dimensionality), dtype="float32")
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
return np.zeros((0, output_dimensionality), dtype="float32")
|
| 197 |
-
|
| 198 |
-
return vectors
|
| 199 |
|
| 200 |
|
| 201 |
@lru_cache(maxsize=8)
|
|
|
|
| 36 |
COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
|
| 37 |
EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
|
| 38 |
EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
|
| 39 |
+
EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
|
| 40 |
|
| 41 |
|
| 42 |
def _normalize_text(value: Any) -> str:
|
|
|
|
| 181 |
if not texts:
|
| 182 |
return np.zeros((0, output_dimensionality), dtype="float32")
|
| 183 |
|
| 184 |
+
batches: list[np.ndarray] = []
|
| 185 |
+
batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100))
|
| 186 |
+
for start in range(0, len(texts), batch_size):
|
| 187 |
+
chunk = texts[start : start + batch_size]
|
| 188 |
+
response = _get_genai_client().models.embed_content(
|
| 189 |
+
model=embedding_model,
|
| 190 |
+
contents=chunk,
|
| 191 |
+
config=types.EmbedContentConfig(
|
| 192 |
+
task_type=task_type,
|
| 193 |
+
output_dimensionality=output_dimensionality,
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
vectors = np.array(
|
| 197 |
+
[embedding.values for embedding in response.embeddings],
|
| 198 |
+
dtype="float32",
|
| 199 |
+
)
|
| 200 |
+
if vectors.size == 0:
|
| 201 |
+
continue
|
| 202 |
+
faiss.normalize_L2(vectors)
|
| 203 |
+
batches.append(vectors)
|
| 204 |
+
|
| 205 |
+
if not batches:
|
| 206 |
return np.zeros((0, output_dimensionality), dtype="float32")
|
| 207 |
+
return np.vstack(batches)
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
@lru_cache(maxsize=8)
|