File size: 5,939 Bytes
204fa23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import json
import random
import urllib.request
import ast
from pathlib import Path
from typing import Any


FACTS_PATH = Path(__file__).parent / "facts.json"
FAITHEVAL_COUNTERFACTUAL_URL = (
    "https://raw.githubusercontent.com/SalesforceAIResearch/FaithEval/main/"
    "data/counterfactual.json"
)


def _load_dataset(*args: Any, **kwargs: Any) -> Any:
    from datasets import load_dataset

    return load_dataset(*args, **kwargs)


def _first_text(value: Any) -> str | None:
    """Extract the first useful text value from nested dataset fields."""
    if value is None:
        return None
    if isinstance(value, str):
        text = value.strip()
        if text.startswith("[") and text.endswith("]"):
            try:
                parsed = ast.literal_eval(text)
            except (SyntaxError, ValueError):
                parsed = None
            parsed_text = _first_text(parsed)
            if parsed_text:
                return parsed_text
        return text or None
    if isinstance(value, (int, float)):
        return str(value)
    if isinstance(value, dict):
        for key in ("text", "answer", "answers", "value"):
            text = _first_text(value.get(key))
            if text:
                return text
        return None
    if isinstance(value, (list, tuple)):
        for item in value:
            text = _first_text(item)
            if text:
                return text
    return None


def _word_count(text: str) -> int:
    return len(text.split())


def _clean_question(text: Any) -> str | None:
    question = _first_text(text)
    if not question:
        return None
    question = question.strip()
    if not question.endswith("?"):
        question = f"{question}?"
    return question


def _natural_questions_answer(row: dict[str, Any]) -> str | None:
    annotations = row.get("annotations") or {}
    short_answers = annotations.get("short_answers")
    answer = _first_text(short_answers)
    if answer and _word_count(answer) <= 5:
        return answer
    return None


def load_natural_questions(n: int = 300) -> list[dict[str, str]]:
    facts: list[dict[str, str]] = []
    dataset = _load_dataset(
        "google-research-datasets/natural_questions",
        split="train",
        streaming=True,
    )

    for row in dataset:
        question = _clean_question(row.get("question") or row.get("question_text"))
        answer = _natural_questions_answer(row)
        if not question or not answer:
            continue

        facts.append(
            {
                "question": question,
                "answer": answer,
                "source": "natural_questions",
                "conflict_type": "entity",
            }
        )
        if len(facts) >= n:
            break

    return facts


def load_popqa(n: int = 150) -> list[dict[str, str]]:
    facts: list[dict[str, str]] = []
    dataset = _load_dataset("akariasai/PopQA", split="test")

    for row in dataset:
        question = _clean_question(row.get("question"))
        answer = _first_text(row.get("possible_answers"))
        if not question or not answer:
            continue

        facts.append(
            {
                "question": question,
                "answer": answer,
                "source": "popqa",
                "conflict_type": "entity",
                "entity": _first_text(row.get("subj") or row.get("entity")) or "",
                "relation": _first_text(row.get("prop") or row.get("relation")) or "",
            }
        )
        if len(facts) >= n:
            break

    return facts


def _iter_faitheval_items(payload: Any) -> list[dict[str, Any]]:
    if isinstance(payload, list):
        return [item for item in payload if isinstance(item, dict)]
    if isinstance(payload, dict):
        for key in ("data", "examples", "items", "counterfactual"):
            items = payload.get(key)
            if isinstance(items, list):
                return [item for item in items if isinstance(item, dict)]
    return []


def load_faitheval_counterfactual(n: int = 100) -> list[dict[str, str]]:
    try:
        with urllib.request.urlopen(FAITHEVAL_COUNTERFACTUAL_URL, timeout=20) as response:
            payload = json.loads(response.read().decode("utf-8"))
    except Exception:
        return []

    facts: list[dict[str, str]] = []
    for item in _iter_faitheval_items(payload):
        question = _clean_question(
            item.get("question") or item.get("query") or item.get("claim")
        )
        answer = _first_text(
            item.get("answer")
            or item.get("gold_answer")
            or item.get("label")
            or item.get("target")
        )
        if not question or not answer:
            continue

        facts.append(
            {
                "question": question,
                "answer": answer,
                "source": "faitheval",
                "conflict_type": "counterfactual",
                "provided_context": _first_text(
                    item.get("provided_context")
                    or item.get("context")
                    or item.get("evidence")
                )
                or "",
            }
        )
        if len(facts) >= n:
            break

    return facts


def build_fact_database() -> list[dict[str, str]]:
    facts = (
        load_natural_questions()
        + load_popqa()
        + load_faitheval_counterfactual()
    )
    random.shuffle(facts)

    FACTS_PATH.parent.mkdir(parents=True, exist_ok=True)
    with open(FACTS_PATH, "w", encoding="utf-8") as f:
        json.dump(facts, f, indent=2, ensure_ascii=False)

    counts: dict[str, int] = {}
    for fact in facts:
        source = fact.get("source", "unknown")
        counts[source] = counts.get(source, 0) + 1

    print(f"Wrote {len(facts)} facts to {FACTS_PATH}")
    print(f"Source counts: {counts}")
    return facts


if __name__ == "__main__":
    build_fact_database()