|
|
"""
|
|
|
Minimal dataset loaders/adapters for evaluation benchmarks (GSM8K, HumanEval, MMLU).
|
|
|
|
|
|
These return lightweight Python iterables that yield dicts in the shapes expected by
|
|
|
`src/evaluation/benchmarks.py` evaluators.
|
|
|
|
|
|
Heavy datasets are optional; loaders handle missing datasets gracefully by
|
|
|
raising a clear exception the caller can catch and skip.
|
|
|
"""
|
|
|
from typing import Iterable, Dict, Any, Optional
|
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
|
|
|
def load_gsm8k(split: str = "test", subset: str = "main", max_samples: Optional[int] = None) -> Iterable[Dict[str, Any]]:
|
|
|
ds = load_dataset("gsm8k", subset, split=split)
|
|
|
def iterator():
|
|
|
count = 0
|
|
|
for row in ds:
|
|
|
yield {"question": [row["question"]], "answer": [row["answer"]]}
|
|
|
count += 1
|
|
|
if max_samples and count >= max_samples:
|
|
|
break
|
|
|
return iterator()
|
|
|
|
|
|
|
|
|
essential_humaneval_fields = ["prompt", "test", "canonical_solution"]
|
|
|
|
|
|
def load_humaneval(split: str = "test", max_samples: Optional[int] = None) -> Iterable[Dict[str, Any]]:
|
|
|
ds = load_dataset("openai_humaneval", split=split)
|
|
|
def iterator():
|
|
|
count = 0
|
|
|
for row in ds:
|
|
|
item = {k: [row.get(k)] for k in essential_humaneval_fields}
|
|
|
yield item
|
|
|
count += 1
|
|
|
if max_samples and count >= max_samples:
|
|
|
break
|
|
|
return iterator()
|
|
|
|
|
|
|
|
|
def load_mmlu(split: str = "validation", subject: Optional[str] = None, max_samples: Optional[int] = None) -> Iterable[Dict[str, Any]]:
|
|
|
"""
|
|
|
Load MMLU-like multiple-choice QA. Try widely used sources in order.
|
|
|
Returns dicts with fields: question, choices, answer.
|
|
|
"""
|
|
|
ds = None
|
|
|
try:
|
|
|
|
|
|
ds = load_dataset("cais/mmlu", subject or "abstract_algebra", split=split)
|
|
|
except Exception:
|
|
|
try:
|
|
|
ds = load_dataset("hendrycks_test", subject or "abstract_algebra", split=split)
|
|
|
except Exception as e:
|
|
|
raise RuntimeError(f"Unable to load MMLU dataset: {e}")
|
|
|
|
|
|
def iterator():
|
|
|
count = 0
|
|
|
for row in ds:
|
|
|
choices = row.get("choices") or [row.get("A"), row.get("B"), row.get("C"), row.get("D")]
|
|
|
yield {
|
|
|
"question": [row.get("question", "")],
|
|
|
"choices": [choices],
|
|
|
"answer": [row.get("answer", "")],
|
|
|
}
|
|
|
count += 1
|
|
|
if max_samples and count >= max_samples:
|
|
|
break
|
|
|
return iterator()
|
|
|
|