llm-agent-factory / experiments /benchmark_data.py
bridges-optimal-55's picture
Initial commit
505aa09
"""
Benchmark data loaders for MMLU, BIG-Bench, and BIG-Bench Hard datasets.
Loads questions from HuggingFace datasets and formats them for agent evaluation.
Each sample contains: question text, choices, correct answer, subject/task.
"""
import random
from dataclasses import dataclass, field
from datasets import load_dataset
@dataclass
class BenchmarkSample:
"""A single benchmark question."""
dataset_name: str # "mmlu" or "bbh"
subject: str # e.g. "abstract_algebra" or "boolean_expressions"
question: str # The full question text
choices: list[str] # Answer options (for MMLU: A/B/C/D)
correct_answer: str # The correct answer letter or text
sample_id: str # Unique identifier for checkpoint/resume
raw: dict = field(default_factory=dict, repr=False)
# ── MMLU ──────────────────────────────────────────────────────────────────────
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
ANSWER_LETTERS = ["A", "B", "C", "D"]
def load_mmlu(
subjects: list[str] | None = None,
max_samples_per_subject: int | None = None,
split: str = "test",
seed: int = 42,
) -> list[BenchmarkSample]:
"""
Load MMLU dataset from HuggingFace.
Args:
subjects: List of subjects to load (None = all 57).
max_samples_per_subject: Limit samples per subject for faster testing.
split: Dataset split to use.
seed: Random seed for sampling.
Returns:
List of BenchmarkSample objects.
"""
subjects = subjects or MMLU_SUBJECTS
samples: list[BenchmarkSample] = []
rng = random.Random(seed)
for _i, subj in enumerate(subjects, 1):
try:
ds = load_dataset("cais/mmlu", subj, split=split)
except Exception:
try:
ds = load_dataset("lukaemon/mmlu", subj, split=split)
except Exception:
continue
items = list(ds)
if max_samples_per_subject and len(items) > max_samples_per_subject:
items = rng.sample(items, max_samples_per_subject)
for idx, row in enumerate(items):
question_text = row["question"]
choices = (
row["choices"]
if "choices" in row
else [row.get("A", ""), row.get("B", ""), row.get("C", ""), row.get("D", "")]
)
# answer is an int index (0-3) in cais/mmlu
answer_idx = row["answer"]
correct = ANSWER_LETTERS[answer_idx] if isinstance(answer_idx, int) else str(answer_idx).strip().upper()
# Build formatted question
formatted_q = f"{question_text}\n"
for letter_i, ch in enumerate(choices):
formatted_q += f"{ANSWER_LETTERS[letter_i]}. {ch}\n"
samples.append(
BenchmarkSample(
dataset_name="mmlu",
subject=subj,
question=formatted_q.strip(),
choices=choices,
correct_answer=correct,
sample_id=f"mmlu_{subj}_{idx}",
raw=dict(row),
)
)
return samples
# ── BIG-Bench Hard ────────────────────────────────────────────────────────────
BBH_TASKS = [
"boolean_expressions",
"causal_judgement",
"date_understanding",
"disambiguation_qa",
"dyck_languages",
"formal_fallacies",
"geometric_shapes",
"hyperbaton",
"logical_deduction_five_objects",
"logical_deduction_seven_objects",
"logical_deduction_three_objects",
"movie_recommendation",
"multistep_arithmetic_two",
"navigate",
"object_counting",
"penguins_in_a_table",
"reasoning_about_colored_objects",
"ruin_names",
"salient_translation_error_detection",
"snarks",
"sports_understanding",
"temporal_sequences",
"tracking_shuffled_objects_five_objects",
"tracking_shuffled_objects_seven_objects",
"tracking_shuffled_objects_three_objects",
"web_of_lies",
"word_sorting",
]
def load_bbh(
tasks: list[str] | None = None,
max_samples_per_task: int | None = None,
seed: int = 42,
) -> list[BenchmarkSample]:
"""
Load BIG-Bench Hard dataset from HuggingFace.
Args:
tasks: List of tasks to load (None = all).
max_samples_per_task: Limit samples per task.
seed: Random seed for sampling.
Returns:
List of BenchmarkSample objects.
"""
tasks = tasks or BBH_TASKS
samples: list[BenchmarkSample] = []
rng = random.Random(seed)
for _i, task in enumerate(tasks, 1):
try:
# Primary source: lukaemon/bbh (maveriq/bigbenchhard is deprecated)
ds = load_dataset("lukaemon/bbh", task, split="test")
data = list(ds)
except Exception:
try:
ds = load_dataset("lukaemon/bbh", task, split="train")
data = list(ds)
except Exception:
continue
if max_samples_per_task and len(data) > max_samples_per_task:
data = rng.sample(data, max_samples_per_task)
for idx, row in enumerate(data):
question_text = row.get("input", row.get("question", ""))
target = row.get("target", row.get("answer", ""))
samples.append(
BenchmarkSample(
dataset_name="bbh",
subject=task,
question=question_text.strip(),
choices=[], # BBH is mostly free-form
correct_answer=str(target).strip(),
sample_id=f"bbh_{task}_{idx}",
raw=dict(row),
)
)
return samples
# ── BIG-Bench (regular) ───────────────────────────────────────────────────────
BIGBENCH_TASKS = [
"abstract_narrative_understanding",
"anachronisms",
"analogical_similarity",
"causal_judgment",
"cause_and_effect",
"elementary_math_qa",
"epistemic_reasoning",
"general_knowledge",
"logical_args",
"logical_fallacy_detection",
"logical_sequence",
"movie_dialog_same_or_different",
"novel_concepts",
"odd_one_out",
"play_dialog_same_or_different",
"presuppositions_as_nli",
"riddle_sense",
"strange_stories",
"strategyqa",
"vitaminc_fact_verification",
]
def load_bigbench(
tasks: list[str] | None = None,
max_samples_per_task: int | None = None,
seed: int = 42,
) -> list[BenchmarkSample]:
"""
Load regular BIG-Bench dataset from HuggingFace (tasksource/bigbench).
Unlike BBH (free-form), regular BIG-Bench has multiple-choice format
with 'inputs', 'targets', 'multiple_choice_targets', 'multiple_choice_scores'.
Args:
tasks: List of tasks to load (None = default subset).
max_samples_per_task: Limit samples per task.
seed: Random seed for sampling.
Returns:
List of BenchmarkSample objects.
"""
tasks = tasks or BIGBENCH_TASKS
samples: list[BenchmarkSample] = []
rng = random.Random(seed)
for i, task in enumerate(tasks, 1):
try:
ds = load_dataset("tasksource/bigbench", task, split="train")
data = list(ds)
except Exception:
try:
ds = load_dataset("tasksource/bigbench", task, split="validation")
data = list(ds)
except Exception:
continue
if max_samples_per_task and len(data) > max_samples_per_task:
data = rng.sample(data, max_samples_per_task)
for idx, row in enumerate(data):
question_text = row.get("inputs", "")
targets = row.get("targets", [])
mc_targets = row.get("multiple_choice_targets", [])
mc_scores = row.get("multiple_choice_scores", [])
# Determine correct answer
if mc_targets and mc_scores:
# Multiple-choice format
correct_idx = None
for i, score in enumerate(mc_scores):
if score == 1:
correct_idx = i
break
if correct_idx is not None and correct_idx < len(mc_targets):
correct_answer = mc_targets[correct_idx]
elif targets:
correct_answer = targets[0] if isinstance(targets, list) else str(targets)
else:
continue
# Format as multiple-choice question
choices = mc_targets
formatted_q = f"{question_text}\n"
for i, ch in enumerate(choices):
letter = chr(65 + i) # A, B, C, ...
formatted_q += f"{letter}. {ch}\n"
# Store correct as letter
correct_letter = chr(65 + correct_idx) if correct_idx is not None else str(correct_answer)
else:
# Free-form
correct_answer = targets[0] if isinstance(targets, list) and targets else str(targets)
formatted_q = question_text
choices = []
correct_letter = correct_answer
samples.append(
BenchmarkSample(
dataset_name="bigbench",
subject=task,
question=formatted_q.strip(),
choices=choices,
correct_answer=correct_letter.strip(),
sample_id=f"bigbench_{task}_{idx}",
raw=dict(row),
)
)
return samples
def load_all_benchmarks(
mmlu_subjects: list[str] | None = None,
bbh_tasks: list[str] | None = None,
bigbench_tasks: list[str] | None = None,
max_samples_per_subject: int | None = None,
seed: int = 42,
) -> dict[str, list[BenchmarkSample]]:
"""Load MMLU, BBH, and BIG-Bench datasets."""
return {
"mmlu": load_mmlu(mmlu_subjects, max_samples_per_subject, seed=seed),
"bbh": load_bbh(bbh_tasks, max_samples_per_subject, seed=seed),
"bigbench": load_bigbench(bigbench_tasks, max_samples_per_subject, seed=seed),
}