# loader/data.py from typing import List, Dict from pathlib import Path import csv # ------------------------------------------------------------------- # 1. Mapping from dataset key → CSV filename # ------------------------------------------------------------------- _DATASET_FILES: Dict[str, str] = { "bar_exam": "BarExam_qa.csv", "causal_judgment": "bbh_causal_judgement.csv", "snarks": "bbh_snarks.csv", "bbq_disamb": "BBQ_disamb.csv", "cnn_dailymail": "CNN_dailymail.csv", "drop": "drop.csv", "esnli": "eSNLI.csv", "fever": "fever.csv", "hotpot_qa": "hotpot_qa.csv", "medical_qa": "medical_qa.csv", } # ------------------------------------------------------------------- # 1b. Human-readable display names for UI # ------------------------------------------------------------------- _DATASET_DISPLAY_NAMES: Dict[str, str] = { "bar_exam": "Bar Exam Questions", "causal_judgment": "Causal Judgment", "snarks": "Snarks", "bbq_disamb": "BBQ Disambiguation", "cnn_dailymail": "CNN / DailyMail Summaries", "drop": "DROP Reading Comprehension", "esnli": "e-SNLI Natural Language Inference", "fever": "FEVER Fact Checking", "hotpot_qa": "HotpotQA Multi-hop Questions", "medical_qa": "Medical Questions", } # ------------------------------------------------------------------- # 2. Where the CSVs live (loader/../datasets/) # ------------------------------------------------------------------- def _datasets_dir() -> Path: return (Path(__file__).resolve().parent.parent / "datasets").resolve() # ------------------------------------------------------------------- # 3. Pick the first non-empty column among several candidates # ------------------------------------------------------------------- def _pick_first_nonempty(raw: Dict[str, str], candidates: List[str]) -> str: for c in candidates: val = raw.get(c) if val is not None and str(val).strip() != "": return str(val) return "" # ------------------------------------------------------------------- # 4. Load a single CSV file and normalize it to our schema # ------------------------------------------------------------------- def _load_one_dataset(name: str, filename: str) -> List[Dict[str, str]]: """ Reads a CSV file and converts each row to our standard format: { "id": "example_1", "context": "...", "prompt": "...", "answer": "..." # optional } Only the first 10 rows are kept. """ path = _datasets_dir() / filename rows: List[Dict[str, str]] = [] # errors="replace" avoids Unicode crashes for imperfect CSVs try: with path.open("r", encoding="utf-8", errors="replace", newline="") as f: reader = csv.DictReader(f) for i, raw in enumerate(reader, start=1): ex_id = raw.get("id") or raw.get("example_id") \ or raw.get("uid") or f"example_{i}" context = _pick_first_nonempty(raw, [ "Context", "context", "passage", "article", "story", "premise", "paragraph", "document", "sentence1", "sent1", "background", ]) prompt = _pick_first_nonempty(raw, [ "Prompt", "prompt", "question", "input", "query", "sentence2", "sent2", "hypothesis", "qa_question", "title", ]) answer = _pick_first_nonempty(raw, [ "Answer", "answer", "target", "gold", "label", "output", "reference", "highlights", ]) ex = { "id": str(ex_id), "context": context, "prompt": prompt, } if answer: ex["answer"] = answer rows.append(ex) except FileNotFoundError: return [] except Exception: # Keep import resilient in constrained environments (e.g., Spaces). return [] return rows[:10] # keep exactly 10 examples # ------------------------------------------------------------------- # 5. Load all datasets ONCE when the module is imported # ------------------------------------------------------------------- def _load_all_datasets() -> Dict[str, List[Dict[str, str]]]: return { name: _load_one_dataset(name, filename) for name, filename in _DATASET_FILES.items() } _DATA: Dict[str, List[Dict[str, str]]] = _load_all_datasets() # ← cached # ------------------------------------------------------------------- # 6. Public Functions — these are used by the app # ------------------------------------------------------------------- def list_datasets() -> List[str]: """Return all dataset names, sorted alphabetically.""" return sorted(_DATA.keys()) def get_dataset_display_name(dataset: str) -> str: """Return human-readable display name for a dataset.""" return _DATASET_DISPLAY_NAMES.get(dataset, dataset) def get_dataset_key_from_display_name(display_name: str) -> str: """Convert display name back to internal key.""" # Create reverse mapping for key, name in _DATASET_DISPLAY_NAMES.items(): if name == display_name: return key # If not found, assume it's already a key return display_name def list_datasets_with_display_names() -> List[tuple[str, str]]: """Return list of (key, display_name) tuples, sorted by display name.""" pairs = [(key, _DATASET_DISPLAY_NAMES.get(key, key)) for key in _DATA.keys()] return sorted(pairs, key=lambda x: x[1]) def list_dataset_display_names() -> List[str]: """Return list of display names only, sorted alphabetically.""" names = [_DATASET_DISPLAY_NAMES.get(key, key) for key in _DATA.keys()] return sorted(names) def get_examples(dataset: str, n: int = 10) -> List[Dict[str, str]]: """Return up to n examples for a dataset.""" if dataset not in _DATA: raise KeyError(f"Unknown dataset: {dataset}") return _DATA[dataset][:n] def get_example_by_id(dataset: str, ex_id: str) -> Dict[str, str]: """Return a single example whose ID matches ex_id.""" if dataset not in _DATA: raise KeyError(f"Unknown dataset: {dataset}") for ex in _DATA[dataset]: if ex["id"] == ex_id: return ex raise KeyError(f"Example id '{ex_id}' not found in dataset '{dataset}'")