Siddh12334's picture
feat: training space with manual start UI
204fa23 verified
import json
import random
import urllib.request
import ast
from pathlib import Path
from typing import Any
FACTS_PATH = Path(__file__).parent / "facts.json"
FAITHEVAL_COUNTERFACTUAL_URL = (
"https://raw.githubusercontent.com/SalesforceAIResearch/FaithEval/main/"
"data/counterfactual.json"
)
def _load_dataset(*args: Any, **kwargs: Any) -> Any:
from datasets import load_dataset
return load_dataset(*args, **kwargs)
def _first_text(value: Any) -> str | None:
"""Extract the first useful text value from nested dataset fields."""
if value is None:
return None
if isinstance(value, str):
text = value.strip()
if text.startswith("[") and text.endswith("]"):
try:
parsed = ast.literal_eval(text)
except (SyntaxError, ValueError):
parsed = None
parsed_text = _first_text(parsed)
if parsed_text:
return parsed_text
return text or None
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, dict):
for key in ("text", "answer", "answers", "value"):
text = _first_text(value.get(key))
if text:
return text
return None
if isinstance(value, (list, tuple)):
for item in value:
text = _first_text(item)
if text:
return text
return None
def _word_count(text: str) -> int:
return len(text.split())
def _clean_question(text: Any) -> str | None:
question = _first_text(text)
if not question:
return None
question = question.strip()
if not question.endswith("?"):
question = f"{question}?"
return question
def _natural_questions_answer(row: dict[str, Any]) -> str | None:
annotations = row.get("annotations") or {}
short_answers = annotations.get("short_answers")
answer = _first_text(short_answers)
if answer and _word_count(answer) <= 5:
return answer
return None
def load_natural_questions(n: int = 300) -> list[dict[str, str]]:
facts: list[dict[str, str]] = []
dataset = _load_dataset(
"google-research-datasets/natural_questions",
split="train",
streaming=True,
)
for row in dataset:
question = _clean_question(row.get("question") or row.get("question_text"))
answer = _natural_questions_answer(row)
if not question or not answer:
continue
facts.append(
{
"question": question,
"answer": answer,
"source": "natural_questions",
"conflict_type": "entity",
}
)
if len(facts) >= n:
break
return facts
def load_popqa(n: int = 150) -> list[dict[str, str]]:
facts: list[dict[str, str]] = []
dataset = _load_dataset("akariasai/PopQA", split="test")
for row in dataset:
question = _clean_question(row.get("question"))
answer = _first_text(row.get("possible_answers"))
if not question or not answer:
continue
facts.append(
{
"question": question,
"answer": answer,
"source": "popqa",
"conflict_type": "entity",
"entity": _first_text(row.get("subj") or row.get("entity")) or "",
"relation": _first_text(row.get("prop") or row.get("relation")) or "",
}
)
if len(facts) >= n:
break
return facts
def _iter_faitheval_items(payload: Any) -> list[dict[str, Any]]:
if isinstance(payload, list):
return [item for item in payload if isinstance(item, dict)]
if isinstance(payload, dict):
for key in ("data", "examples", "items", "counterfactual"):
items = payload.get(key)
if isinstance(items, list):
return [item for item in items if isinstance(item, dict)]
return []
def load_faitheval_counterfactual(n: int = 100) -> list[dict[str, str]]:
try:
with urllib.request.urlopen(FAITHEVAL_COUNTERFACTUAL_URL, timeout=20) as response:
payload = json.loads(response.read().decode("utf-8"))
except Exception:
return []
facts: list[dict[str, str]] = []
for item in _iter_faitheval_items(payload):
question = _clean_question(
item.get("question") or item.get("query") or item.get("claim")
)
answer = _first_text(
item.get("answer")
or item.get("gold_answer")
or item.get("label")
or item.get("target")
)
if not question or not answer:
continue
facts.append(
{
"question": question,
"answer": answer,
"source": "faitheval",
"conflict_type": "counterfactual",
"provided_context": _first_text(
item.get("provided_context")
or item.get("context")
or item.get("evidence")
)
or "",
}
)
if len(facts) >= n:
break
return facts
def build_fact_database() -> list[dict[str, str]]:
facts = (
load_natural_questions()
+ load_popqa()
+ load_faitheval_counterfactual()
)
random.shuffle(facts)
FACTS_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(FACTS_PATH, "w", encoding="utf-8") as f:
json.dump(facts, f, indent=2, ensure_ascii=False)
counts: dict[str, int] = {}
for fact in facts:
source = fact.get("source", "unknown")
counts[source] = counts.get(source, 0) + 1
print(f"Wrote {len(facts)} facts to {FACTS_PATH}")
print(f"Source counts: {counts}")
return facts
if __name__ == "__main__":
build_fact_database()