File size: 6,705 Bytes
3e72399 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | # loader/data.py
from typing import List, Dict
from pathlib import Path
import csv
# -------------------------------------------------------------------
# 1. Mapping from dataset key → CSV filename
# -------------------------------------------------------------------
_DATASET_FILES: Dict[str, str] = {
"bar_exam": "BarExam_qa.csv",
"causal_judgment": "bbh_causal_judgement.csv",
"snarks": "bbh_snarks.csv",
"bbq_disamb": "BBQ_disamb.csv",
"cnn_dailymail": "CNN_dailymail.csv",
"drop": "drop.csv",
"esnli": "eSNLI.csv",
"fever": "fever.csv",
"hotpot_qa": "hotpot_qa.csv",
"medical_qa": "medical_qa.csv",
}
# -------------------------------------------------------------------
# 1b. Human-readable display names for UI
# -------------------------------------------------------------------
_DATASET_DISPLAY_NAMES: Dict[str, str] = {
"bar_exam": "Bar Exam Questions",
"causal_judgment": "Causal Judgment",
"snarks": "Snarks",
"bbq_disamb": "BBQ Disambiguation",
"cnn_dailymail": "CNN / DailyMail Summaries",
"drop": "DROP Reading Comprehension",
"esnli": "e-SNLI Natural Language Inference",
"fever": "FEVER Fact Checking",
"hotpot_qa": "HotpotQA Multi-hop Questions",
"medical_qa": "Medical Questions",
}
# -------------------------------------------------------------------
# 2. Where the CSVs live (loader/../datasets/)
# -------------------------------------------------------------------
def _datasets_dir() -> Path:
return (Path(__file__).resolve().parent.parent / "datasets").resolve()
# -------------------------------------------------------------------
# 3. Pick the first non-empty column among several candidates
# -------------------------------------------------------------------
def _pick_first_nonempty(raw: Dict[str, str], candidates: List[str]) -> str:
for c in candidates:
val = raw.get(c)
if val is not None and str(val).strip() != "":
return str(val)
return ""
# -------------------------------------------------------------------
# 4. Load a single CSV file and normalize it to our schema
# -------------------------------------------------------------------
def _load_one_dataset(name: str, filename: str) -> List[Dict[str, str]]:
"""
Reads a CSV file and converts each row to our standard format:
{
"id": "example_1",
"context": "...",
"prompt": "...",
"answer": "..." # optional
}
Only the first 10 rows are kept.
"""
path = _datasets_dir() / filename
rows: List[Dict[str, str]] = []
# errors="replace" avoids Unicode crashes for imperfect CSVs
try:
with path.open("r", encoding="utf-8", errors="replace", newline="") as f:
reader = csv.DictReader(f)
for i, raw in enumerate(reader, start=1):
ex_id = raw.get("id") or raw.get("example_id") \
or raw.get("uid") or f"example_{i}"
context = _pick_first_nonempty(raw, [
"Context", "context",
"passage", "article", "story", "premise",
"paragraph", "document", "sentence1", "sent1", "background",
])
prompt = _pick_first_nonempty(raw, [
"Prompt", "prompt",
"question", "input", "query",
"sentence2", "sent2", "hypothesis",
"qa_question", "title",
])
answer = _pick_first_nonempty(raw, [
"Answer", "answer",
"target", "gold", "label", "output", "reference",
"highlights",
])
ex = {
"id": str(ex_id),
"context": context,
"prompt": prompt,
}
if answer:
ex["answer"] = answer
rows.append(ex)
except FileNotFoundError:
return []
except Exception:
# Keep import resilient in constrained environments (e.g., Spaces).
return []
return rows[:10] # keep exactly 10 examples
# -------------------------------------------------------------------
# 5. Load all datasets ONCE when the module is imported
# -------------------------------------------------------------------
def _load_all_datasets() -> Dict[str, List[Dict[str, str]]]:
return {
name: _load_one_dataset(name, filename)
for name, filename in _DATASET_FILES.items()
}
_DATA: Dict[str, List[Dict[str, str]]] = _load_all_datasets() # ← cached
# -------------------------------------------------------------------
# 6. Public Functions — these are used by the app
# -------------------------------------------------------------------
def list_datasets() -> List[str]:
"""Return all dataset names, sorted alphabetically."""
return sorted(_DATA.keys())
def get_dataset_display_name(dataset: str) -> str:
"""Return human-readable display name for a dataset."""
return _DATASET_DISPLAY_NAMES.get(dataset, dataset)
def get_dataset_key_from_display_name(display_name: str) -> str:
"""Convert display name back to internal key."""
# Create reverse mapping
for key, name in _DATASET_DISPLAY_NAMES.items():
if name == display_name:
return key
# If not found, assume it's already a key
return display_name
def list_datasets_with_display_names() -> List[tuple[str, str]]:
"""Return list of (key, display_name) tuples, sorted by display name."""
pairs = [(key, _DATASET_DISPLAY_NAMES.get(key, key)) for key in _DATA.keys()]
return sorted(pairs, key=lambda x: x[1])
def list_dataset_display_names() -> List[str]:
"""Return list of display names only, sorted alphabetically."""
names = [_DATASET_DISPLAY_NAMES.get(key, key) for key in _DATA.keys()]
return sorted(names)
def get_examples(dataset: str, n: int = 10) -> List[Dict[str, str]]:
"""Return up to n examples for a dataset."""
if dataset not in _DATA:
raise KeyError(f"Unknown dataset: {dataset}")
return _DATA[dataset][:n]
def get_example_by_id(dataset: str, ex_id: str) -> Dict[str, str]:
"""Return a single example whose ID matches ex_id."""
if dataset not in _DATA:
raise KeyError(f"Unknown dataset: {dataset}")
for ex in _DATA[dataset]:
if ex["id"] == ex_id:
return ex
raise KeyError(f"Example id '{ex_id}' not found in dataset '{dataset}'")
|