""" Standard NLP benchmark task definitions. Each task loads from HuggingFace datasets and formats examples for log-likelihood scoring. Supported tasks: - HellaSwag (commonsense NLI, 4-choice) - ARC-Easy / ARC-Challenge (science QA, 3-5 choices) - PIQA (physical intuition, 2-choice) - WinoGrande (coreference, 2-choice) - MMLU (multi-domain knowledge, 4-choice) - HaluEval-QA (hallucination detection, 2-choice) """ import random from abc import ABC, abstractmethod from typing import List, Dict, Tuple, Optional, Any from datasets import load_dataset class BenchmarkTask(ABC): """Base class for benchmark tasks.""" name: str = "base" num_few_shot: int = 0 @abstractmethod def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: """ Load and format examples. Returns list of dicts, each with: - "context": str — the prompt/context - "continuations": List[str] — possible completions - "gold_idx": int — index of the correct continuation """ ... def format_few_shot(self, examples: List[Dict], train_examples: List[Dict]) -> List[Dict]: """Prepend few-shot examples to each test example's context.""" if not train_examples or self.num_few_shot == 0: return examples # Build few-shot prefix shots = train_examples[:self.num_few_shot] prefix = "" for shot in shots: prefix += shot["context"] + shot["continuations"][shot["gold_idx"]] + "\n\n" for ex in examples: ex["context"] = prefix + ex["context"] return examples class HellaSwag(BenchmarkTask): """ HellaSwag: Can a Machine Really Finish Your Sentence? 4-choice commonsense NLI. Dataset: Rowan/hellaswag """ name = "hellaswag" num_few_shot = 5 def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: ds = load_dataset("Rowan/hellaswag", split="validation") if n is not None: ds = ds.shuffle(seed=seed).select(range(min(n, len(ds)))) examples = [] few_shot_ds = load_dataset("Rowan/hellaswag", split="train") few_shot_ds = few_shot_ds.shuffle(seed=seed).select(range(self.num_few_shot)) train_examples = [] for row in few_shot_ds: ctx = row["ctx"] endings = [ ending if ending.startswith(" ") else f" {ending}" for ending in row["endings"] ] gold = int(row["label"]) train_examples.append({ "context": ctx, "continuations": endings, "gold_idx": gold, }) for row in ds: ctx = row["ctx"] endings = [ ending if ending.startswith(" ") else f" {ending}" for ending in row["endings"] ] gold = int(row["label"]) examples.append({ "context": ctx, "continuations": endings, "gold_idx": gold, }) return self.format_few_shot(examples, train_examples) class ARC(BenchmarkTask): """ AI2 Reasoning Challenge. Dataset: allenai/ai2_arc Subsets: ARC-Easy, ARC-Challenge """ name = "arc" num_few_shot = 5 def __init__(self, subset: str = "ARC-Easy"): self.subset = subset self.name = f"arc-{'easy' if 'Easy' in subset else 'challenge'}" def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: ds = load_dataset("allenai/ai2_arc", self.subset, split="test") if n is not None: ds = ds.shuffle(seed=seed).select(range(min(n, len(ds)))) # Few-shot from train train_ds = load_dataset("allenai/ai2_arc", self.subset, split="train") train_ds = train_ds.shuffle(seed=seed).select(range(self.num_few_shot)) def format_row(row): question = row["question"] choices = row["choices"] labels = choices["label"] texts = choices["text"] answer_key = row["answerKey"] # Map answer key to index gold_idx = labels.index(answer_key) if answer_key in labels else 0 # Format as "Question: ...\nA) ... B) ...\nAnswer:" choice_str = " ".join(f"{l}) {t}" for l, t in zip(labels, texts)) context = f"Question: {question}\n{choice_str}\nAnswer:" continuations = [f" {l}" for l in labels] return { "context": context, "continuations": continuations, "gold_idx": gold_idx, } train_examples = [format_row(row) for row in train_ds] examples = [format_row(row) for row in ds] return self.format_few_shot(examples, train_examples) class PIQA(BenchmarkTask): """ Physical Intuition QA. 2-choice. Dataset: gimmaru/piqa (parquet mirror — ybisk/piqa loading script no longer supported) """ name = "piqa" num_few_shot = 0 # No train split in mirror; use 0-shot def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: ds = load_dataset("gimmaru/piqa", split="validation") if n is not None: ds = ds.shuffle(seed=seed).select(range(min(n, len(ds)))) def format_row(row): goal = row["goal"] sol1 = row["sol1"] sol2 = row["sol2"] gold = row["label"] # 0 or 1 context = f"Goal: {goal}\nSolution 1: {sol1}\nSolution 2: {sol2}\nThe better solution is Solution" continuations = [" 1", " 2"] return { "context": context, "continuations": continuations, "gold_idx": gold, } examples = [format_row(row) for row in ds] return examples class WinoGrande(BenchmarkTask): """ WinoGrande: Winograd-style coreference. 2-choice. Dataset: allenai/winogrande (winogrande_xl) """ name = "winogrande" num_few_shot = 5 def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: ds = load_dataset("allenai/winogrande", "winogrande_xl", split="validation") if n is not None: ds = ds.shuffle(seed=seed).select(range(min(n, len(ds)))) train_ds = load_dataset("allenai/winogrande", "winogrande_xl", split="train") train_ds = train_ds.shuffle(seed=seed).select(range(self.num_few_shot)) def format_row(row): sentence = row["sentence"] option1 = row["option1"] option2 = row["option2"] answer = int(row["answer"]) - 1 # 1-indexed -> 0-indexed # Replace _ with each option sent1 = sentence.replace("_", option1) sent2 = sentence.replace("_", option2) context = f"Which makes more sense?\nA) {sent1}\nB) {sent2}\nAnswer:" continuations = [" A", " B"] return { "context": context, "continuations": continuations, "gold_idx": answer, } train_examples = [format_row(row) for row in train_ds] examples = [format_row(row) for row in ds] return self.format_few_shot(examples, train_examples) class MMLU(BenchmarkTask): """ Massive Multitask Language Understanding. 4-choice. Dataset: cais/mmlu (all subjects) """ name = "mmlu" num_few_shot = 5 def __init__(self, subject: Optional[str] = None): """If subject is None, sample across all subjects.""" self.subject = subject if subject: self.name = f"mmlu-{subject}" def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: if self.subject: ds = load_dataset("cais/mmlu", self.subject, split="test") train_ds = load_dataset("cais/mmlu", self.subject, split="validation") else: ds = load_dataset("cais/mmlu", "all", split="test") train_ds = load_dataset("cais/mmlu", "all", split="validation") if n is not None: ds = ds.shuffle(seed=seed).select(range(min(n, len(ds)))) train_ds = train_ds.shuffle(seed=seed).select(range(min(self.num_few_shot, len(train_ds)))) def format_row(row): question = row["question"] choices = row["choices"] answer = row["answer"] # 0-3 labels = ["A", "B", "C", "D"] choice_str = "\n".join(f"{l}) {c}" for l, c in zip(labels, choices)) context = f"Question: {question}\n{choice_str}\nAnswer:" continuations = [f" {l}" for l in labels] return { "context": context, "continuations": continuations, "gold_idx": answer, } train_examples = [format_row(row) for row in train_ds] examples = [format_row(row) for row in ds] return self.format_few_shot(examples, train_examples) class HaluEval(BenchmarkTask): """ HaluEval: Hallucination Evaluation. Dataset: pminervini/HaluEval (qa_samples) Tests whether the model can identify hallucinated answers. """ name = "halueval" num_few_shot = 2 def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]: ds = load_dataset("pminervini/HaluEval", "qa_samples", split="data") if n is not None: ds = ds.shuffle(seed=seed).select(range(min(n, len(ds)))) examples = [] for row in ds: question = row["question"] knowledge = row.get("knowledge", "") right_answer = row.get("right_answer", "") hallucinated_answer = row.get("hallucinated_answer", "") if not right_answer or not hallucinated_answer: continue # Randomly order the options rng = random.Random(seed + len(examples)) options = [(right_answer, 0), (hallucinated_answer, 1)] if rng.random() > 0.5: options = options[::-1] # Gold is the correct (non-hallucinated) answer gold_idx = 0 if options[0][1] == 0 else 1 context_parts = [f"Question: {question}"] if knowledge: context_parts.insert(0, f"Knowledge: {knowledge[:300]}") context_parts.append(f"Answer A: {options[0][0][:200]}") context_parts.append(f"Answer B: {options[1][0][:200]}") context_parts.append("Which answer is correct? Answer:") context = "\n".join(context_parts) continuations = [" A", " B"] examples.append({ "context": context, "continuations": continuations, "gold_idx": gold_idx, }) return examples # Task registry for easy lookup TASK_REGISTRY = { "hellaswag": HellaSwag, "arc-easy": lambda: ARC("ARC-Easy"), "arc-challenge": lambda: ARC("ARC-Challenge"), "piqa": PIQA, "winogrande": WinoGrande, "mmlu": MMLU, "halueval": HaluEval, }