Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import math | |
| import os | |
| import re | |
| import unicodedata | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| from typing import Iterable | |
| import faiss | |
| import numpy as np | |
| from google import genai | |
| from google.genai import types | |
| QUESTION_KEYS = ( | |
| "question", | |
| "query", | |
| "q", | |
| "prompt", | |
| "user", | |
| "instruction", | |
| "input", | |
| ) | |
| ANSWER_KEYS = ( | |
| "answer", | |
| "response", | |
| "a", | |
| "output", | |
| "assistant", | |
| "completion", | |
| ) | |
| COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records") | |
| EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001") | |
| EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768")) | |
| EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100")) | |
| FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss") | |
| FAISS_QA_INDEX_FILENAME = os.getenv( | |
| "MEGUMIN_FAISS_QA_INDEX_FILENAME", | |
| "megumin_question_answer.faiss", | |
| ) | |
| FAISS_METADATA_FILENAME = os.getenv( | |
| "MEGUMIN_FAISS_METADATA_FILENAME", | |
| "megumin_questions_meta.json", | |
| ) | |
| PERSONA_DATASET_PATTERNS = ("megumin_qa_dataset.json",) | |
| FACT_DATASET_PATTERNS = ("namuwiki*.json",) | |
| def _normalize_text(value: Any) -> str: | |
| text = str(value or "") | |
| text = unicodedata.normalize("NFKC", text).strip() | |
| text = re.sub(r"\s+", " ", text) | |
| return text | |
| def _safe_excerpt(text: str, limit: int = 220) -> str: | |
| compact = re.sub(r"\s+", " ", str(text or "")).strip() | |
| if len(compact) <= limit: | |
| return compact | |
| return compact[: limit - 3].rstrip() + "..." | |
| def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]: | |
| normalized = tuple(pattern.strip() for pattern in (patterns or ()) if pattern.strip()) | |
| return normalized | |
| def _record_search_text(record: "QaRecord", mode: str) -> str: | |
| if mode == "question_answer": | |
| return f"{record.question}\n{record.answer}".strip() | |
| return record.question | |
| class QaRecord: | |
| question: str | |
| answer: str | |
| source_file: str | |
| metadata: dict[str, Any] | |
| def normalized_question(self) -> str: | |
| return _normalize_text(self.question) | |
| class VectorStore: | |
| records: tuple[QaRecord, ...] | |
| index: faiss.Index | |
| embedding_model: str | |
| dimension: int | |
| def _extract_collection(payload: Any) -> list[Any]: | |
| if isinstance(payload, list): | |
| return payload | |
| if isinstance(payload, dict): | |
| for key in COLLECTION_KEYS: | |
| value = payload.get(key) | |
| if isinstance(value, list): | |
| return value | |
| return [] | |
| def _pick_first(mapping: dict[str, Any], keys: tuple[str, ...]) -> str | None: | |
| lowered = {str(key).lower(): value for key, value in mapping.items()} | |
| for key in keys: | |
| if key in lowered and lowered[key] not in (None, ""): | |
| return str(lowered[key]).strip() | |
| return None | |
| def _record_from_mapping(item: dict[str, Any], source_file: str) -> QaRecord | None: | |
| question = _pick_first(item, QUESTION_KEYS) | |
| answer = _pick_first(item, ANSWER_KEYS) | |
| if not question or not answer: | |
| return None | |
| metadata = { | |
| key: value | |
| for key, value in item.items() | |
| if str(key).lower() not in QUESTION_KEYS + ANSWER_KEYS | |
| } | |
| return QaRecord( | |
| question=question, | |
| answer=answer, | |
| source_file=source_file, | |
| metadata=metadata, | |
| ) | |
| def _load_json_records(path: Path) -> list[QaRecord]: | |
| raw_text = path.read_text(encoding="utf-8") | |
| stripped = raw_text.strip() | |
| if not stripped: | |
| return [] | |
| records: list[QaRecord] = [] | |
| try: | |
| payload = json.loads(stripped) | |
| except json.JSONDecodeError: | |
| payload = None | |
| if payload is not None: | |
| for item in _extract_collection(payload): | |
| if isinstance(item, dict): | |
| record = _record_from_mapping(item, path.name) | |
| if record: | |
| records.append(record) | |
| if records: | |
| return records | |
| for line in stripped.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| item = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| if isinstance(item, dict): | |
| record = _record_from_mapping(item, path.name) | |
| if record: | |
| records.append(record) | |
| return records | |
| def _load_metadata_records(path: Path) -> tuple[QaRecord, ...]: | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| records: list[QaRecord] = [] | |
| for item in _extract_collection(payload): | |
| if isinstance(item, dict): | |
| record = _record_from_mapping(item, path.name) | |
| if record: | |
| records.append(record) | |
| return tuple(records) | |
| def _iter_matching_paths(root: Path, include_patterns: tuple[str, ...]) -> list[Path]: | |
| if not include_patterns: | |
| return sorted(root.glob("*.json")) | |
| seen: set[Path] = set() | |
| paths: list[Path] = [] | |
| for pattern in include_patterns: | |
| for path in sorted(root.glob(pattern)): | |
| if path in seen or path.suffix.lower() != ".json": | |
| continue | |
| seen.add(path) | |
| paths.append(path) | |
| return paths | |
| def _load_records(dataset_dir: str, include_patterns: tuple[str, ...] = ()) -> tuple[QaRecord, ...]: | |
| root = Path(dataset_dir) | |
| if not root.exists(): | |
| return tuple() | |
| all_records: list[QaRecord] = [] | |
| for path in _iter_matching_paths(root, include_patterns): | |
| try: | |
| all_records.extend(_load_json_records(path)) | |
| except OSError: | |
| continue | |
| except UnicodeDecodeError: | |
| continue | |
| return tuple(all_records) | |
| def _get_genai_client() -> genai.Client: | |
| return genai.Client() | |
| def _embed_texts( | |
| texts: list[str], | |
| *, | |
| task_type: str, | |
| embedding_model: str, | |
| output_dimensionality: int, | |
| ) -> np.ndarray: | |
| if not texts: | |
| return np.zeros((0, output_dimensionality), dtype="float32") | |
| batches: list[np.ndarray] = [] | |
| batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100)) | |
| for start in range(0, len(texts), batch_size): | |
| chunk = texts[start : start + batch_size] | |
| response = _get_genai_client().models.embed_content( | |
| model=embedding_model, | |
| contents=chunk, | |
| config=types.EmbedContentConfig( | |
| task_type=task_type, | |
| output_dimensionality=output_dimensionality, | |
| ), | |
| ) | |
| vectors = np.array( | |
| [embedding.values for embedding in response.embeddings], | |
| dtype="float32", | |
| ) | |
| if vectors.size == 0: | |
| continue | |
| faiss.normalize_L2(vectors) | |
| batches.append(vectors) | |
| if not batches: | |
| return np.zeros((0, output_dimensionality), dtype="float32") | |
| return np.vstack(batches) | |
| def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]: | |
| root = Path(dataset_dir) | |
| return ( | |
| root / FAISS_INDEX_FILENAME, | |
| root / FAISS_METADATA_FILENAME, | |
| ) | |
| def _build_index_from_records( | |
| records: tuple[QaRecord, ...], | |
| *, | |
| embedding_model: str, | |
| output_dimensionality: int, | |
| mode: str, | |
| ) -> faiss.IndexFlatIP: | |
| search_texts = [_record_search_text(record, mode) for record in records] | |
| vectors = _embed_texts( | |
| search_texts, | |
| task_type="RETRIEVAL_DOCUMENT", | |
| embedding_model=embedding_model, | |
| output_dimensionality=output_dimensionality, | |
| ) | |
| if vectors.size == 0: | |
| raise RuntimeError("No embeddings were generated for the dataset records.") | |
| index = faiss.IndexFlatIP(int(vectors.shape[1])) | |
| index.add(vectors) | |
| return index | |
| def build_and_save_faiss_index( | |
| dataset_dir: str | Path, | |
| *, | |
| embedding_model: str = EMBEDDING_MODEL_NAME, | |
| output_dimensionality: int = EMBEDDING_DIMENSION, | |
| index_filename: str = FAISS_INDEX_FILENAME, | |
| qa_index_filename: str = FAISS_QA_INDEX_FILENAME, | |
| metadata_filename: str = FAISS_METADATA_FILENAME, | |
| include_patterns: Iterable[str] | None = None, | |
| ) -> tuple[Path, Path, Path]: | |
| root = Path(dataset_dir) | |
| records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns)) | |
| if not records: | |
| raise FileNotFoundError(f"No JSON records found under {root}") | |
| question_index = _build_index_from_records( | |
| records, | |
| embedding_model=embedding_model, | |
| output_dimensionality=output_dimensionality, | |
| mode="question", | |
| ) | |
| qa_index = _build_index_from_records( | |
| records, | |
| embedding_model=embedding_model, | |
| output_dimensionality=output_dimensionality, | |
| mode="question_answer", | |
| ) | |
| index_path = root / index_filename | |
| qa_index_path = root / qa_index_filename | |
| metadata_path = root / metadata_filename | |
| faiss.write_index(question_index, str(index_path)) | |
| faiss.write_index(qa_index, str(qa_index_path)) | |
| metadata_payload = { | |
| "items": [ | |
| { | |
| "question": record.question, | |
| "answer": record.answer, | |
| "source_file": record.source_file, | |
| **record.metadata, | |
| } | |
| for record in records | |
| ] | |
| } | |
| metadata_path.write_text( | |
| json.dumps(metadata_payload, ensure_ascii=False, indent=2), | |
| encoding="utf-8", | |
| ) | |
| return index_path, qa_index_path, metadata_path | |
| def _load_vector_store( | |
| dataset_dir: str, | |
| embedding_model: str, | |
| output_dimensionality: int, | |
| include_patterns: tuple[str, ...] = (), | |
| index_filename: str | None = FAISS_INDEX_FILENAME, | |
| qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME, | |
| metadata_filename: str | None = FAISS_METADATA_FILENAME, | |
| mode: str = "question", | |
| ) -> VectorStore: | |
| selected_index_filename = index_filename if mode == "question" else qa_index_filename | |
| if selected_index_filename and metadata_filename: | |
| index_path = Path(dataset_dir) / selected_index_filename | |
| metadata_path = Path(dataset_dir) / metadata_filename | |
| else: | |
| index_path = metadata_path = None | |
| if index_path and metadata_path and index_path.exists() and metadata_path.exists(): | |
| index = faiss.read_index(str(index_path)) | |
| records = _load_metadata_records(metadata_path) | |
| if index.ntotal != len(records): | |
| raise ValueError( | |
| f"FAISS index size ({index.ntotal}) does not match metadata size ({len(records)})." | |
| ) | |
| return VectorStore( | |
| records=records, | |
| index=index, | |
| embedding_model=embedding_model, | |
| dimension=index.d, | |
| ) | |
| records = _load_records(dataset_dir, include_patterns) | |
| if not records: | |
| empty_index = faiss.IndexFlatIP(output_dimensionality) | |
| return VectorStore( | |
| records=tuple(), | |
| index=empty_index, | |
| embedding_model=embedding_model, | |
| dimension=output_dimensionality, | |
| ) | |
| index = _build_index_from_records( | |
| records, | |
| embedding_model=embedding_model, | |
| output_dimensionality=output_dimensionality, | |
| mode=mode, | |
| ) | |
| return VectorStore( | |
| records=records, | |
| index=index, | |
| embedding_model=embedding_model, | |
| dimension=index.d, | |
| ) | |
| class JsonQaRetriever: | |
| def __init__( | |
| self, | |
| dataset_dir: str | Path, | |
| *, | |
| embedding_model: str = EMBEDDING_MODEL_NAME, | |
| output_dimensionality: int = EMBEDDING_DIMENSION, | |
| include_patterns: Iterable[str] | None = None, | |
| index_filename: str | None = FAISS_INDEX_FILENAME, | |
| qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME, | |
| metadata_filename: str | None = FAISS_METADATA_FILENAME, | |
| ): | |
| self.dataset_dir = Path(dataset_dir) | |
| self.embedding_model = embedding_model | |
| self.output_dimensionality = output_dimensionality | |
| self.include_patterns = _normalize_patterns(include_patterns) | |
| self.index_filename = index_filename | |
| self.qa_index_filename = qa_index_filename | |
| self.metadata_filename = metadata_filename | |
| def warmup(self) -> None: | |
| _load_vector_store( | |
| str(self.dataset_dir.resolve()), | |
| self.embedding_model, | |
| self.output_dimensionality, | |
| self.include_patterns, | |
| self.index_filename, | |
| self.qa_index_filename, | |
| self.metadata_filename, | |
| "question", | |
| ) | |
| _load_vector_store( | |
| str(self.dataset_dir.resolve()), | |
| self.embedding_model, | |
| self.output_dimensionality, | |
| self.include_patterns, | |
| self.index_filename, | |
| self.qa_index_filename, | |
| self.metadata_filename, | |
| "question_answer", | |
| ) | |
| def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]: | |
| if not matches: | |
| return [ | |
| "No strong example was retrieved, so stay in Megumin's persona without inventing unsupported canon facts.", | |
| ] | |
| notes = [ | |
| "Answer in first person as Megumin, with respectful but dramatic confidence.", | |
| "Use the retrieved cases to mirror tone and answer shape, but do not copy them verbatim.", | |
| "Prefer the retrieved answers as evidence for facts, relationships, and recurring phrasing.", | |
| ] | |
| long_answers = sum( | |
| 1 for match in matches if len(match.get("answer", "")) >= 180 | |
| ) | |
| if long_answers >= max(1, math.ceil(len(matches) / 2)): | |
| notes.append( | |
| "The retrieved examples skew narrative, so a short anecdotal lead-in is acceptable." | |
| ) | |
| else: | |
| notes.append( | |
| "The retrieved examples are compact, so keep the answer concise and pointed." | |
| ) | |
| return notes | |
| def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]: | |
| question_store = _load_vector_store( | |
| str(self.dataset_dir.resolve()), | |
| self.embedding_model, | |
| self.output_dimensionality, | |
| self.include_patterns, | |
| self.index_filename, | |
| self.qa_index_filename, | |
| self.metadata_filename, | |
| "question", | |
| ) | |
| qa_store = _load_vector_store( | |
| str(self.dataset_dir.resolve()), | |
| self.embedding_model, | |
| self.output_dimensionality, | |
| self.include_patterns, | |
| self.index_filename, | |
| self.qa_index_filename, | |
| self.metadata_filename, | |
| "question_answer", | |
| ) | |
| if not question_store.records: | |
| return { | |
| "query": query, | |
| "match_count": 0, | |
| "matches": [], | |
| "style_notes": [ | |
| "No processed JSON dataset was found for retrieval.", | |
| ], | |
| } | |
| query_vector = _embed_texts( | |
| [_normalize_text(query) or query], | |
| task_type="RETRIEVAL_QUERY", | |
| embedding_model=question_store.embedding_model, | |
| output_dimensionality=question_store.dimension, | |
| ) | |
| search_k = max(1, min(top_k, len(question_store.records))) | |
| candidates: dict[int, dict[str, Any]] = {} | |
| for store_name, store in (("question", question_store), ("question_answer", qa_store)): | |
| scores, indices = store.index.search(query_vector, search_k) | |
| for score, index in zip(scores[0], indices[0]): | |
| if index < 0: | |
| continue | |
| record = store.records[int(index)] | |
| current = candidates.get(int(index)) | |
| score_value = round(float(score), 6) | |
| if current is None or score_value > current["score"]: | |
| candidates[int(index)] = { | |
| "question": record.question, | |
| "answer": _safe_excerpt(record.answer), | |
| "score": score_value, | |
| "source_file": record.source_file, | |
| "metadata": record.metadata, | |
| "matched_via": store_name, | |
| } | |
| matches = sorted( | |
| candidates.values(), | |
| key=lambda item: item["score"], | |
| reverse=True, | |
| )[:top_k] | |
| return { | |
| "query": query, | |
| "match_count": len(matches), | |
| "matches": matches, | |
| "style_notes": self._style_notes(matches), | |
| } | |