Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""
Minimal dataset loaders/adapters for evaluation benchmarks (GSM8K, HumanEval, MMLU).
These return lightweight Python iterables that yield dicts in the shapes expected by
`src/evaluation/benchmarks.py` evaluators.
Heavy datasets are optional; loaders handle missing datasets gracefully by
raising a clear exception the caller can catch and skip.
"""
from typing import Iterable, Dict, Any, Optional
from datasets import load_dataset # type: ignore
def load_gsm8k(split: str = "test", subset: str = "main", max_samples: Optional[int] = None) -> Iterable[Dict[str, Any]]:
ds = load_dataset("gsm8k", subset, split=split)
def iterator():
count = 0
for row in ds:
yield {"question": [row["question"]], "answer": [row["answer"]]}
count += 1
if max_samples and count >= max_samples:
break
return iterator()
essential_humaneval_fields = ["prompt", "test", "canonical_solution"]
def load_humaneval(split: str = "test", max_samples: Optional[int] = None) -> Iterable[Dict[str, Any]]:
ds = load_dataset("openai_humaneval", split=split)
def iterator():
count = 0
for row in ds:
item = {k: [row.get(k)] for k in essential_humaneval_fields}
yield item
count += 1
if max_samples and count >= max_samples:
break
return iterator()
def load_mmlu(split: str = "validation", subject: Optional[str] = None, max_samples: Optional[int] = None) -> Iterable[Dict[str, Any]]:
"""
Load MMLU-like multiple-choice QA. Try widely used sources in order.
Returns dicts with fields: question, choices, answer.
"""
ds = None
try:
# Newer mirror
ds = load_dataset("cais/mmlu", subject or "abstract_algebra", split=split)
except Exception:
try:
ds = load_dataset("hendrycks_test", subject or "abstract_algebra", split=split)
except Exception as e:
raise RuntimeError(f"Unable to load MMLU dataset: {e}")
def iterator():
count = 0
for row in ds:
choices = row.get("choices") or [row.get("A"), row.get("B"), row.get("C"), row.get("D")]
yield {
"question": [row.get("question", "")],
"choices": [choices],
"answer": [row.get("answer", "")],
}
count += 1
if max_samples and count >= max_samples:
break
return iterator()