Megumin-chat / megumin_agent /retrieval.py
Junhoee's picture
Upload 6 files
06b2015 verified
from __future__ import annotations
import json
import math
import os
import re
import unicodedata
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Iterable
import faiss
import numpy as np
from google import genai
from google.genai import types
QUESTION_KEYS = (
"question",
"query",
"q",
"prompt",
"user",
"instruction",
"input",
)
ANSWER_KEYS = (
"answer",
"response",
"a",
"output",
"assistant",
"completion",
)
COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
FAISS_QA_INDEX_FILENAME = os.getenv(
"MEGUMIN_FAISS_QA_INDEX_FILENAME",
"megumin_question_answer.faiss",
)
FAISS_METADATA_FILENAME = os.getenv(
"MEGUMIN_FAISS_METADATA_FILENAME",
"megumin_questions_meta.json",
)
PERSONA_DATASET_PATTERNS = ("megumin_qa_dataset.json",)
FACT_DATASET_PATTERNS = ("namuwiki*.json",)
def _normalize_text(value: Any) -> str:
text = str(value or "")
text = unicodedata.normalize("NFKC", text).strip()
text = re.sub(r"\s+", " ", text)
return text
def _safe_excerpt(text: str, limit: int = 220) -> str:
compact = re.sub(r"\s+", " ", str(text or "")).strip()
if len(compact) <= limit:
return compact
return compact[: limit - 3].rstrip() + "..."
def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]:
normalized = tuple(pattern.strip() for pattern in (patterns or ()) if pattern.strip())
return normalized
def _record_search_text(record: "QaRecord", mode: str) -> str:
if mode == "question_answer":
return f"{record.question}\n{record.answer}".strip()
return record.question
@dataclass(frozen=True)
class QaRecord:
question: str
answer: str
source_file: str
metadata: dict[str, Any]
@property
def normalized_question(self) -> str:
return _normalize_text(self.question)
@dataclass(frozen=True)
class VectorStore:
records: tuple[QaRecord, ...]
index: faiss.Index
embedding_model: str
dimension: int
def _extract_collection(payload: Any) -> list[Any]:
if isinstance(payload, list):
return payload
if isinstance(payload, dict):
for key in COLLECTION_KEYS:
value = payload.get(key)
if isinstance(value, list):
return value
return []
def _pick_first(mapping: dict[str, Any], keys: tuple[str, ...]) -> str | None:
lowered = {str(key).lower(): value for key, value in mapping.items()}
for key in keys:
if key in lowered and lowered[key] not in (None, ""):
return str(lowered[key]).strip()
return None
def _record_from_mapping(item: dict[str, Any], source_file: str) -> QaRecord | None:
question = _pick_first(item, QUESTION_KEYS)
answer = _pick_first(item, ANSWER_KEYS)
if not question or not answer:
return None
metadata = {
key: value
for key, value in item.items()
if str(key).lower() not in QUESTION_KEYS + ANSWER_KEYS
}
return QaRecord(
question=question,
answer=answer,
source_file=source_file,
metadata=metadata,
)
def _load_json_records(path: Path) -> list[QaRecord]:
raw_text = path.read_text(encoding="utf-8")
stripped = raw_text.strip()
if not stripped:
return []
records: list[QaRecord] = []
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
payload = None
if payload is not None:
for item in _extract_collection(payload):
if isinstance(item, dict):
record = _record_from_mapping(item, path.name)
if record:
records.append(record)
if records:
return records
for line in stripped.splitlines():
line = line.strip()
if not line:
continue
try:
item = json.loads(line)
except json.JSONDecodeError:
continue
if isinstance(item, dict):
record = _record_from_mapping(item, path.name)
if record:
records.append(record)
return records
def _load_metadata_records(path: Path) -> tuple[QaRecord, ...]:
payload = json.loads(path.read_text(encoding="utf-8"))
records: list[QaRecord] = []
for item in _extract_collection(payload):
if isinstance(item, dict):
record = _record_from_mapping(item, path.name)
if record:
records.append(record)
return tuple(records)
def _iter_matching_paths(root: Path, include_patterns: tuple[str, ...]) -> list[Path]:
if not include_patterns:
return sorted(root.glob("*.json"))
seen: set[Path] = set()
paths: list[Path] = []
for pattern in include_patterns:
for path in sorted(root.glob(pattern)):
if path in seen or path.suffix.lower() != ".json":
continue
seen.add(path)
paths.append(path)
return paths
@lru_cache(maxsize=16)
def _load_records(dataset_dir: str, include_patterns: tuple[str, ...] = ()) -> tuple[QaRecord, ...]:
root = Path(dataset_dir)
if not root.exists():
return tuple()
all_records: list[QaRecord] = []
for path in _iter_matching_paths(root, include_patterns):
try:
all_records.extend(_load_json_records(path))
except OSError:
continue
except UnicodeDecodeError:
continue
return tuple(all_records)
@lru_cache(maxsize=2)
def _get_genai_client() -> genai.Client:
return genai.Client()
def _embed_texts(
texts: list[str],
*,
task_type: str,
embedding_model: str,
output_dimensionality: int,
) -> np.ndarray:
if not texts:
return np.zeros((0, output_dimensionality), dtype="float32")
batches: list[np.ndarray] = []
batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100))
for start in range(0, len(texts), batch_size):
chunk = texts[start : start + batch_size]
response = _get_genai_client().models.embed_content(
model=embedding_model,
contents=chunk,
config=types.EmbedContentConfig(
task_type=task_type,
output_dimensionality=output_dimensionality,
),
)
vectors = np.array(
[embedding.values for embedding in response.embeddings],
dtype="float32",
)
if vectors.size == 0:
continue
faiss.normalize_L2(vectors)
batches.append(vectors)
if not batches:
return np.zeros((0, output_dimensionality), dtype="float32")
return np.vstack(batches)
def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]:
root = Path(dataset_dir)
return (
root / FAISS_INDEX_FILENAME,
root / FAISS_METADATA_FILENAME,
)
def _build_index_from_records(
records: tuple[QaRecord, ...],
*,
embedding_model: str,
output_dimensionality: int,
mode: str,
) -> faiss.IndexFlatIP:
search_texts = [_record_search_text(record, mode) for record in records]
vectors = _embed_texts(
search_texts,
task_type="RETRIEVAL_DOCUMENT",
embedding_model=embedding_model,
output_dimensionality=output_dimensionality,
)
if vectors.size == 0:
raise RuntimeError("No embeddings were generated for the dataset records.")
index = faiss.IndexFlatIP(int(vectors.shape[1]))
index.add(vectors)
return index
def build_and_save_faiss_index(
dataset_dir: str | Path,
*,
embedding_model: str = EMBEDDING_MODEL_NAME,
output_dimensionality: int = EMBEDDING_DIMENSION,
index_filename: str = FAISS_INDEX_FILENAME,
qa_index_filename: str = FAISS_QA_INDEX_FILENAME,
metadata_filename: str = FAISS_METADATA_FILENAME,
include_patterns: Iterable[str] | None = None,
) -> tuple[Path, Path, Path]:
root = Path(dataset_dir)
records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
if not records:
raise FileNotFoundError(f"No JSON records found under {root}")
question_index = _build_index_from_records(
records,
embedding_model=embedding_model,
output_dimensionality=output_dimensionality,
mode="question",
)
qa_index = _build_index_from_records(
records,
embedding_model=embedding_model,
output_dimensionality=output_dimensionality,
mode="question_answer",
)
index_path = root / index_filename
qa_index_path = root / qa_index_filename
metadata_path = root / metadata_filename
faiss.write_index(question_index, str(index_path))
faiss.write_index(qa_index, str(qa_index_path))
metadata_payload = {
"items": [
{
"question": record.question,
"answer": record.answer,
"source_file": record.source_file,
**record.metadata,
}
for record in records
]
}
metadata_path.write_text(
json.dumps(metadata_payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return index_path, qa_index_path, metadata_path
@lru_cache(maxsize=8)
def _load_vector_store(
dataset_dir: str,
embedding_model: str,
output_dimensionality: int,
include_patterns: tuple[str, ...] = (),
index_filename: str | None = FAISS_INDEX_FILENAME,
qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
metadata_filename: str | None = FAISS_METADATA_FILENAME,
mode: str = "question",
) -> VectorStore:
selected_index_filename = index_filename if mode == "question" else qa_index_filename
if selected_index_filename and metadata_filename:
index_path = Path(dataset_dir) / selected_index_filename
metadata_path = Path(dataset_dir) / metadata_filename
else:
index_path = metadata_path = None
if index_path and metadata_path and index_path.exists() and metadata_path.exists():
index = faiss.read_index(str(index_path))
records = _load_metadata_records(metadata_path)
if index.ntotal != len(records):
raise ValueError(
f"FAISS index size ({index.ntotal}) does not match metadata size ({len(records)})."
)
return VectorStore(
records=records,
index=index,
embedding_model=embedding_model,
dimension=index.d,
)
records = _load_records(dataset_dir, include_patterns)
if not records:
empty_index = faiss.IndexFlatIP(output_dimensionality)
return VectorStore(
records=tuple(),
index=empty_index,
embedding_model=embedding_model,
dimension=output_dimensionality,
)
index = _build_index_from_records(
records,
embedding_model=embedding_model,
output_dimensionality=output_dimensionality,
mode=mode,
)
return VectorStore(
records=records,
index=index,
embedding_model=embedding_model,
dimension=index.d,
)
class JsonQaRetriever:
def __init__(
self,
dataset_dir: str | Path,
*,
embedding_model: str = EMBEDDING_MODEL_NAME,
output_dimensionality: int = EMBEDDING_DIMENSION,
include_patterns: Iterable[str] | None = None,
index_filename: str | None = FAISS_INDEX_FILENAME,
qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
metadata_filename: str | None = FAISS_METADATA_FILENAME,
):
self.dataset_dir = Path(dataset_dir)
self.embedding_model = embedding_model
self.output_dimensionality = output_dimensionality
self.include_patterns = _normalize_patterns(include_patterns)
self.index_filename = index_filename
self.qa_index_filename = qa_index_filename
self.metadata_filename = metadata_filename
def warmup(self) -> None:
_load_vector_store(
str(self.dataset_dir.resolve()),
self.embedding_model,
self.output_dimensionality,
self.include_patterns,
self.index_filename,
self.qa_index_filename,
self.metadata_filename,
"question",
)
_load_vector_store(
str(self.dataset_dir.resolve()),
self.embedding_model,
self.output_dimensionality,
self.include_patterns,
self.index_filename,
self.qa_index_filename,
self.metadata_filename,
"question_answer",
)
def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
if not matches:
return [
"No strong example was retrieved, so stay in Megumin's persona without inventing unsupported canon facts.",
]
notes = [
"Answer in first person as Megumin, with respectful but dramatic confidence.",
"Use the retrieved cases to mirror tone and answer shape, but do not copy them verbatim.",
"Prefer the retrieved answers as evidence for facts, relationships, and recurring phrasing.",
]
long_answers = sum(
1 for match in matches if len(match.get("answer", "")) >= 180
)
if long_answers >= max(1, math.ceil(len(matches) / 2)):
notes.append(
"The retrieved examples skew narrative, so a short anecdotal lead-in is acceptable."
)
else:
notes.append(
"The retrieved examples are compact, so keep the answer concise and pointed."
)
return notes
def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
question_store = _load_vector_store(
str(self.dataset_dir.resolve()),
self.embedding_model,
self.output_dimensionality,
self.include_patterns,
self.index_filename,
self.qa_index_filename,
self.metadata_filename,
"question",
)
qa_store = _load_vector_store(
str(self.dataset_dir.resolve()),
self.embedding_model,
self.output_dimensionality,
self.include_patterns,
self.index_filename,
self.qa_index_filename,
self.metadata_filename,
"question_answer",
)
if not question_store.records:
return {
"query": query,
"match_count": 0,
"matches": [],
"style_notes": [
"No processed JSON dataset was found for retrieval.",
],
}
query_vector = _embed_texts(
[_normalize_text(query) or query],
task_type="RETRIEVAL_QUERY",
embedding_model=question_store.embedding_model,
output_dimensionality=question_store.dimension,
)
search_k = max(1, min(top_k, len(question_store.records)))
candidates: dict[int, dict[str, Any]] = {}
for store_name, store in (("question", question_store), ("question_answer", qa_store)):
scores, indices = store.index.search(query_vector, search_k)
for score, index in zip(scores[0], indices[0]):
if index < 0:
continue
record = store.records[int(index)]
current = candidates.get(int(index))
score_value = round(float(score), 6)
if current is None or score_value > current["score"]:
candidates[int(index)] = {
"question": record.question,
"answer": _safe_excerpt(record.answer),
"score": score_value,
"source_file": record.source_file,
"metadata": record.metadata,
"matched_via": store_name,
}
matches = sorted(
candidates.values(),
key=lambda item: item["score"],
reverse=True,
)[:top_k]
return {
"query": query,
"match_count": len(matches),
"matches": matches,
"style_notes": self._style_notes(matches),
}