theapemachine commited on
Commit
e014dff
·
verified ·
1 Parent(s): 81ff944

Add benchmark harness: tasks.py - Standard NLP task definitions

Browse files
Files changed (1) hide show
  1. benchmark/tasks.py +336 -0
benchmark/tasks.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standard NLP benchmark task definitions.
3
+
4
+ Each task loads from HuggingFace datasets and formats examples for
5
+ log-likelihood scoring.
6
+
7
+ Supported tasks:
8
+ - HellaSwag (commonsense NLI, 4-choice)
9
+ - ARC-Easy / ARC-Challenge (science QA, 3-5 choices)
10
+ - PIQA (physical intuition, 2-choice)
11
+ - WinoGrande (coreference, 2-choice)
12
+ - MMLU (multi-domain knowledge, 4-choice)
13
+ - HaluEval-QA (hallucination detection, 2-choice)
14
+ """
15
+
16
+ import random
17
+ from abc import ABC, abstractmethod
18
+ from typing import List, Dict, Tuple, Optional, Any
19
+ from datasets import load_dataset
20
+
21
+
22
+ class BenchmarkTask(ABC):
23
+ """Base class for benchmark tasks."""
24
+
25
+ name: str = "base"
26
+ num_few_shot: int = 0
27
+
28
+ @abstractmethod
29
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
30
+ """
31
+ Load and format examples.
32
+
33
+ Returns list of dicts, each with:
34
+ - "context": str — the prompt/context
35
+ - "continuations": List[str] — possible completions
36
+ - "gold_idx": int — index of the correct continuation
37
+ """
38
+ ...
39
+
40
+ def format_few_shot(self, examples: List[Dict], train_examples: List[Dict]) -> List[Dict]:
41
+ """Prepend few-shot examples to each test example's context."""
42
+ if not train_examples or self.num_few_shot == 0:
43
+ return examples
44
+
45
+ # Build few-shot prefix
46
+ shots = train_examples[:self.num_few_shot]
47
+ prefix = ""
48
+ for shot in shots:
49
+ prefix += shot["context"] + shot["continuations"][shot["gold_idx"]] + "\n\n"
50
+
51
+ for ex in examples:
52
+ ex["context"] = prefix + ex["context"]
53
+
54
+ return examples
55
+
56
+
57
+ class HellaSwag(BenchmarkTask):
58
+ """
59
+ HellaSwag: Can a Machine Really Finish Your Sentence?
60
+ 4-choice commonsense NLI. Dataset: Rowan/hellaswag
61
+ """
62
+ name = "hellaswag"
63
+ num_few_shot = 5
64
+
65
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
66
+ ds = load_dataset("Rowan/hellaswag", split="validation")
67
+
68
+ if n is not None:
69
+ ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
70
+
71
+ examples = []
72
+ few_shot_ds = load_dataset("Rowan/hellaswag", split="train")
73
+ few_shot_ds = few_shot_ds.shuffle(seed=seed).select(range(self.num_few_shot))
74
+
75
+ train_examples = []
76
+ for row in few_shot_ds:
77
+ ctx = row["ctx"]
78
+ endings = row["endings"]
79
+ gold = int(row["label"])
80
+ train_examples.append({
81
+ "context": ctx,
82
+ "continuations": endings,
83
+ "gold_idx": gold,
84
+ })
85
+
86
+ for row in ds:
87
+ ctx = row["ctx"]
88
+ endings = row["endings"]
89
+ gold = int(row["label"])
90
+ examples.append({
91
+ "context": ctx,
92
+ "continuations": endings,
93
+ "gold_idx": gold,
94
+ })
95
+
96
+ return self.format_few_shot(examples, train_examples)
97
+
98
+
99
+ class ARC(BenchmarkTask):
100
+ """
101
+ AI2 Reasoning Challenge. Dataset: allenai/ai2_arc
102
+ Subsets: ARC-Easy, ARC-Challenge
103
+ """
104
+ name = "arc"
105
+ num_few_shot = 5
106
+
107
+ def __init__(self, subset: str = "ARC-Easy"):
108
+ self.subset = subset
109
+ self.name = f"arc-{'easy' if 'Easy' in subset else 'challenge'}"
110
+
111
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
112
+ ds = load_dataset("allenai/ai2_arc", self.subset, split="test")
113
+
114
+ if n is not None:
115
+ ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
116
+
117
+ # Few-shot from train
118
+ train_ds = load_dataset("allenai/ai2_arc", self.subset, split="train")
119
+ train_ds = train_ds.shuffle(seed=seed).select(range(self.num_few_shot))
120
+
121
+ def format_row(row):
122
+ question = row["question"]
123
+ choices = row["choices"]
124
+ labels = choices["label"]
125
+ texts = choices["text"]
126
+ answer_key = row["answerKey"]
127
+
128
+ # Map answer key to index
129
+ gold_idx = labels.index(answer_key) if answer_key in labels else 0
130
+
131
+ # Format as "Question: ...\nA) ... B) ...\nAnswer:"
132
+ choice_str = " ".join(f"{l}) {t}" for l, t in zip(labels, texts))
133
+ context = f"Question: {question}\n{choice_str}\nAnswer:"
134
+
135
+ continuations = [f" {t}" for t in texts]
136
+
137
+ return {
138
+ "context": context,
139
+ "continuations": continuations,
140
+ "gold_idx": gold_idx,
141
+ }
142
+
143
+ train_examples = [format_row(row) for row in train_ds]
144
+ examples = [format_row(row) for row in ds]
145
+
146
+ return self.format_few_shot(examples, train_examples)
147
+
148
+
149
+ class PIQA(BenchmarkTask):
150
+ """
151
+ Physical Intuition QA. 2-choice.
152
+ Dataset: gimmaru/piqa (parquet mirror — ybisk/piqa loading script no longer supported)
153
+ """
154
+ name = "piqa"
155
+ num_few_shot = 0 # No train split in mirror; use 0-shot
156
+
157
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
158
+ ds = load_dataset("gimmaru/piqa", split="validation")
159
+
160
+ if n is not None:
161
+ ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
162
+
163
+ def format_row(row):
164
+ goal = row["goal"]
165
+ sol1 = row["sol1"]
166
+ sol2 = row["sol2"]
167
+ gold = row["label"] # 0 or 1
168
+
169
+ context = f"Goal: {goal}\nSolution 1: {sol1}\nSolution 2: {sol2}\nThe better solution is Solution"
170
+ continuations = [" 1", " 2"]
171
+
172
+ return {
173
+ "context": context,
174
+ "continuations": continuations,
175
+ "gold_idx": gold,
176
+ }
177
+
178
+ examples = [format_row(row) for row in ds]
179
+ return examples
180
+
181
+
182
+ class WinoGrande(BenchmarkTask):
183
+ """
184
+ WinoGrande: Winograd-style coreference. 2-choice.
185
+ Dataset: allenai/winogrande (winogrande_xl)
186
+ """
187
+ name = "winogrande"
188
+ num_few_shot = 5
189
+
190
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
191
+ ds = load_dataset("allenai/winogrande", "winogrande_xl", split="validation")
192
+
193
+ if n is not None:
194
+ ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
195
+
196
+ train_ds = load_dataset("allenai/winogrande", "winogrande_xl", split="train")
197
+ train_ds = train_ds.shuffle(seed=seed).select(range(self.num_few_shot))
198
+
199
+ def format_row(row):
200
+ sentence = row["sentence"]
201
+ option1 = row["option1"]
202
+ option2 = row["option2"]
203
+ answer = int(row["answer"]) - 1 # 1-indexed -> 0-indexed
204
+
205
+ # Replace _ with each option
206
+ sent1 = sentence.replace("_", option1)
207
+ sent2 = sentence.replace("_", option2)
208
+
209
+ context = f"Which makes more sense?\nA) {sent1}\nB) {sent2}\nAnswer:"
210
+ continuations = [" A", " B"]
211
+
212
+ return {
213
+ "context": context,
214
+ "continuations": continuations,
215
+ "gold_idx": answer,
216
+ }
217
+
218
+ train_examples = [format_row(row) for row in train_ds]
219
+ examples = [format_row(row) for row in ds]
220
+
221
+ return self.format_few_shot(examples, train_examples)
222
+
223
+
224
+ class MMLU(BenchmarkTask):
225
+ """
226
+ Massive Multitask Language Understanding. 4-choice.
227
+ Dataset: cais/mmlu (all subjects)
228
+ """
229
+ name = "mmlu"
230
+ num_few_shot = 5
231
+
232
+ def __init__(self, subject: Optional[str] = None):
233
+ """If subject is None, sample across all subjects."""
234
+ self.subject = subject
235
+ if subject:
236
+ self.name = f"mmlu-{subject}"
237
+
238
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
239
+ if self.subject:
240
+ ds = load_dataset("cais/mmlu", self.subject, split="test")
241
+ train_ds = load_dataset("cais/mmlu", self.subject, split="validation")
242
+ else:
243
+ ds = load_dataset("cais/mmlu", "all", split="test")
244
+ train_ds = load_dataset("cais/mmlu", "all", split="validation")
245
+
246
+ if n is not None:
247
+ ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
248
+
249
+ train_ds = train_ds.shuffle(seed=seed).select(range(min(self.num_few_shot, len(train_ds))))
250
+
251
+ def format_row(row):
252
+ question = row["question"]
253
+ choices = row["choices"]
254
+ answer = row["answer"] # 0-3
255
+
256
+ labels = ["A", "B", "C", "D"]
257
+ choice_str = "\n".join(f"{l}) {c}" for l, c in zip(labels, choices))
258
+ context = f"Question: {question}\n{choice_str}\nAnswer:"
259
+ continuations = [f" {l}" for l in labels]
260
+
261
+ return {
262
+ "context": context,
263
+ "continuations": continuations,
264
+ "gold_idx": answer,
265
+ }
266
+
267
+ train_examples = [format_row(row) for row in train_ds]
268
+ examples = [format_row(row) for row in ds]
269
+
270
+ return self.format_few_shot(examples, train_examples)
271
+
272
+
273
+ class HaluEval(BenchmarkTask):
274
+ """
275
+ HaluEval: Hallucination Evaluation.
276
+ Dataset: pminervini/HaluEval (qa_samples)
277
+
278
+ Tests whether the model can identify hallucinated answers.
279
+ """
280
+ name = "halueval"
281
+ num_few_shot = 2
282
+
283
+ def load_examples(self, n: Optional[int] = None, seed: int = 42) -> List[Dict]:
284
+ ds = load_dataset("pminervini/HaluEval", "qa_samples", split="data")
285
+
286
+ if n is not None:
287
+ ds = ds.shuffle(seed=seed).select(range(min(n, len(ds))))
288
+
289
+ examples = []
290
+ for row in ds:
291
+ question = row["question"]
292
+ knowledge = row.get("knowledge", "")
293
+ right_answer = row.get("right_answer", "")
294
+ hallucinated_answer = row.get("hallucinated_answer", "")
295
+
296
+ if not right_answer or not hallucinated_answer:
297
+ continue
298
+
299
+ # Randomly order the options
300
+ rng = random.Random(seed + len(examples))
301
+ options = [(right_answer, 0), (hallucinated_answer, 1)]
302
+ if rng.random() > 0.5:
303
+ options = options[::-1]
304
+
305
+ # Gold is the correct (non-hallucinated) answer
306
+ gold_idx = 0 if options[0][1] == 0 else 1
307
+
308
+ context_parts = [f"Question: {question}"]
309
+ if knowledge:
310
+ context_parts.insert(0, f"Knowledge: {knowledge[:300]}")
311
+ context_parts.append(f"Answer A: {options[0][0][:200]}")
312
+ context_parts.append(f"Answer B: {options[1][0][:200]}")
313
+ context_parts.append("Which answer is correct? Answer:")
314
+
315
+ context = "\n".join(context_parts)
316
+ continuations = [" A", " B"]
317
+
318
+ examples.append({
319
+ "context": context,
320
+ "continuations": continuations,
321
+ "gold_idx": gold_idx,
322
+ })
323
+
324
+ return examples
325
+
326
+
327
+ # Task registry for easy lookup
328
+ TASK_REGISTRY = {
329
+ "hellaswag": HellaSwag,
330
+ "arc-easy": lambda: ARC("ARC-Easy"),
331
+ "arc-challenge": lambda: ARC("ARC-Challenge"),
332
+ "piqa": PIQA,
333
+ "winogrande": WinoGrande,
334
+ "mmlu": MMLU,
335
+ "halueval": HaluEval,
336
+ }