Junhoee commited on
Commit
06b2015
·
verified ·
1 Parent(s): 558f57f

Upload 6 files

Browse files
megumin_agent/agent.py CHANGED
@@ -127,7 +127,7 @@ root_agent = LlmAgent(
127
  이 tool은 스타일/페르소나용 사례 top-3와 사실/설정용 사례 top-3를 5:5 비중으로 함께 돌려줍니다.
128
  persona_matches는 메구밍의 성격, 말투, 감정선, 답변 리듬을 참고하는 용도입니다.
129
  fact_matches는 설정, 관계, 사건, 세계관 사실을 참고하는 용도입니다.
130
- 두 종류의 사례를 모두 참고하되, 검색된 답변을 그대로 복사하지 마세요.
131
  검색 결과가 약하거나 없는 경우에도 메구밍 페르소나는 유지하되, 모르는 내용은 지어내지 말고 솔직하게 답하세요.
132
  최종 답변은 언제나 메구밍의 페르소나를 강하게 반영해야 하며, 내부 tool 이름이나 구현 세부사항은 드러내지 마세요.
133
  """.strip(),
 
127
  이 tool은 스타일/페르소나용 사례 top-3와 사실/설정용 사례 top-3를 5:5 비중으로 함께 돌려줍니다.
128
  persona_matches는 메구밍의 성격, 말투, 감정선, 답변 리듬을 참고하는 용도입니다.
129
  fact_matches는 설정, 관계, 사건, 세계관 사실을 참고하는 용도입니다.
130
+ 두 종류의 사례를 모두 참고하되 검색된 답변을 그대로 복사하지 마세요.
131
  검색 결과가 약하거나 없는 경우에도 메구밍 페르소나는 유지하되, 모르는 내용은 지어내지 말고 솔직하게 답하세요.
132
  최종 답변은 언제나 메구밍의 페르소나를 강하게 반영해야 하며, 내부 tool 이름이나 구현 세부사항은 드러내지 마세요.
133
  """.strip(),
megumin_agent/bootstrap.py CHANGED
@@ -1,531 +1,94 @@
1
  from __future__ import annotations
2
 
3
- import json
4
- import math
5
  import os
6
- import re
7
- import unicodedata
8
- from dataclasses import dataclass
9
- from functools import lru_cache
10
  from pathlib import Path
11
- from typing import Any
12
- from typing import Iterable
13
 
14
- import faiss
15
- import numpy as np
16
- from google import genai
17
- from google.genai import types
18
 
19
 
20
- QUESTION_KEYS = (
21
- "question",
22
- "query",
23
- "q",
24
- "prompt",
25
- "user",
26
- "instruction",
27
- "input",
28
- )
29
- ANSWER_KEYS = (
30
- "answer",
31
- "response",
32
- "a",
33
- "output",
34
- "assistant",
35
- "completion",
36
- )
37
- COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
38
- EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
39
- EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
40
- EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
41
- FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
42
- FAISS_QA_INDEX_FILENAME = os.getenv(
43
- "MEGUMIN_FAISS_QA_INDEX_FILENAME",
44
- "megumin_question_answer.faiss",
45
- )
46
- FAISS_METADATA_FILENAME = os.getenv(
47
- "MEGUMIN_FAISS_METADATA_FILENAME",
48
- "megumin_questions_meta.json",
49
- )
50
- PERSONA_DATASET_PATTERNS = ("megumin_qa_dataset.json",)
51
- FACT_DATASET_PATTERNS = ("namuwiki*.json",)
52
 
53
 
54
- def _normalize_text(value: Any) -> str:
55
- text = str(value or "")
56
- text = unicodedata.normalize("NFKC", text).strip()
57
- text = re.sub(r"\s+", " ", text)
58
- return text
59
 
60
 
61
- def _safe_excerpt(text: str, limit: int = 220) -> str:
62
- compact = re.sub(r"\s+", " ", str(text or "")).strip()
63
- if len(compact) <= limit:
64
- return compact
65
- return compact[: limit - 3].rstrip() + "..."
66
 
67
 
68
- def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]:
69
- normalized = tuple(pattern.strip() for pattern in (patterns or ()) if pattern.strip())
70
- return normalized
71
 
72
 
73
- def _record_search_text(record: "QaRecord", mode: str) -> str:
74
- if mode == "question_answer":
75
- return f"{record.question}\n{record.answer}".strip()
76
- return record.question
77
 
78
 
79
- @dataclass(frozen=True)
80
- class QaRecord:
81
- question: str
82
- answer: str
83
- source_file: str
84
- metadata: dict[str, Any]
85
 
86
- @property
87
- def normalized_question(self) -> str:
88
- return _normalize_text(self.question)
89
 
 
 
90
 
91
- @dataclass(frozen=True)
92
- class VectorStore:
93
- records: tuple[QaRecord, ...]
94
- index: faiss.Index
95
- embedding_model: str
96
- dimension: int
97
 
 
 
98
 
99
- def _extract_collection(payload: Any) -> list[Any]:
100
- if isinstance(payload, list):
101
- return payload
102
- if isinstance(payload, dict):
103
- for key in COLLECTION_KEYS:
104
- value = payload.get(key)
105
- if isinstance(value, list):
106
- return value
107
- return []
108
 
 
 
109
 
110
- def _pick_first(mapping: dict[str, Any], keys: tuple[str, ...]) -> str | None:
111
- lowered = {str(key).lower(): value for key, value in mapping.items()}
112
- for key in keys:
113
- if key in lowered and lowered[key] not in (None, ""):
114
- return str(lowered[key]).strip()
115
- return None
116
 
 
 
117
 
118
- def _record_from_mapping(item: dict[str, Any], source_file: str) -> QaRecord | None:
119
- question = _pick_first(item, QUESTION_KEYS)
120
- answer = _pick_first(item, ANSWER_KEYS)
121
- if not question or not answer:
122
- return None
123
 
124
- metadata = {
125
- key: value
126
- for key, value in item.items()
127
- if str(key).lower() not in QUESTION_KEYS + ANSWER_KEYS
128
- }
129
- return QaRecord(
130
- question=question,
131
- answer=answer,
132
- source_file=source_file,
133
- metadata=metadata,
134
- )
135
 
136
 
137
- def _load_json_records(path: Path) -> list[QaRecord]:
138
- raw_text = path.read_text(encoding="utf-8")
139
- stripped = raw_text.strip()
140
- if not stripped:
141
- return []
142
-
143
- records: list[QaRecord] = []
144
 
145
  try:
146
- payload = json.loads(stripped)
147
- except json.JSONDecodeError:
148
- payload = None
149
-
150
- if payload is not None:
151
- for item in _extract_collection(payload):
152
- if isinstance(item, dict):
153
- record = _record_from_mapping(item, path.name)
154
- if record:
155
- records.append(record)
156
- if records:
157
- return records
158
-
159
- for line in stripped.splitlines():
160
- line = line.strip()
161
- if not line:
162
- continue
163
- try:
164
- item = json.loads(line)
165
- except json.JSONDecodeError:
166
- continue
167
- if isinstance(item, dict):
168
- record = _record_from_mapping(item, path.name)
169
- if record:
170
- records.append(record)
171
-
172
- return records
173
-
174
-
175
- def _load_metadata_records(path: Path) -> tuple[QaRecord, ...]:
176
- payload = json.loads(path.read_text(encoding="utf-8"))
177
- records: list[QaRecord] = []
178
- for item in _extract_collection(payload):
179
- if isinstance(item, dict):
180
- record = _record_from_mapping(item, path.name)
181
- if record:
182
- records.append(record)
183
- return tuple(records)
184
-
185
-
186
- def _iter_matching_paths(root: Path, include_patterns: tuple[str, ...]) -> list[Path]:
187
- if not include_patterns:
188
- return sorted(root.glob("*.json"))
189
-
190
- seen: set[Path] = set()
191
- paths: list[Path] = []
192
- for pattern in include_patterns:
193
- for path in sorted(root.glob(pattern)):
194
- if path in seen or path.suffix.lower() != ".json":
195
- continue
196
- seen.add(path)
197
- paths.append(path)
198
- return paths
199
-
200
-
201
- @lru_cache(maxsize=16)
202
- def _load_records(dataset_dir: str, include_patterns: tuple[str, ...] = ()) -> tuple[QaRecord, ...]:
203
- root = Path(dataset_dir)
204
- if not root.exists():
205
- return tuple()
206
-
207
- all_records: list[QaRecord] = []
208
- for path in _iter_matching_paths(root, include_patterns):
209
- try:
210
- all_records.extend(_load_json_records(path))
211
- except OSError:
212
- continue
213
- except UnicodeDecodeError:
214
- continue
215
- return tuple(all_records)
216
-
217
-
218
- @lru_cache(maxsize=2)
219
- def _get_genai_client() -> genai.Client:
220
- return genai.Client()
221
-
222
-
223
- def _embed_texts(
224
- texts: list[str],
225
- *,
226
- task_type: str,
227
- embedding_model: str,
228
- output_dimensionality: int,
229
- ) -> np.ndarray:
230
- if not texts:
231
- return np.zeros((0, output_dimensionality), dtype="float32")
232
-
233
- batches: list[np.ndarray] = []
234
- batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100))
235
- for start in range(0, len(texts), batch_size):
236
- chunk = texts[start : start + batch_size]
237
- response = _get_genai_client().models.embed_content(
238
- model=embedding_model,
239
- contents=chunk,
240
- config=types.EmbedContentConfig(
241
- task_type=task_type,
242
- output_dimensionality=output_dimensionality,
243
- ),
244
  )
245
- vectors = np.array(
246
- [embedding.values for embedding in response.embeddings],
247
- dtype="float32",
248
- )
249
- if vectors.size == 0:
250
- continue
251
- faiss.normalize_L2(vectors)
252
- batches.append(vectors)
253
-
254
- if not batches:
255
- return np.zeros((0, output_dimensionality), dtype="float32")
256
- return np.vstack(batches)
257
-
258
-
259
- def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]:
260
- root = Path(dataset_dir)
261
- return (
262
- root / FAISS_INDEX_FILENAME,
263
- root / FAISS_METADATA_FILENAME,
264
- )
265
-
266
-
267
- def _build_index_from_records(
268
- records: tuple[QaRecord, ...],
269
- *,
270
- embedding_model: str,
271
- output_dimensionality: int,
272
- mode: str,
273
- ) -> faiss.IndexFlatIP:
274
- search_texts = [_record_search_text(record, mode) for record in records]
275
- vectors = _embed_texts(
276
- search_texts,
277
- task_type="RETRIEVAL_DOCUMENT",
278
- embedding_model=embedding_model,
279
- output_dimensionality=output_dimensionality,
280
- )
281
- if vectors.size == 0:
282
- raise RuntimeError("No embeddings were generated for the dataset records.")
283
-
284
- index = faiss.IndexFlatIP(int(vectors.shape[1]))
285
- index.add(vectors)
286
- return index
287
-
288
-
289
- def build_and_save_faiss_index(
290
- dataset_dir: str | Path,
291
- *,
292
- embedding_model: str = EMBEDDING_MODEL_NAME,
293
- output_dimensionality: int = EMBEDDING_DIMENSION,
294
- index_filename: str = FAISS_INDEX_FILENAME,
295
- qa_index_filename: str = FAISS_QA_INDEX_FILENAME,
296
- metadata_filename: str = FAISS_METADATA_FILENAME,
297
- include_patterns: Iterable[str] | None = None,
298
- ) -> tuple[Path, Path, Path]:
299
- root = Path(dataset_dir)
300
- records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
301
- if not records:
302
- raise FileNotFoundError(f"No JSON records found under {root}")
303
-
304
- question_index = _build_index_from_records(
305
- records,
306
- embedding_model=embedding_model,
307
- output_dimensionality=output_dimensionality,
308
- mode="question",
309
- )
310
- qa_index = _build_index_from_records(
311
- records,
312
- embedding_model=embedding_model,
313
- output_dimensionality=output_dimensionality,
314
- mode="question_answer",
315
- )
316
- index_path = root / index_filename
317
- qa_index_path = root / qa_index_filename
318
- metadata_path = root / metadata_filename
319
- faiss.write_index(question_index, str(index_path))
320
- faiss.write_index(qa_index, str(qa_index_path))
321
- metadata_payload = {
322
- "items": [
323
- {
324
- "question": record.question,
325
- "answer": record.answer,
326
- "source_file": record.source_file,
327
- **record.metadata,
328
- }
329
- for record in records
330
- ]
331
- }
332
- metadata_path.write_text(
333
- json.dumps(metadata_payload, ensure_ascii=False, indent=2),
334
- encoding="utf-8",
335
- )
336
- return index_path, qa_index_path, metadata_path
337
-
338
-
339
- @lru_cache(maxsize=8)
340
- def _load_vector_store(
341
- dataset_dir: str,
342
- embedding_model: str,
343
- output_dimensionality: int,
344
- include_patterns: tuple[str, ...] = (),
345
- index_filename: str | None = FAISS_INDEX_FILENAME,
346
- qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
347
- metadata_filename: str | None = FAISS_METADATA_FILENAME,
348
- mode: str = "question",
349
- ) -> VectorStore:
350
- selected_index_filename = index_filename if mode == "question" else qa_index_filename
351
- if selected_index_filename and metadata_filename:
352
- index_path = Path(dataset_dir) / selected_index_filename
353
- metadata_path = Path(dataset_dir) / metadata_filename
354
- else:
355
- index_path = metadata_path = None
356
-
357
- if index_path and metadata_path and index_path.exists() and metadata_path.exists():
358
- index = faiss.read_index(str(index_path))
359
- records = _load_metadata_records(metadata_path)
360
- if index.ntotal != len(records):
361
- raise ValueError(
362
- f"FAISS index size ({index.ntotal}) does not match metadata size ({len(records)})."
363
- )
364
- return VectorStore(
365
- records=records,
366
- index=index,
367
- embedding_model=embedding_model,
368
- dimension=index.d,
369
- )
370
-
371
- records = _load_records(dataset_dir, include_patterns)
372
- if not records:
373
- empty_index = faiss.IndexFlatIP(output_dimensionality)
374
- return VectorStore(
375
- records=tuple(),
376
- index=empty_index,
377
- embedding_model=embedding_model,
378
- dimension=output_dimensionality,
379
- )
380
-
381
- index = _build_index_from_records(
382
- records,
383
- embedding_model=embedding_model,
384
- output_dimensionality=output_dimensionality,
385
- mode=mode,
386
- )
387
- return VectorStore(
388
- records=records,
389
- index=index,
390
- embedding_model=embedding_model,
391
- dimension=index.d,
392
- )
393
-
394
-
395
- class JsonQaRetriever:
396
- def __init__(
397
- self,
398
- dataset_dir: str | Path,
399
- *,
400
- embedding_model: str = EMBEDDING_MODEL_NAME,
401
- output_dimensionality: int = EMBEDDING_DIMENSION,
402
- include_patterns: Iterable[str] | None = None,
403
- index_filename: str | None = FAISS_INDEX_FILENAME,
404
- qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
405
- metadata_filename: str | None = FAISS_METADATA_FILENAME,
406
- ):
407
- self.dataset_dir = Path(dataset_dir)
408
- self.embedding_model = embedding_model
409
- self.output_dimensionality = output_dimensionality
410
- self.include_patterns = _normalize_patterns(include_patterns)
411
- self.index_filename = index_filename
412
- self.qa_index_filename = qa_index_filename
413
- self.metadata_filename = metadata_filename
414
-
415
- def warmup(self) -> None:
416
- _load_vector_store(
417
- str(self.dataset_dir.resolve()),
418
- self.embedding_model,
419
- self.output_dimensionality,
420
- self.include_patterns,
421
- self.index_filename,
422
- self.qa_index_filename,
423
- self.metadata_filename,
424
- "question",
425
- )
426
- _load_vector_store(
427
- str(self.dataset_dir.resolve()),
428
- self.embedding_model,
429
- self.output_dimensionality,
430
- self.include_patterns,
431
- self.index_filename,
432
- self.qa_index_filename,
433
- self.metadata_filename,
434
- "question_answer",
435
- )
436
-
437
- def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
438
- if not matches:
439
- return [
440
- "No strong example was retrieved, so stay in Megumin's persona without inventing unsupported canon facts.",
441
- ]
442
-
443
- notes = [
444
- "Answer in first person as Megumin, with respectful but dramatic confidence.",
445
- "Use the retrieved cases to mirror tone and answer shape, but do not copy them verbatim.",
446
- "Prefer the retrieved answers as evidence for facts, relationships, and recurring phrasing.",
447
- ]
448
-
449
- long_answers = sum(
450
- 1 for match in matches if len(match.get("answer", "")) >= 180
451
- )
452
- if long_answers >= max(1, math.ceil(len(matches) / 2)):
453
- notes.append(
454
- "The retrieved examples skew narrative, so a short anecdotal lead-in is acceptable."
455
- )
456
- else:
457
- notes.append(
458
- "The retrieved examples are compact, so keep the answer concise and pointed."
459
- )
460
- return notes
461
-
462
- def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
463
- question_store = _load_vector_store(
464
- str(self.dataset_dir.resolve()),
465
- self.embedding_model,
466
- self.output_dimensionality,
467
- self.include_patterns,
468
- self.index_filename,
469
- self.qa_index_filename,
470
- self.metadata_filename,
471
- "question",
472
- )
473
- qa_store = _load_vector_store(
474
- str(self.dataset_dir.resolve()),
475
- self.embedding_model,
476
- self.output_dimensionality,
477
- self.include_patterns,
478
- self.index_filename,
479
- self.qa_index_filename,
480
- self.metadata_filename,
481
- "question_answer",
482
- )
483
- if not question_store.records:
484
- return {
485
- "query": query,
486
- "match_count": 0,
487
- "matches": [],
488
- "style_notes": [
489
- "No processed JSON dataset was found for retrieval.",
490
- ],
491
- }
492
-
493
- query_vector = _embed_texts(
494
- [_normalize_text(query) or query],
495
- task_type="RETRIEVAL_QUERY",
496
- embedding_model=question_store.embedding_model,
497
- output_dimensionality=question_store.dimension,
498
- )
499
- search_k = max(1, min(top_k, len(question_store.records)))
500
-
501
- candidates: dict[int, dict[str, Any]] = {}
502
- for store_name, store in (("question", question_store), ("question_answer", qa_store)):
503
- scores, indices = store.index.search(query_vector, search_k)
504
- for score, index in zip(scores[0], indices[0]):
505
- if index < 0:
506
  continue
507
- record = store.records[int(index)]
508
- current = candidates.get(int(index))
509
- score_value = round(float(score), 6)
510
- if current is None or score_value > current["score"]:
511
- candidates[int(index)] = {
512
- "question": record.question,
513
- "answer": _safe_excerpt(record.answer),
514
- "score": score_value,
515
- "source_file": record.source_file,
516
- "metadata": record.metadata,
517
- "matched_via": store_name,
518
- }
519
-
520
- matches = sorted(
521
- candidates.values(),
522
- key=lambda item: item["score"],
523
- reverse=True,
524
- )[:top_k]
525
-
526
- return {
527
- "query": query,
528
- "match_count": len(matches),
529
- "matches": matches,
530
- "style_notes": self._style_notes(matches),
531
- }
 
1
  from __future__ import annotations
2
 
 
 
3
  import os
4
+ import sys
 
 
 
5
  from pathlib import Path
 
 
6
 
7
+ from dotenv import load_dotenv
8
+ from huggingface_hub import hf_hub_download
 
 
9
 
10
 
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ ADK_SRC = PROJECT_ROOT / "adk-python" / "src"
13
+ LOCAL_DATASET_DIR = PROJECT_ROOT / "data" / "processed"
14
+ RUNTIME_DATASET_DIR = PROJECT_ROOT / "data" / "_runtime_processed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
+ def _dataset_repo_id() -> str:
18
+ return os.getenv("MEGUMIN_HF_DATASET_REPO_ID", "Junhoee/megumin-chat")
 
 
 
19
 
20
 
21
+ def _dataset_filename() -> str:
22
+ return os.getenv("MEGUMIN_HF_DATASET_FILENAME", "megumin_qa_dataset.json")
 
 
 
23
 
24
 
25
+ def _index_filename() -> str:
26
+ return os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
 
27
 
28
 
29
+ def _qa_index_filename() -> str:
30
+ return os.getenv("MEGUMIN_FAISS_QA_INDEX_FILENAME", "megumin_question_answer.faiss")
 
 
31
 
32
 
33
+ def _metadata_filename() -> str:
34
+ return os.getenv("MEGUMIN_FAISS_METADATA_FILENAME", "megumin_questions_meta.json")
 
 
 
 
35
 
 
 
 
36
 
37
+ def _fact_dataset_filename() -> str:
38
+ return os.getenv("MEGUMIN_HF_FACT_DATASET_FILENAME", "namuwiki_qa.json")
39
 
 
 
 
 
 
 
40
 
41
+ def _fact_index_filename() -> str:
42
+ return os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
43
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def _fact_qa_index_filename() -> str:
46
+ return os.getenv("MEGUMIN_HF_FACT_QA_INDEX_FILENAME", "namuwiki_question_answer.faiss")
47
 
 
 
 
 
 
 
48
 
49
+ def _fact_metadata_filename() -> str:
50
+ return os.getenv("MEGUMIN_HF_FACT_METADATA_FILENAME", "namuwiki_questions_meta.json")
51
 
 
 
 
 
 
52
 
53
+ def bootstrap_environment() -> None:
54
+ load_dotenv(PROJECT_ROOT / ".env", override=True)
55
+ if ADK_SRC.exists():
56
+ adk_src = str(ADK_SRC)
57
+ if adk_src not in sys.path:
58
+ sys.path.insert(0, adk_src)
 
 
 
 
 
59
 
60
 
61
+ def resolve_dataset_dir() -> Path:
62
+ RUNTIME_DATASET_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
63
 
64
  try:
65
+ hf_token = os.getenv("HF_TOKEN") or None
66
+ repo_id = _dataset_repo_id()
67
+ artifact_names = (
68
+ _dataset_filename(),
69
+ _index_filename(),
70
+ _qa_index_filename(),
71
+ _metadata_filename(),
72
+ _fact_dataset_filename(),
73
+ _fact_index_filename(),
74
+ _fact_qa_index_filename(),
75
+ _fact_metadata_filename(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
+ for artifact_name in artifact_names:
78
+ try:
79
+ hf_hub_download(
80
+ repo_id=repo_id,
81
+ repo_type="dataset",
82
+ filename=artifact_name,
83
+ token=hf_token,
84
+ local_dir=str(RUNTIME_DATASET_DIR),
85
+ )
86
+ except Exception:
87
+ if artifact_name not in {_dataset_filename(), _fact_dataset_filename()}:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  continue
89
+ raise
90
+ return RUNTIME_DATASET_DIR
91
+ except Exception:
92
+ if LOCAL_DATASET_DIR.exists() and any(LOCAL_DATASET_DIR.glob("*.json")):
93
+ return LOCAL_DATASET_DIR
94
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megumin_agent/retrieval.py CHANGED
@@ -39,6 +39,10 @@ EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-00
39
  EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
40
  EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
41
  FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
 
 
 
 
42
  FAISS_METADATA_FILENAME = os.getenv(
43
  "MEGUMIN_FAISS_METADATA_FILENAME",
44
  "megumin_questions_meta.json",
@@ -66,6 +70,12 @@ def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]:
66
  return normalized
67
 
68
 
 
 
 
 
 
 
69
  @dataclass(frozen=True)
70
  class QaRecord:
71
  question: str
@@ -254,36 +264,60 @@ def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]:
254
  )
255
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def build_and_save_faiss_index(
258
  dataset_dir: str | Path,
259
  *,
260
  embedding_model: str = EMBEDDING_MODEL_NAME,
261
  output_dimensionality: int = EMBEDDING_DIMENSION,
262
  index_filename: str = FAISS_INDEX_FILENAME,
 
263
  metadata_filename: str = FAISS_METADATA_FILENAME,
264
  include_patterns: Iterable[str] | None = None,
265
- ) -> tuple[Path, Path]:
266
  root = Path(dataset_dir)
267
  records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
268
  if not records:
269
  raise FileNotFoundError(f"No JSON records found under {root}")
270
 
271
- questions = [record.normalized_question or record.question for record in records]
272
- question_vectors = _embed_texts(
273
- questions,
274
- task_type="RETRIEVAL_DOCUMENT",
275
  embedding_model=embedding_model,
276
  output_dimensionality=output_dimensionality,
 
 
 
 
 
 
 
277
  )
278
- if question_vectors.size == 0:
279
- raise RuntimeError("No embeddings were generated for the dataset questions.")
280
-
281
- index = faiss.IndexFlatIP(int(question_vectors.shape[1]))
282
- index.add(question_vectors)
283
-
284
  index_path = root / index_filename
 
285
  metadata_path = root / metadata_filename
286
- faiss.write_index(index, str(index_path))
 
287
  metadata_payload = {
288
  "items": [
289
  {
@@ -299,7 +333,7 @@ def build_and_save_faiss_index(
299
  json.dumps(metadata_payload, ensure_ascii=False, indent=2),
300
  encoding="utf-8",
301
  )
302
- return index_path, metadata_path
303
 
304
 
305
  @lru_cache(maxsize=8)
@@ -309,10 +343,13 @@ def _load_vector_store(
309
  output_dimensionality: int,
310
  include_patterns: tuple[str, ...] = (),
311
  index_filename: str | None = FAISS_INDEX_FILENAME,
 
312
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
 
313
  ) -> VectorStore:
314
- if index_filename and metadata_filename:
315
- index_path = Path(dataset_dir) / index_filename
 
316
  metadata_path = Path(dataset_dir) / metadata_filename
317
  else:
318
  index_path = metadata_path = None
@@ -341,21 +378,17 @@ def _load_vector_store(
341
  dimension=output_dimensionality,
342
  )
343
 
344
- questions = [record.normalized_question or record.question for record in records]
345
- question_vectors = _embed_texts(
346
- questions,
347
- task_type="RETRIEVAL_DOCUMENT",
348
  embedding_model=embedding_model,
349
  output_dimensionality=output_dimensionality,
 
350
  )
351
- dimension = int(question_vectors.shape[1])
352
- index = faiss.IndexFlatIP(dimension)
353
- index.add(question_vectors)
354
  return VectorStore(
355
  records=records,
356
  index=index,
357
  embedding_model=embedding_model,
358
- dimension=dimension,
359
  )
360
 
361
 
@@ -368,6 +401,7 @@ class JsonQaRetriever:
368
  output_dimensionality: int = EMBEDDING_DIMENSION,
369
  include_patterns: Iterable[str] | None = None,
370
  index_filename: str | None = FAISS_INDEX_FILENAME,
 
371
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
372
  ):
373
  self.dataset_dir = Path(dataset_dir)
@@ -375,6 +409,7 @@ class JsonQaRetriever:
375
  self.output_dimensionality = output_dimensionality
376
  self.include_patterns = _normalize_patterns(include_patterns)
377
  self.index_filename = index_filename
 
378
  self.metadata_filename = metadata_filename
379
 
380
  def warmup(self) -> None:
@@ -384,7 +419,19 @@ class JsonQaRetriever:
384
  self.output_dimensionality,
385
  self.include_patterns,
386
  self.index_filename,
 
 
 
 
 
 
 
 
 
 
 
387
  self.metadata_filename,
 
388
  )
389
 
390
  def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
@@ -413,15 +460,27 @@ class JsonQaRetriever:
413
  return notes
414
 
415
  def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
416
- store = _load_vector_store(
417
  str(self.dataset_dir.resolve()),
418
  self.embedding_model,
419
  self.output_dimensionality,
420
  self.include_patterns,
421
  self.index_filename,
 
422
  self.metadata_filename,
 
423
  )
424
- if not store.records:
 
 
 
 
 
 
 
 
 
 
425
  return {
426
  "query": query,
427
  "match_count": 0,
@@ -434,26 +493,35 @@ class JsonQaRetriever:
434
  query_vector = _embed_texts(
435
  [_normalize_text(query) or query],
436
  task_type="RETRIEVAL_QUERY",
437
- embedding_model=store.embedding_model,
438
- output_dimensionality=store.dimension,
439
  )
440
- search_k = max(1, min(top_k, len(store.records)))
441
- scores, indices = store.index.search(query_vector, search_k)
442
-
443
- matches: list[dict[str, Any]] = []
444
- for score, index in zip(scores[0], indices[0]):
445
- if index < 0:
446
- continue
447
- record = store.records[int(index)]
448
- matches.append(
449
- {
450
- "question": record.question,
451
- "answer": _safe_excerpt(record.answer),
452
- "score": round(float(score), 6),
453
- "source_file": record.source_file,
454
- "metadata": record.metadata,
455
- }
456
- )
 
 
 
 
 
 
 
 
 
457
 
458
  return {
459
  "query": query,
 
39
  EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
40
  EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
41
  FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
42
+ FAISS_QA_INDEX_FILENAME = os.getenv(
43
+ "MEGUMIN_FAISS_QA_INDEX_FILENAME",
44
+ "megumin_question_answer.faiss",
45
+ )
46
  FAISS_METADATA_FILENAME = os.getenv(
47
  "MEGUMIN_FAISS_METADATA_FILENAME",
48
  "megumin_questions_meta.json",
 
70
  return normalized
71
 
72
 
73
+ def _record_search_text(record: "QaRecord", mode: str) -> str:
74
+ if mode == "question_answer":
75
+ return f"{record.question}\n{record.answer}".strip()
76
+ return record.question
77
+
78
+
79
  @dataclass(frozen=True)
80
  class QaRecord:
81
  question: str
 
264
  )
265
 
266
 
267
+ def _build_index_from_records(
268
+ records: tuple[QaRecord, ...],
269
+ *,
270
+ embedding_model: str,
271
+ output_dimensionality: int,
272
+ mode: str,
273
+ ) -> faiss.IndexFlatIP:
274
+ search_texts = [_record_search_text(record, mode) for record in records]
275
+ vectors = _embed_texts(
276
+ search_texts,
277
+ task_type="RETRIEVAL_DOCUMENT",
278
+ embedding_model=embedding_model,
279
+ output_dimensionality=output_dimensionality,
280
+ )
281
+ if vectors.size == 0:
282
+ raise RuntimeError("No embeddings were generated for the dataset records.")
283
+
284
+ index = faiss.IndexFlatIP(int(vectors.shape[1]))
285
+ index.add(vectors)
286
+ return index
287
+
288
+
289
  def build_and_save_faiss_index(
290
  dataset_dir: str | Path,
291
  *,
292
  embedding_model: str = EMBEDDING_MODEL_NAME,
293
  output_dimensionality: int = EMBEDDING_DIMENSION,
294
  index_filename: str = FAISS_INDEX_FILENAME,
295
+ qa_index_filename: str = FAISS_QA_INDEX_FILENAME,
296
  metadata_filename: str = FAISS_METADATA_FILENAME,
297
  include_patterns: Iterable[str] | None = None,
298
+ ) -> tuple[Path, Path, Path]:
299
  root = Path(dataset_dir)
300
  records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
301
  if not records:
302
  raise FileNotFoundError(f"No JSON records found under {root}")
303
 
304
+ question_index = _build_index_from_records(
305
+ records,
 
 
306
  embedding_model=embedding_model,
307
  output_dimensionality=output_dimensionality,
308
+ mode="question",
309
+ )
310
+ qa_index = _build_index_from_records(
311
+ records,
312
+ embedding_model=embedding_model,
313
+ output_dimensionality=output_dimensionality,
314
+ mode="question_answer",
315
  )
 
 
 
 
 
 
316
  index_path = root / index_filename
317
+ qa_index_path = root / qa_index_filename
318
  metadata_path = root / metadata_filename
319
+ faiss.write_index(question_index, str(index_path))
320
+ faiss.write_index(qa_index, str(qa_index_path))
321
  metadata_payload = {
322
  "items": [
323
  {
 
333
  json.dumps(metadata_payload, ensure_ascii=False, indent=2),
334
  encoding="utf-8",
335
  )
336
+ return index_path, qa_index_path, metadata_path
337
 
338
 
339
  @lru_cache(maxsize=8)
 
343
  output_dimensionality: int,
344
  include_patterns: tuple[str, ...] = (),
345
  index_filename: str | None = FAISS_INDEX_FILENAME,
346
+ qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
347
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
348
+ mode: str = "question",
349
  ) -> VectorStore:
350
+ selected_index_filename = index_filename if mode == "question" else qa_index_filename
351
+ if selected_index_filename and metadata_filename:
352
+ index_path = Path(dataset_dir) / selected_index_filename
353
  metadata_path = Path(dataset_dir) / metadata_filename
354
  else:
355
  index_path = metadata_path = None
 
378
  dimension=output_dimensionality,
379
  )
380
 
381
+ index = _build_index_from_records(
382
+ records,
 
 
383
  embedding_model=embedding_model,
384
  output_dimensionality=output_dimensionality,
385
+ mode=mode,
386
  )
 
 
 
387
  return VectorStore(
388
  records=records,
389
  index=index,
390
  embedding_model=embedding_model,
391
+ dimension=index.d,
392
  )
393
 
394
 
 
401
  output_dimensionality: int = EMBEDDING_DIMENSION,
402
  include_patterns: Iterable[str] | None = None,
403
  index_filename: str | None = FAISS_INDEX_FILENAME,
404
+ qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
405
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
406
  ):
407
  self.dataset_dir = Path(dataset_dir)
 
409
  self.output_dimensionality = output_dimensionality
410
  self.include_patterns = _normalize_patterns(include_patterns)
411
  self.index_filename = index_filename
412
+ self.qa_index_filename = qa_index_filename
413
  self.metadata_filename = metadata_filename
414
 
415
  def warmup(self) -> None:
 
419
  self.output_dimensionality,
420
  self.include_patterns,
421
  self.index_filename,
422
+ self.qa_index_filename,
423
+ self.metadata_filename,
424
+ "question",
425
+ )
426
+ _load_vector_store(
427
+ str(self.dataset_dir.resolve()),
428
+ self.embedding_model,
429
+ self.output_dimensionality,
430
+ self.include_patterns,
431
+ self.index_filename,
432
+ self.qa_index_filename,
433
  self.metadata_filename,
434
+ "question_answer",
435
  )
436
 
437
  def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
 
460
  return notes
461
 
462
  def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
463
+ question_store = _load_vector_store(
464
  str(self.dataset_dir.resolve()),
465
  self.embedding_model,
466
  self.output_dimensionality,
467
  self.include_patterns,
468
  self.index_filename,
469
+ self.qa_index_filename,
470
  self.metadata_filename,
471
+ "question",
472
  )
473
+ qa_store = _load_vector_store(
474
+ str(self.dataset_dir.resolve()),
475
+ self.embedding_model,
476
+ self.output_dimensionality,
477
+ self.include_patterns,
478
+ self.index_filename,
479
+ self.qa_index_filename,
480
+ self.metadata_filename,
481
+ "question_answer",
482
+ )
483
+ if not question_store.records:
484
  return {
485
  "query": query,
486
  "match_count": 0,
 
493
  query_vector = _embed_texts(
494
  [_normalize_text(query) or query],
495
  task_type="RETRIEVAL_QUERY",
496
+ embedding_model=question_store.embedding_model,
497
+ output_dimensionality=question_store.dimension,
498
  )
499
+ search_k = max(1, min(top_k, len(question_store.records)))
500
+
501
+ candidates: dict[int, dict[str, Any]] = {}
502
+ for store_name, store in (("question", question_store), ("question_answer", qa_store)):
503
+ scores, indices = store.index.search(query_vector, search_k)
504
+ for score, index in zip(scores[0], indices[0]):
505
+ if index < 0:
506
+ continue
507
+ record = store.records[int(index)]
508
+ current = candidates.get(int(index))
509
+ score_value = round(float(score), 6)
510
+ if current is None or score_value > current["score"]:
511
+ candidates[int(index)] = {
512
+ "question": record.question,
513
+ "answer": _safe_excerpt(record.answer),
514
+ "score": score_value,
515
+ "source_file": record.source_file,
516
+ "metadata": record.metadata,
517
+ "matched_via": store_name,
518
+ }
519
+
520
+ matches = sorted(
521
+ candidates.values(),
522
+ key=lambda item: item["score"],
523
+ reverse=True,
524
+ )[:top_k]
525
 
526
  return {
527
  "query": query,