Junhoee commited on
Commit
29cb348
·
verified ·
1 Parent(s): 021a05e

Update megumin_agent/retrieval.py

Browse files
Files changed (1) hide show
  1. megumin_agent/retrieval.py +51 -7
megumin_agent/retrieval.py CHANGED
@@ -9,6 +9,7 @@ from dataclasses import dataclass
9
  from functools import lru_cache
10
  from pathlib import Path
11
  from typing import Any
 
12
 
13
  import faiss
14
  import numpy as np
@@ -42,6 +43,8 @@ FAISS_METADATA_FILENAME = os.getenv(
42
  "MEGUMIN_FAISS_METADATA_FILENAME",
43
  "megumin_questions_meta.json",
44
  )
 
 
45
 
46
 
47
  def _normalize_text(value: Any) -> str:
@@ -58,6 +61,11 @@ def _safe_excerpt(text: str, limit: int = 220) -> str:
58
  return compact[: limit - 3].rstrip() + "..."
59
 
60
 
 
 
 
 
 
61
  @dataclass(frozen=True)
62
  class QaRecord:
63
  question: str
@@ -165,14 +173,29 @@ def _load_metadata_records(path: Path) -> tuple[QaRecord, ...]:
165
  return tuple(records)
166
 
167
 
168
- @lru_cache(maxsize=8)
169
- def _load_records(dataset_dir: str) -> tuple[QaRecord, ...]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  root = Path(dataset_dir)
171
  if not root.exists():
172
  return tuple()
173
 
174
  all_records: list[QaRecord] = []
175
- for path in sorted(root.glob("*.json")):
176
  try:
177
  all_records.extend(_load_json_records(path))
178
  except OSError:
@@ -238,9 +261,10 @@ def build_and_save_faiss_index(
238
  output_dimensionality: int = EMBEDDING_DIMENSION,
239
  index_filename: str = FAISS_INDEX_FILENAME,
240
  metadata_filename: str = FAISS_METADATA_FILENAME,
 
241
  ) -> tuple[Path, Path]:
242
  root = Path(dataset_dir)
243
- records = _load_records(str(root.resolve()))
244
  if not records:
245
  raise FileNotFoundError(f"No JSON records found under {root}")
246
 
@@ -283,9 +307,17 @@ def _load_vector_store(
283
  dataset_dir: str,
284
  embedding_model: str,
285
  output_dimensionality: int,
 
 
 
286
  ) -> VectorStore:
287
- index_path, metadata_path = _index_artifact_paths(dataset_dir)
288
- if index_path.exists() and metadata_path.exists():
 
 
 
 
 
289
  index = faiss.read_index(str(index_path))
290
  records = _load_metadata_records(metadata_path)
291
  if index.ntotal != len(records):
@@ -299,7 +331,7 @@ def _load_vector_store(
299
  dimension=index.d,
300
  )
301
 
302
- records = _load_records(dataset_dir)
303
  if not records:
304
  empty_index = faiss.IndexFlatIP(output_dimensionality)
305
  return VectorStore(
@@ -334,16 +366,25 @@ class JsonQaRetriever:
334
  *,
335
  embedding_model: str = EMBEDDING_MODEL_NAME,
336
  output_dimensionality: int = EMBEDDING_DIMENSION,
 
 
 
337
  ):
338
  self.dataset_dir = Path(dataset_dir)
339
  self.embedding_model = embedding_model
340
  self.output_dimensionality = output_dimensionality
 
 
 
341
 
342
  def warmup(self) -> None:
343
  _load_vector_store(
344
  str(self.dataset_dir.resolve()),
345
  self.embedding_model,
346
  self.output_dimensionality,
 
 
 
347
  )
348
 
349
  def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
@@ -376,6 +417,9 @@ class JsonQaRetriever:
376
  str(self.dataset_dir.resolve()),
377
  self.embedding_model,
378
  self.output_dimensionality,
 
 
 
379
  )
380
  if not store.records:
381
  return {
 
9
  from functools import lru_cache
10
  from pathlib import Path
11
  from typing import Any
12
+ from typing import Iterable
13
 
14
  import faiss
15
  import numpy as np
 
43
  "MEGUMIN_FAISS_METADATA_FILENAME",
44
  "megumin_questions_meta.json",
45
  )
46
+ PERSONA_DATASET_PATTERNS = ("megumin_qa_dataset.json",)
47
+ FACT_DATASET_PATTERNS = ("namuwiki*.json",)
48
 
49
 
50
  def _normalize_text(value: Any) -> str:
 
61
  return compact[: limit - 3].rstrip() + "..."
62
 
63
 
64
+ def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]:
65
+ normalized = tuple(pattern.strip() for pattern in (patterns or ()) if pattern.strip())
66
+ return normalized
67
+
68
+
69
  @dataclass(frozen=True)
70
  class QaRecord:
71
  question: str
 
173
  return tuple(records)
174
 
175
 
176
+ def _iter_matching_paths(root: Path, include_patterns: tuple[str, ...]) -> list[Path]:
177
+ if not include_patterns:
178
+ return sorted(root.glob("*.json"))
179
+
180
+ seen: set[Path] = set()
181
+ paths: list[Path] = []
182
+ for pattern in include_patterns:
183
+ for path in sorted(root.glob(pattern)):
184
+ if path in seen or path.suffix.lower() != ".json":
185
+ continue
186
+ seen.add(path)
187
+ paths.append(path)
188
+ return paths
189
+
190
+
191
+ @lru_cache(maxsize=16)
192
+ def _load_records(dataset_dir: str, include_patterns: tuple[str, ...] = ()) -> tuple[QaRecord, ...]:
193
  root = Path(dataset_dir)
194
  if not root.exists():
195
  return tuple()
196
 
197
  all_records: list[QaRecord] = []
198
+ for path in _iter_matching_paths(root, include_patterns):
199
  try:
200
  all_records.extend(_load_json_records(path))
201
  except OSError:
 
261
  output_dimensionality: int = EMBEDDING_DIMENSION,
262
  index_filename: str = FAISS_INDEX_FILENAME,
263
  metadata_filename: str = FAISS_METADATA_FILENAME,
264
+ include_patterns: Iterable[str] | None = None,
265
  ) -> tuple[Path, Path]:
266
  root = Path(dataset_dir)
267
+ records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
268
  if not records:
269
  raise FileNotFoundError(f"No JSON records found under {root}")
270
 
 
307
  dataset_dir: str,
308
  embedding_model: str,
309
  output_dimensionality: int,
310
+ include_patterns: tuple[str, ...] = (),
311
+ index_filename: str | None = FAISS_INDEX_FILENAME,
312
+ metadata_filename: str | None = FAISS_METADATA_FILENAME,
313
  ) -> VectorStore:
314
+ if index_filename and metadata_filename:
315
+ index_path = Path(dataset_dir) / index_filename
316
+ metadata_path = Path(dataset_dir) / metadata_filename
317
+ else:
318
+ index_path = metadata_path = None
319
+
320
+ if index_path and metadata_path and index_path.exists() and metadata_path.exists():
321
  index = faiss.read_index(str(index_path))
322
  records = _load_metadata_records(metadata_path)
323
  if index.ntotal != len(records):
 
331
  dimension=index.d,
332
  )
333
 
334
+ records = _load_records(dataset_dir, include_patterns)
335
  if not records:
336
  empty_index = faiss.IndexFlatIP(output_dimensionality)
337
  return VectorStore(
 
366
  *,
367
  embedding_model: str = EMBEDDING_MODEL_NAME,
368
  output_dimensionality: int = EMBEDDING_DIMENSION,
369
+ include_patterns: Iterable[str] | None = None,
370
+ index_filename: str | None = FAISS_INDEX_FILENAME,
371
+ metadata_filename: str | None = FAISS_METADATA_FILENAME,
372
  ):
373
  self.dataset_dir = Path(dataset_dir)
374
  self.embedding_model = embedding_model
375
  self.output_dimensionality = output_dimensionality
376
+ self.include_patterns = _normalize_patterns(include_patterns)
377
+ self.index_filename = index_filename
378
+ self.metadata_filename = metadata_filename
379
 
380
  def warmup(self) -> None:
381
  _load_vector_store(
382
  str(self.dataset_dir.resolve()),
383
  self.embedding_model,
384
  self.output_dimensionality,
385
+ self.include_patterns,
386
+ self.index_filename,
387
+ self.metadata_filename,
388
  )
389
 
390
  def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
 
417
  str(self.dataset_dir.resolve()),
418
  self.embedding_model,
419
  self.output_dimensionality,
420
+ self.include_patterns,
421
+ self.index_filename,
422
+ self.metadata_filename,
423
  )
424
  if not store.records:
425
  return {