| """
|
| Benchmark data loaders for MMLU, BIG-Bench, and BIG-Bench Hard datasets.
|
|
|
| Loads questions from HuggingFace datasets and formats them for agent evaluation.
|
| Each sample contains: question text, choices, correct answer, subject/task.
|
| """
|
|
|
| import random
|
| from dataclasses import dataclass, field
|
|
|
| from datasets import load_dataset
|
|
|
|
|
| @dataclass
|
| class BenchmarkSample:
|
| """A single benchmark question."""
|
|
|
| dataset_name: str
|
| subject: str
|
| question: str
|
| choices: list[str]
|
| correct_answer: str
|
| sample_id: str
|
| raw: dict = field(default_factory=dict, repr=False)
|
|
|
|
|
|
|
|
|
| MMLU_SUBJECTS = [
|
| "abstract_algebra",
|
| "anatomy",
|
| "astronomy",
|
| "business_ethics",
|
| "clinical_knowledge",
|
| "college_biology",
|
| "college_chemistry",
|
| "college_computer_science",
|
| "college_mathematics",
|
| "college_medicine",
|
| "college_physics",
|
| "computer_security",
|
| "conceptual_physics",
|
| "econometrics",
|
| "electrical_engineering",
|
| "elementary_mathematics",
|
| "formal_logic",
|
| "global_facts",
|
| "high_school_biology",
|
| "high_school_chemistry",
|
| "high_school_computer_science",
|
| "high_school_european_history",
|
| "high_school_geography",
|
| "high_school_government_and_politics",
|
| "high_school_macroeconomics",
|
| "high_school_mathematics",
|
| "high_school_microeconomics",
|
| "high_school_physics",
|
| "high_school_psychology",
|
| "high_school_statistics",
|
| "high_school_us_history",
|
| "high_school_world_history",
|
| "human_aging",
|
| "human_sexuality",
|
| "international_law",
|
| "jurisprudence",
|
| "logical_fallacies",
|
| "machine_learning",
|
| "management",
|
| "marketing",
|
| "medical_genetics",
|
| "miscellaneous",
|
| "moral_disputes",
|
| "moral_scenarios",
|
| "nutrition",
|
| "philosophy",
|
| "prehistory",
|
| "professional_accounting",
|
| "professional_law",
|
| "professional_medicine",
|
| "professional_psychology",
|
| "public_relations",
|
| "security_studies",
|
| "sociology",
|
| "us_foreign_policy",
|
| "virology",
|
| "world_religions",
|
| ]
|
|
|
| ANSWER_LETTERS = ["A", "B", "C", "D"]
|
|
|
|
|
| def load_mmlu(
|
| subjects: list[str] | None = None,
|
| max_samples_per_subject: int | None = None,
|
| split: str = "test",
|
| seed: int = 42,
|
| ) -> list[BenchmarkSample]:
|
| """
|
| Load MMLU dataset from HuggingFace.
|
|
|
| Args:
|
| subjects: List of subjects to load (None = all 57).
|
| max_samples_per_subject: Limit samples per subject for faster testing.
|
| split: Dataset split to use.
|
| seed: Random seed for sampling.
|
|
|
| Returns:
|
| List of BenchmarkSample objects.
|
|
|
| """
|
| subjects = subjects or MMLU_SUBJECTS
|
| samples: list[BenchmarkSample] = []
|
| rng = random.Random(seed)
|
|
|
| for _i, subj in enumerate(subjects, 1):
|
| try:
|
| ds = load_dataset("cais/mmlu", subj, split=split)
|
| except Exception:
|
| try:
|
| ds = load_dataset("lukaemon/mmlu", subj, split=split)
|
| except Exception:
|
| continue
|
|
|
| items = list(ds)
|
| if max_samples_per_subject and len(items) > max_samples_per_subject:
|
| items = rng.sample(items, max_samples_per_subject)
|
|
|
| for idx, row in enumerate(items):
|
| question_text = row["question"]
|
| choices = (
|
| row["choices"]
|
| if "choices" in row
|
| else [row.get("A", ""), row.get("B", ""), row.get("C", ""), row.get("D", "")]
|
| )
|
|
|
| answer_idx = row["answer"]
|
| correct = ANSWER_LETTERS[answer_idx] if isinstance(answer_idx, int) else str(answer_idx).strip().upper()
|
|
|
|
|
| formatted_q = f"{question_text}\n"
|
| for letter_i, ch in enumerate(choices):
|
| formatted_q += f"{ANSWER_LETTERS[letter_i]}. {ch}\n"
|
|
|
| samples.append(
|
| BenchmarkSample(
|
| dataset_name="mmlu",
|
| subject=subj,
|
| question=formatted_q.strip(),
|
| choices=choices,
|
| correct_answer=correct,
|
| sample_id=f"mmlu_{subj}_{idx}",
|
| raw=dict(row),
|
| )
|
| )
|
|
|
| return samples
|
|
|
|
|
|
|
|
|
| BBH_TASKS = [
|
| "boolean_expressions",
|
| "causal_judgement",
|
| "date_understanding",
|
| "disambiguation_qa",
|
| "dyck_languages",
|
| "formal_fallacies",
|
| "geometric_shapes",
|
| "hyperbaton",
|
| "logical_deduction_five_objects",
|
| "logical_deduction_seven_objects",
|
| "logical_deduction_three_objects",
|
| "movie_recommendation",
|
| "multistep_arithmetic_two",
|
| "navigate",
|
| "object_counting",
|
| "penguins_in_a_table",
|
| "reasoning_about_colored_objects",
|
| "ruin_names",
|
| "salient_translation_error_detection",
|
| "snarks",
|
| "sports_understanding",
|
| "temporal_sequences",
|
| "tracking_shuffled_objects_five_objects",
|
| "tracking_shuffled_objects_seven_objects",
|
| "tracking_shuffled_objects_three_objects",
|
| "web_of_lies",
|
| "word_sorting",
|
| ]
|
|
|
|
|
| def load_bbh(
|
| tasks: list[str] | None = None,
|
| max_samples_per_task: int | None = None,
|
| seed: int = 42,
|
| ) -> list[BenchmarkSample]:
|
| """
|
| Load BIG-Bench Hard dataset from HuggingFace.
|
|
|
| Args:
|
| tasks: List of tasks to load (None = all).
|
| max_samples_per_task: Limit samples per task.
|
| seed: Random seed for sampling.
|
|
|
| Returns:
|
| List of BenchmarkSample objects.
|
|
|
| """
|
| tasks = tasks or BBH_TASKS
|
| samples: list[BenchmarkSample] = []
|
| rng = random.Random(seed)
|
|
|
| for _i, task in enumerate(tasks, 1):
|
| try:
|
|
|
| ds = load_dataset("lukaemon/bbh", task, split="test")
|
| data = list(ds)
|
| except Exception:
|
| try:
|
| ds = load_dataset("lukaemon/bbh", task, split="train")
|
| data = list(ds)
|
| except Exception:
|
| continue
|
|
|
| if max_samples_per_task and len(data) > max_samples_per_task:
|
| data = rng.sample(data, max_samples_per_task)
|
|
|
| for idx, row in enumerate(data):
|
| question_text = row.get("input", row.get("question", ""))
|
| target = row.get("target", row.get("answer", ""))
|
|
|
| samples.append(
|
| BenchmarkSample(
|
| dataset_name="bbh",
|
| subject=task,
|
| question=question_text.strip(),
|
| choices=[],
|
| correct_answer=str(target).strip(),
|
| sample_id=f"bbh_{task}_{idx}",
|
| raw=dict(row),
|
| )
|
| )
|
|
|
| return samples
|
|
|
|
|
|
|
|
|
| BIGBENCH_TASKS = [
|
| "abstract_narrative_understanding",
|
| "anachronisms",
|
| "analogical_similarity",
|
| "causal_judgment",
|
| "cause_and_effect",
|
| "elementary_math_qa",
|
| "epistemic_reasoning",
|
| "general_knowledge",
|
| "logical_args",
|
| "logical_fallacy_detection",
|
| "logical_sequence",
|
| "movie_dialog_same_or_different",
|
| "novel_concepts",
|
| "odd_one_out",
|
| "play_dialog_same_or_different",
|
| "presuppositions_as_nli",
|
| "riddle_sense",
|
| "strange_stories",
|
| "strategyqa",
|
| "vitaminc_fact_verification",
|
| ]
|
|
|
|
|
| def load_bigbench(
|
| tasks: list[str] | None = None,
|
| max_samples_per_task: int | None = None,
|
| seed: int = 42,
|
| ) -> list[BenchmarkSample]:
|
| """
|
| Load regular BIG-Bench dataset from HuggingFace (tasksource/bigbench).
|
|
|
| Unlike BBH (free-form), regular BIG-Bench has multiple-choice format
|
| with 'inputs', 'targets', 'multiple_choice_targets', 'multiple_choice_scores'.
|
|
|
| Args:
|
| tasks: List of tasks to load (None = default subset).
|
| max_samples_per_task: Limit samples per task.
|
| seed: Random seed for sampling.
|
|
|
| Returns:
|
| List of BenchmarkSample objects.
|
|
|
| """
|
| tasks = tasks or BIGBENCH_TASKS
|
| samples: list[BenchmarkSample] = []
|
| rng = random.Random(seed)
|
|
|
| for i, task in enumerate(tasks, 1):
|
| try:
|
| ds = load_dataset("tasksource/bigbench", task, split="train")
|
| data = list(ds)
|
| except Exception:
|
| try:
|
| ds = load_dataset("tasksource/bigbench", task, split="validation")
|
| data = list(ds)
|
| except Exception:
|
| continue
|
|
|
| if max_samples_per_task and len(data) > max_samples_per_task:
|
| data = rng.sample(data, max_samples_per_task)
|
|
|
| for idx, row in enumerate(data):
|
| question_text = row.get("inputs", "")
|
| targets = row.get("targets", [])
|
| mc_targets = row.get("multiple_choice_targets", [])
|
| mc_scores = row.get("multiple_choice_scores", [])
|
|
|
|
|
| if mc_targets and mc_scores:
|
|
|
| correct_idx = None
|
| for i, score in enumerate(mc_scores):
|
| if score == 1:
|
| correct_idx = i
|
| break
|
|
|
| if correct_idx is not None and correct_idx < len(mc_targets):
|
| correct_answer = mc_targets[correct_idx]
|
| elif targets:
|
| correct_answer = targets[0] if isinstance(targets, list) else str(targets)
|
| else:
|
| continue
|
|
|
|
|
| choices = mc_targets
|
| formatted_q = f"{question_text}\n"
|
| for i, ch in enumerate(choices):
|
| letter = chr(65 + i)
|
| formatted_q += f"{letter}. {ch}\n"
|
|
|
|
|
| correct_letter = chr(65 + correct_idx) if correct_idx is not None else str(correct_answer)
|
| else:
|
|
|
| correct_answer = targets[0] if isinstance(targets, list) and targets else str(targets)
|
| formatted_q = question_text
|
| choices = []
|
| correct_letter = correct_answer
|
|
|
| samples.append(
|
| BenchmarkSample(
|
| dataset_name="bigbench",
|
| subject=task,
|
| question=formatted_q.strip(),
|
| choices=choices,
|
| correct_answer=correct_letter.strip(),
|
| sample_id=f"bigbench_{task}_{idx}",
|
| raw=dict(row),
|
| )
|
| )
|
|
|
| return samples
|
|
|
|
|
| def load_all_benchmarks(
|
| mmlu_subjects: list[str] | None = None,
|
| bbh_tasks: list[str] | None = None,
|
| bigbench_tasks: list[str] | None = None,
|
| max_samples_per_subject: int | None = None,
|
| seed: int = 42,
|
| ) -> dict[str, list[BenchmarkSample]]:
|
| """Load MMLU, BBH, and BIG-Bench datasets."""
|
| return {
|
| "mmlu": load_mmlu(mmlu_subjects, max_samples_per_subject, seed=seed),
|
| "bbh": load_bbh(bbh_tasks, max_samples_per_subject, seed=seed),
|
| "bigbench": load_bigbench(bigbench_tasks, max_samples_per_subject, seed=seed),
|
| }
|
|
|