from __future__ import annotations import json import math import os import re import unicodedata from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any from typing import Iterable import faiss import numpy as np from google import genai from google.genai import types QUESTION_KEYS = ( "question", "query", "q", "prompt", "user", "instruction", "input", ) ANSWER_KEYS = ( "answer", "response", "a", "output", "assistant", "completion", ) COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records") EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001") EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768")) EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100")) FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss") FAISS_QA_INDEX_FILENAME = os.getenv( "MEGUMIN_FAISS_QA_INDEX_FILENAME", "megumin_question_answer.faiss", ) FAISS_METADATA_FILENAME = os.getenv( "MEGUMIN_FAISS_METADATA_FILENAME", "megumin_questions_meta.json", ) PERSONA_DATASET_PATTERNS = ("megumin_qa_dataset.json",) FACT_DATASET_PATTERNS = ("namuwiki*.json",) def _normalize_text(value: Any) -> str: text = str(value or "") text = unicodedata.normalize("NFKC", text).strip() text = re.sub(r"\s+", " ", text) return text def _safe_excerpt(text: str, limit: int = 220) -> str: compact = re.sub(r"\s+", " ", str(text or "")).strip() if len(compact) <= limit: return compact return compact[: limit - 3].rstrip() + "..." def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]: normalized = tuple(pattern.strip() for pattern in (patterns or ()) if pattern.strip()) return normalized def _record_search_text(record: "QaRecord", mode: str) -> str: if mode == "question_answer": return f"{record.question}\n{record.answer}".strip() return record.question @dataclass(frozen=True) class QaRecord: question: str answer: str source_file: str metadata: dict[str, Any] @property def normalized_question(self) -> str: return _normalize_text(self.question) @dataclass(frozen=True) class VectorStore: records: tuple[QaRecord, ...] index: faiss.Index embedding_model: str dimension: int def _extract_collection(payload: Any) -> list[Any]: if isinstance(payload, list): return payload if isinstance(payload, dict): for key in COLLECTION_KEYS: value = payload.get(key) if isinstance(value, list): return value return [] def _pick_first(mapping: dict[str, Any], keys: tuple[str, ...]) -> str | None: lowered = {str(key).lower(): value for key, value in mapping.items()} for key in keys: if key in lowered and lowered[key] not in (None, ""): return str(lowered[key]).strip() return None def _record_from_mapping(item: dict[str, Any], source_file: str) -> QaRecord | None: question = _pick_first(item, QUESTION_KEYS) answer = _pick_first(item, ANSWER_KEYS) if not question or not answer: return None metadata = { key: value for key, value in item.items() if str(key).lower() not in QUESTION_KEYS + ANSWER_KEYS } return QaRecord( question=question, answer=answer, source_file=source_file, metadata=metadata, ) def _load_json_records(path: Path) -> list[QaRecord]: raw_text = path.read_text(encoding="utf-8") stripped = raw_text.strip() if not stripped: return [] records: list[QaRecord] = [] try: payload = json.loads(stripped) except json.JSONDecodeError: payload = None if payload is not None: for item in _extract_collection(payload): if isinstance(item, dict): record = _record_from_mapping(item, path.name) if record: records.append(record) if records: return records for line in stripped.splitlines(): line = line.strip() if not line: continue try: item = json.loads(line) except json.JSONDecodeError: continue if isinstance(item, dict): record = _record_from_mapping(item, path.name) if record: records.append(record) return records def _load_metadata_records(path: Path) -> tuple[QaRecord, ...]: payload = json.loads(path.read_text(encoding="utf-8")) records: list[QaRecord] = [] for item in _extract_collection(payload): if isinstance(item, dict): record = _record_from_mapping(item, path.name) if record: records.append(record) return tuple(records) def _iter_matching_paths(root: Path, include_patterns: tuple[str, ...]) -> list[Path]: if not include_patterns: return sorted(root.glob("*.json")) seen: set[Path] = set() paths: list[Path] = [] for pattern in include_patterns: for path in sorted(root.glob(pattern)): if path in seen or path.suffix.lower() != ".json": continue seen.add(path) paths.append(path) return paths @lru_cache(maxsize=16) def _load_records(dataset_dir: str, include_patterns: tuple[str, ...] = ()) -> tuple[QaRecord, ...]: root = Path(dataset_dir) if not root.exists(): return tuple() all_records: list[QaRecord] = [] for path in _iter_matching_paths(root, include_patterns): try: all_records.extend(_load_json_records(path)) except OSError: continue except UnicodeDecodeError: continue return tuple(all_records) @lru_cache(maxsize=2) def _get_genai_client() -> genai.Client: return genai.Client() def _embed_texts( texts: list[str], *, task_type: str, embedding_model: str, output_dimensionality: int, ) -> np.ndarray: if not texts: return np.zeros((0, output_dimensionality), dtype="float32") batches: list[np.ndarray] = [] batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100)) for start in range(0, len(texts), batch_size): chunk = texts[start : start + batch_size] response = _get_genai_client().models.embed_content( model=embedding_model, contents=chunk, config=types.EmbedContentConfig( task_type=task_type, output_dimensionality=output_dimensionality, ), ) vectors = np.array( [embedding.values for embedding in response.embeddings], dtype="float32", ) if vectors.size == 0: continue faiss.normalize_L2(vectors) batches.append(vectors) if not batches: return np.zeros((0, output_dimensionality), dtype="float32") return np.vstack(batches) def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]: root = Path(dataset_dir) return ( root / FAISS_INDEX_FILENAME, root / FAISS_METADATA_FILENAME, ) def _build_index_from_records( records: tuple[QaRecord, ...], *, embedding_model: str, output_dimensionality: int, mode: str, ) -> faiss.IndexFlatIP: search_texts = [_record_search_text(record, mode) for record in records] vectors = _embed_texts( search_texts, task_type="RETRIEVAL_DOCUMENT", embedding_model=embedding_model, output_dimensionality=output_dimensionality, ) if vectors.size == 0: raise RuntimeError("No embeddings were generated for the dataset records.") index = faiss.IndexFlatIP(int(vectors.shape[1])) index.add(vectors) return index def build_and_save_faiss_index( dataset_dir: str | Path, *, embedding_model: str = EMBEDDING_MODEL_NAME, output_dimensionality: int = EMBEDDING_DIMENSION, index_filename: str = FAISS_INDEX_FILENAME, qa_index_filename: str = FAISS_QA_INDEX_FILENAME, metadata_filename: str = FAISS_METADATA_FILENAME, include_patterns: Iterable[str] | None = None, ) -> tuple[Path, Path, Path]: root = Path(dataset_dir) records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns)) if not records: raise FileNotFoundError(f"No JSON records found under {root}") question_index = _build_index_from_records( records, embedding_model=embedding_model, output_dimensionality=output_dimensionality, mode="question", ) qa_index = _build_index_from_records( records, embedding_model=embedding_model, output_dimensionality=output_dimensionality, mode="question_answer", ) index_path = root / index_filename qa_index_path = root / qa_index_filename metadata_path = root / metadata_filename faiss.write_index(question_index, str(index_path)) faiss.write_index(qa_index, str(qa_index_path)) metadata_payload = { "items": [ { "question": record.question, "answer": record.answer, "source_file": record.source_file, **record.metadata, } for record in records ] } metadata_path.write_text( json.dumps(metadata_payload, ensure_ascii=False, indent=2), encoding="utf-8", ) return index_path, qa_index_path, metadata_path @lru_cache(maxsize=8) def _load_vector_store( dataset_dir: str, embedding_model: str, output_dimensionality: int, include_patterns: tuple[str, ...] = (), index_filename: str | None = FAISS_INDEX_FILENAME, qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME, metadata_filename: str | None = FAISS_METADATA_FILENAME, mode: str = "question", ) -> VectorStore: selected_index_filename = index_filename if mode == "question" else qa_index_filename if selected_index_filename and metadata_filename: index_path = Path(dataset_dir) / selected_index_filename metadata_path = Path(dataset_dir) / metadata_filename else: index_path = metadata_path = None if index_path and metadata_path and index_path.exists() and metadata_path.exists(): index = faiss.read_index(str(index_path)) records = _load_metadata_records(metadata_path) if index.ntotal != len(records): raise ValueError( f"FAISS index size ({index.ntotal}) does not match metadata size ({len(records)})." ) return VectorStore( records=records, index=index, embedding_model=embedding_model, dimension=index.d, ) records = _load_records(dataset_dir, include_patterns) if not records: empty_index = faiss.IndexFlatIP(output_dimensionality) return VectorStore( records=tuple(), index=empty_index, embedding_model=embedding_model, dimension=output_dimensionality, ) index = _build_index_from_records( records, embedding_model=embedding_model, output_dimensionality=output_dimensionality, mode=mode, ) return VectorStore( records=records, index=index, embedding_model=embedding_model, dimension=index.d, ) class JsonQaRetriever: def __init__( self, dataset_dir: str | Path, *, embedding_model: str = EMBEDDING_MODEL_NAME, output_dimensionality: int = EMBEDDING_DIMENSION, include_patterns: Iterable[str] | None = None, index_filename: str | None = FAISS_INDEX_FILENAME, qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME, metadata_filename: str | None = FAISS_METADATA_FILENAME, ): self.dataset_dir = Path(dataset_dir) self.embedding_model = embedding_model self.output_dimensionality = output_dimensionality self.include_patterns = _normalize_patterns(include_patterns) self.index_filename = index_filename self.qa_index_filename = qa_index_filename self.metadata_filename = metadata_filename def warmup(self) -> None: _load_vector_store( str(self.dataset_dir.resolve()), self.embedding_model, self.output_dimensionality, self.include_patterns, self.index_filename, self.qa_index_filename, self.metadata_filename, "question", ) _load_vector_store( str(self.dataset_dir.resolve()), self.embedding_model, self.output_dimensionality, self.include_patterns, self.index_filename, self.qa_index_filename, self.metadata_filename, "question_answer", ) def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]: if not matches: return [ "No strong example was retrieved, so stay in Megumin's persona without inventing unsupported canon facts.", ] notes = [ "Answer in first person as Megumin, with respectful but dramatic confidence.", "Use the retrieved cases to mirror tone and answer shape, but do not copy them verbatim.", "Prefer the retrieved answers as evidence for facts, relationships, and recurring phrasing.", ] long_answers = sum( 1 for match in matches if len(match.get("answer", "")) >= 180 ) if long_answers >= max(1, math.ceil(len(matches) / 2)): notes.append( "The retrieved examples skew narrative, so a short anecdotal lead-in is acceptable." ) else: notes.append( "The retrieved examples are compact, so keep the answer concise and pointed." ) return notes def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]: question_store = _load_vector_store( str(self.dataset_dir.resolve()), self.embedding_model, self.output_dimensionality, self.include_patterns, self.index_filename, self.qa_index_filename, self.metadata_filename, "question", ) qa_store = _load_vector_store( str(self.dataset_dir.resolve()), self.embedding_model, self.output_dimensionality, self.include_patterns, self.index_filename, self.qa_index_filename, self.metadata_filename, "question_answer", ) if not question_store.records: return { "query": query, "match_count": 0, "matches": [], "style_notes": [ "No processed JSON dataset was found for retrieval.", ], } query_vector = _embed_texts( [_normalize_text(query) or query], task_type="RETRIEVAL_QUERY", embedding_model=question_store.embedding_model, output_dimensionality=question_store.dimension, ) search_k = max(1, min(top_k, len(question_store.records))) candidates: dict[int, dict[str, Any]] = {} for store_name, store in (("question", question_store), ("question_answer", qa_store)): scores, indices = store.index.search(query_vector, search_k) for score, index in zip(scores[0], indices[0]): if index < 0: continue record = store.records[int(index)] current = candidates.get(int(index)) score_value = round(float(score), 6) if current is None or score_value > current["score"]: candidates[int(index)] = { "question": record.question, "answer": _safe_excerpt(record.answer), "score": score_value, "source_file": record.source_file, "metadata": record.metadata, "matched_via": store_name, } matches = sorted( candidates.values(), key=lambda item: item["score"], reverse=True, )[:top_k] return { "query": query, "match_count": len(matches), "matches": matches, "style_notes": self._style_notes(matches), }