Spaces:
Sleeping
Sleeping
File size: 5,939 Bytes
e16c147 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | import json
import random
import urllib.request
import ast
from pathlib import Path
from typing import Any
FACTS_PATH = Path(__file__).parent / "facts.json"
FAITHEVAL_COUNTERFACTUAL_URL = (
"https://raw.githubusercontent.com/SalesforceAIResearch/FaithEval/main/"
"data/counterfactual.json"
)
def _load_dataset(*args: Any, **kwargs: Any) -> Any:
from datasets import load_dataset
return load_dataset(*args, **kwargs)
def _first_text(value: Any) -> str | None:
"""Extract the first useful text value from nested dataset fields."""
if value is None:
return None
if isinstance(value, str):
text = value.strip()
if text.startswith("[") and text.endswith("]"):
try:
parsed = ast.literal_eval(text)
except (SyntaxError, ValueError):
parsed = None
parsed_text = _first_text(parsed)
if parsed_text:
return parsed_text
return text or None
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, dict):
for key in ("text", "answer", "answers", "value"):
text = _first_text(value.get(key))
if text:
return text
return None
if isinstance(value, (list, tuple)):
for item in value:
text = _first_text(item)
if text:
return text
return None
def _word_count(text: str) -> int:
return len(text.split())
def _clean_question(text: Any) -> str | None:
question = _first_text(text)
if not question:
return None
question = question.strip()
if not question.endswith("?"):
question = f"{question}?"
return question
def _natural_questions_answer(row: dict[str, Any]) -> str | None:
annotations = row.get("annotations") or {}
short_answers = annotations.get("short_answers")
answer = _first_text(short_answers)
if answer and _word_count(answer) <= 5:
return answer
return None
def load_natural_questions(n: int = 300) -> list[dict[str, str]]:
facts: list[dict[str, str]] = []
dataset = _load_dataset(
"google-research-datasets/natural_questions",
split="train",
streaming=True,
)
for row in dataset:
question = _clean_question(row.get("question") or row.get("question_text"))
answer = _natural_questions_answer(row)
if not question or not answer:
continue
facts.append(
{
"question": question,
"answer": answer,
"source": "natural_questions",
"conflict_type": "entity",
}
)
if len(facts) >= n:
break
return facts
def load_popqa(n: int = 150) -> list[dict[str, str]]:
facts: list[dict[str, str]] = []
dataset = _load_dataset("akariasai/PopQA", split="test")
for row in dataset:
question = _clean_question(row.get("question"))
answer = _first_text(row.get("possible_answers"))
if not question or not answer:
continue
facts.append(
{
"question": question,
"answer": answer,
"source": "popqa",
"conflict_type": "entity",
"entity": _first_text(row.get("subj") or row.get("entity")) or "",
"relation": _first_text(row.get("prop") or row.get("relation")) or "",
}
)
if len(facts) >= n:
break
return facts
def _iter_faitheval_items(payload: Any) -> list[dict[str, Any]]:
if isinstance(payload, list):
return [item for item in payload if isinstance(item, dict)]
if isinstance(payload, dict):
for key in ("data", "examples", "items", "counterfactual"):
items = payload.get(key)
if isinstance(items, list):
return [item for item in items if isinstance(item, dict)]
return []
def load_faitheval_counterfactual(n: int = 100) -> list[dict[str, str]]:
try:
with urllib.request.urlopen(FAITHEVAL_COUNTERFACTUAL_URL, timeout=20) as response:
payload = json.loads(response.read().decode("utf-8"))
except Exception:
return []
facts: list[dict[str, str]] = []
for item in _iter_faitheval_items(payload):
question = _clean_question(
item.get("question") or item.get("query") or item.get("claim")
)
answer = _first_text(
item.get("answer")
or item.get("gold_answer")
or item.get("label")
or item.get("target")
)
if not question or not answer:
continue
facts.append(
{
"question": question,
"answer": answer,
"source": "faitheval",
"conflict_type": "counterfactual",
"provided_context": _first_text(
item.get("provided_context")
or item.get("context")
or item.get("evidence")
)
or "",
}
)
if len(facts) >= n:
break
return facts
def build_fact_database() -> list[dict[str, str]]:
facts = (
load_natural_questions()
+ load_popqa()
+ load_faitheval_counterfactual()
)
random.shuffle(facts)
FACTS_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(FACTS_PATH, "w", encoding="utf-8") as f:
json.dump(facts, f, indent=2, ensure_ascii=False)
counts: dict[str, int] = {}
for fact in facts:
source = fact.get("source", "unknown")
counts[source] = counts.get(source, 0) + 1
print(f"Wrote {len(facts)} facts to {FACTS_PATH}")
print(f"Source counts: {counts}")
return facts
if __name__ == "__main__":
build_fact_database()
|