| import json |
| import random |
| import urllib.request |
| import ast |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| FACTS_PATH = Path(__file__).parent / "facts.json" |
| FAITHEVAL_COUNTERFACTUAL_URL = ( |
| "https://raw.githubusercontent.com/SalesforceAIResearch/FaithEval/main/" |
| "data/counterfactual.json" |
| ) |
|
|
|
|
| def _load_dataset(*args: Any, **kwargs: Any) -> Any: |
| from datasets import load_dataset |
|
|
| return load_dataset(*args, **kwargs) |
|
|
|
|
| def _first_text(value: Any) -> str | None: |
| """Extract the first useful text value from nested dataset fields.""" |
| if value is None: |
| return None |
| if isinstance(value, str): |
| text = value.strip() |
| if text.startswith("[") and text.endswith("]"): |
| try: |
| parsed = ast.literal_eval(text) |
| except (SyntaxError, ValueError): |
| parsed = None |
| parsed_text = _first_text(parsed) |
| if parsed_text: |
| return parsed_text |
| return text or None |
| if isinstance(value, (int, float)): |
| return str(value) |
| if isinstance(value, dict): |
| for key in ("text", "answer", "answers", "value"): |
| text = _first_text(value.get(key)) |
| if text: |
| return text |
| return None |
| if isinstance(value, (list, tuple)): |
| for item in value: |
| text = _first_text(item) |
| if text: |
| return text |
| return None |
|
|
|
|
| def _word_count(text: str) -> int: |
| return len(text.split()) |
|
|
|
|
| def _clean_question(text: Any) -> str | None: |
| question = _first_text(text) |
| if not question: |
| return None |
| question = question.strip() |
| if not question.endswith("?"): |
| question = f"{question}?" |
| return question |
|
|
|
|
| def _natural_questions_answer(row: dict[str, Any]) -> str | None: |
| annotations = row.get("annotations") or {} |
| short_answers = annotations.get("short_answers") |
| answer = _first_text(short_answers) |
| if answer and _word_count(answer) <= 5: |
| return answer |
| return None |
|
|
|
|
| def load_natural_questions(n: int = 300) -> list[dict[str, str]]: |
| facts: list[dict[str, str]] = [] |
| dataset = _load_dataset( |
| "google-research-datasets/natural_questions", |
| split="train", |
| streaming=True, |
| ) |
|
|
| for row in dataset: |
| question = _clean_question(row.get("question") or row.get("question_text")) |
| answer = _natural_questions_answer(row) |
| if not question or not answer: |
| continue |
|
|
| facts.append( |
| { |
| "question": question, |
| "answer": answer, |
| "source": "natural_questions", |
| "conflict_type": "entity", |
| } |
| ) |
| if len(facts) >= n: |
| break |
|
|
| return facts |
|
|
|
|
| def load_popqa(n: int = 150) -> list[dict[str, str]]: |
| facts: list[dict[str, str]] = [] |
| dataset = _load_dataset("akariasai/PopQA", split="test") |
|
|
| for row in dataset: |
| question = _clean_question(row.get("question")) |
| answer = _first_text(row.get("possible_answers")) |
| if not question or not answer: |
| continue |
|
|
| facts.append( |
| { |
| "question": question, |
| "answer": answer, |
| "source": "popqa", |
| "conflict_type": "entity", |
| "entity": _first_text(row.get("subj") or row.get("entity")) or "", |
| "relation": _first_text(row.get("prop") or row.get("relation")) or "", |
| } |
| ) |
| if len(facts) >= n: |
| break |
|
|
| return facts |
|
|
|
|
| def _iter_faitheval_items(payload: Any) -> list[dict[str, Any]]: |
| if isinstance(payload, list): |
| return [item for item in payload if isinstance(item, dict)] |
| if isinstance(payload, dict): |
| for key in ("data", "examples", "items", "counterfactual"): |
| items = payload.get(key) |
| if isinstance(items, list): |
| return [item for item in items if isinstance(item, dict)] |
| return [] |
|
|
|
|
| def load_faitheval_counterfactual(n: int = 100) -> list[dict[str, str]]: |
| try: |
| with urllib.request.urlopen(FAITHEVAL_COUNTERFACTUAL_URL, timeout=20) as response: |
| payload = json.loads(response.read().decode("utf-8")) |
| except Exception: |
| return [] |
|
|
| facts: list[dict[str, str]] = [] |
| for item in _iter_faitheval_items(payload): |
| question = _clean_question( |
| item.get("question") or item.get("query") or item.get("claim") |
| ) |
| answer = _first_text( |
| item.get("answer") |
| or item.get("gold_answer") |
| or item.get("label") |
| or item.get("target") |
| ) |
| if not question or not answer: |
| continue |
|
|
| facts.append( |
| { |
| "question": question, |
| "answer": answer, |
| "source": "faitheval", |
| "conflict_type": "counterfactual", |
| "provided_context": _first_text( |
| item.get("provided_context") |
| or item.get("context") |
| or item.get("evidence") |
| ) |
| or "", |
| } |
| ) |
| if len(facts) >= n: |
| break |
|
|
| return facts |
|
|
|
|
| def build_fact_database() -> list[dict[str, str]]: |
| facts = ( |
| load_natural_questions() |
| + load_popqa() |
| + load_faitheval_counterfactual() |
| ) |
| random.shuffle(facts) |
|
|
| FACTS_PATH.parent.mkdir(parents=True, exist_ok=True) |
| with open(FACTS_PATH, "w", encoding="utf-8") as f: |
| json.dump(facts, f, indent=2, ensure_ascii=False) |
|
|
| counts: dict[str, int] = {} |
| for fact in facts: |
| source = fact.get("source", "unknown") |
| counts[source] = counts.get(source, 0) + 1 |
|
|
| print(f"Wrote {len(facts)} facts to {FACTS_PATH}") |
| print(f"Source counts: {counts}") |
| return facts |
|
|
|
|
| if __name__ == "__main__": |
| build_fact_database() |
|
|