Trouter-Library commited on
Commit
7440c87
·
verified ·
1 Parent(s): e2035b7

Create benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +377 -0
benchmark.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark script for evaluating Helion-V2 on standard benchmarks.
3
+ Includes MMLU, HellaSwag, ARC, TruthfulQA, GSM8K, and HumanEval.
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import numpy as np
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from datasets import load_dataset
11
+ from tqdm import tqdm
12
+ import argparse
13
+ from typing import Dict, List, Tuple
14
+ import re
15
+
16
+
17
+ class BenchmarkEvaluator:
18
+ """Evaluator for running benchmarks on Helion-V2."""
19
+
20
+ def __init__(self, model_name: str, device: str = "cuda"):
21
+ """Initialize evaluator with model."""
22
+ print(f"Loading model: {model_name}")
23
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ self.model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ torch_dtype=torch.float16,
27
+ device_map=device,
28
+ )
29
+ self.model.eval()
30
+ self.device = device
31
+
32
+ def evaluate_mmlu(self, num_shots: int = 5) -> float:
33
+ """
34
+ Evaluate on MMLU (Massive Multitask Language Understanding).
35
+
36
+ Args:
37
+ num_shots: Number of examples for few-shot learning
38
+
39
+ Returns:
40
+ Average accuracy across all subjects
41
+ """
42
+ print("\n=== Evaluating MMLU ===")
43
+ dataset = load_dataset("cais/mmlu", "all", split="test")
44
+
45
+ correct = 0
46
+ total = 0
47
+
48
+ for item in tqdm(dataset, desc="MMLU"):
49
+ question = item["question"]
50
+ choices = item["choices"]
51
+ answer = item["answer"]
52
+
53
+ # Format prompt
54
+ prompt = f"Question: {question}\n"
55
+ for i, choice in enumerate(choices):
56
+ prompt += f"{chr(65+i)}. {choice}\n"
57
+ prompt += "Answer:"
58
+
59
+ # Get model prediction
60
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
61
+ with torch.no_grad():
62
+ outputs = self.model.generate(
63
+ **inputs,
64
+ max_new_tokens=1,
65
+ temperature=0.0,
66
+ do_sample=False,
67
+ )
68
+
69
+ response = self.tokenizer.decode(outputs[0][-1:], skip_special_tokens=True).strip()
70
+
71
+ # Check if correct
72
+ if response.upper() in ['A', 'B', 'C', 'D']:
73
+ predicted_idx = ord(response.upper()) - ord('A')
74
+ if predicted_idx == answer:
75
+ correct += 1
76
+
77
+ total += 1
78
+
79
+ if total >= 1000: # Limit for testing
80
+ break
81
+
82
+ accuracy = correct / total if total > 0 else 0
83
+ print(f"MMLU Accuracy: {accuracy:.2%} ({correct}/{total})")
84
+ return accuracy
85
+
86
+ def evaluate_hellaswag(self) -> float:
87
+ """
88
+ Evaluate on HellaSwag (commonsense reasoning).
89
+
90
+ Returns:
91
+ Accuracy on HellaSwag
92
+ """
93
+ print("\n=== Evaluating HellaSwag ===")
94
+ dataset = load_dataset("Rowan/hellaswag", split="validation")
95
+
96
+ correct = 0
97
+ total = 0
98
+
99
+ for item in tqdm(dataset[:1000], desc="HellaSwag"):
100
+ context = item["ctx"]
101
+ endings = item["endings"]
102
+ label = int(item["label"])
103
+
104
+ # Calculate log-likelihood for each ending
105
+ best_score = float('-inf')
106
+ best_idx = -1
107
+
108
+ for idx, ending in enumerate(endings):
109
+ full_text = context + " " + ending
110
+ inputs = self.tokenizer(full_text, return_tensors="pt").to(self.device)
111
+
112
+ with torch.no_grad():
113
+ outputs = self.model(**inputs, labels=inputs["input_ids"])
114
+ score = -outputs.loss.item()
115
+
116
+ if score > best_score:
117
+ best_score = score
118
+ best_idx = idx
119
+
120
+ if best_idx == label:
121
+ correct += 1
122
+ total += 1
123
+
124
+ accuracy = correct / total if total > 0 else 0
125
+ print(f"HellaSwag Accuracy: {accuracy:.2%} ({correct}/{total})")
126
+ return accuracy
127
+
128
+ def evaluate_arc(self, challenge: bool = True) -> float:
129
+ """
130
+ Evaluate on ARC (AI2 Reasoning Challenge).
131
+
132
+ Args:
133
+ challenge: Use ARC-Challenge (harder) vs ARC-Easy
134
+
135
+ Returns:
136
+ Accuracy on ARC
137
+ """
138
+ subset = "ARC-Challenge" if challenge else "ARC-Easy"
139
+ print(f"\n=== Evaluating {subset} ===")
140
+
141
+ dataset = load_dataset("ai2_arc", subset, split="test")
142
+
143
+ correct = 0
144
+ total = 0
145
+
146
+ for item in tqdm(dataset, desc=subset):
147
+ question = item["question"]
148
+ choices = item["choices"]["text"]
149
+ labels = item["choices"]["label"]
150
+ answer_key = item["answerKey"]
151
+
152
+ # Format prompt
153
+ prompt = f"Question: {question}\n"
154
+ for label, choice in zip(labels, choices):
155
+ prompt += f"{label}. {choice}\n"
156
+ prompt += "Answer:"
157
+
158
+ # Get model prediction
159
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
160
+ with torch.no_grad():
161
+ outputs = self.model.generate(
162
+ **inputs,
163
+ max_new_tokens=5,
164
+ temperature=0.0,
165
+ do_sample=False,
166
+ )
167
+
168
+ response = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
169
+
170
+ # Extract answer
171
+ predicted = response[0] if response else ""
172
+
173
+ if predicted.upper() == answer_key.upper():
174
+ correct += 1
175
+
176
+ total += 1
177
+
178
+ accuracy = correct / total if total > 0 else 0
179
+ print(f"{subset} Accuracy: {accuracy:.2%} ({correct}/{total})")
180
+ return accuracy
181
+
182
+ def evaluate_gsm8k(self) -> float:
183
+ """
184
+ Evaluate on GSM8K (grade school math).
185
+
186
+ Returns:
187
+ Accuracy on GSM8K
188
+ """
189
+ print("\n=== Evaluating GSM8K ===")
190
+ dataset = load_dataset("gsm8k", "main", split="test")
191
+
192
+ correct = 0
193
+ total = 0
194
+
195
+ for item in tqdm(dataset[:500], desc="GSM8K"): # Sample for speed
196
+ question = item["question"]
197
+ answer = item["answer"].split("####")[-1].strip()
198
+
199
+ # Format with chain-of-thought prompt
200
+ prompt = f"Question: {question}\nLet's solve this step by step:\n"
201
+
202
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
203
+ with torch.no_grad():
204
+ outputs = self.model.generate(
205
+ **inputs,
206
+ max_new_tokens=400,
207
+ temperature=0.0,
208
+ do_sample=False,
209
+ )
210
+
211
+ response = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
212
+
213
+ # Extract numerical answer
214
+ numbers = re.findall(r'-?\d+\.?\d*', response)
215
+ if numbers:
216
+ predicted = numbers[-1] # Take last number
217
+ if predicted.replace('.', '').replace('-', '').isdigit():
218
+ if float(predicted) == float(answer):
219
+ correct += 1
220
+
221
+ total += 1
222
+
223
+ accuracy = correct / total if total > 0 else 0
224
+ print(f"GSM8K Accuracy: {accuracy:.2%} ({correct}/{total})")
225
+ return accuracy
226
+
227
+ def evaluate_truthfulqa(self) -> float:
228
+ """
229
+ Evaluate on TruthfulQA (truthfulness and informativeness).
230
+
231
+ Returns:
232
+ MC2 accuracy
233
+ """
234
+ print("\n=== Evaluating TruthfulQA ===")
235
+ dataset = load_dataset("truthful_qa", "multiple_choice", split="validation")
236
+
237
+ correct = 0
238
+ total = 0
239
+
240
+ for item in tqdm(dataset, desc="TruthfulQA"):
241
+ question = item["question"]
242
+ mc2_targets = item["mc2_targets"]
243
+ choices = mc2_targets["choices"]
244
+ labels = mc2_targets["labels"]
245
+
246
+ # Format prompt
247
+ prompt = f"Question: {question}\n"
248
+ for i, choice in enumerate(choices):
249
+ prompt += f"{i+1}. {choice}\n"
250
+ prompt += "Select all correct answers:\n"
251
+
252
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
253
+ with torch.no_grad():
254
+ outputs = self.model.generate(
255
+ **inputs,
256
+ max_new_tokens=100,
257
+ temperature=0.0,
258
+ do_sample=False,
259
+ )
260
+
261
+ response = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
262
+
263
+ # Simple scoring: if any correct answer is mentioned
264
+ response_lower = response.lower()
265
+ found_correct = False
266
+ for idx, (choice, label) in enumerate(zip(choices, labels)):
267
+ if label == 1 and (choice.lower() in response_lower or str(idx+1) in response):
268
+ found_correct = True
269
+ break
270
+
271
+ if found_correct:
272
+ correct += 1
273
+ total += 1
274
+
275
+ accuracy = correct / total if total > 0 else 0
276
+ print(f"TruthfulQA MC2 Accuracy: {accuracy:.2%} ({correct}/{total})")
277
+ return accuracy
278
+
279
+ def run_all_benchmarks(self) -> Dict[str, float]:
280
+ """
281
+ Run all available benchmarks.
282
+
283
+ Returns:
284
+ Dictionary of benchmark results
285
+ """
286
+ results = {}
287
+
288
+ try:
289
+ results["MMLU"] = self.evaluate_mmlu()
290
+ except Exception as e:
291
+ print(f"MMLU evaluation failed: {e}")
292
+ results["MMLU"] = 0.0
293
+
294
+ try:
295
+ results["HellaSwag"] = self.evaluate_hellaswag()
296
+ except Exception as e:
297
+ print(f"HellaSwag evaluation failed: {e}")
298
+ results["HellaSwag"] = 0.0
299
+
300
+ try:
301
+ results["ARC-Challenge"] = self.evaluate_arc(challenge=True)
302
+ except Exception as e:
303
+ print(f"ARC-Challenge evaluation failed: {e}")
304
+ results["ARC-Challenge"] = 0.0
305
+
306
+ try:
307
+ results["GSM8K"] = self.evaluate_gsm8k()
308
+ except Exception as e:
309
+ print(f"GSM8K evaluation failed: {e}")
310
+ results["GSM8K"] = 0.0
311
+
312
+ try:
313
+ results["TruthfulQA"] = self.evaluate_truthfulqa()
314
+ except Exception as e:
315
+ print(f"TruthfulQA evaluation failed: {e}")
316
+ results["TruthfulQA"] = 0.0
317
+
318
+ return results
319
+
320
+
321
+ def main():
322
+ parser = argparse.ArgumentParser(description="Benchmark Helion-V2")
323
+ parser.add_argument(
324
+ "--model",
325
+ type=str,
326
+ default="DeepXR/Helion-V2",
327
+ help="Model name or path"
328
+ )
329
+ parser.add_argument(
330
+ "--device",
331
+ type=str,
332
+ default="cuda",
333
+ help="Device to use"
334
+ )
335
+ parser.add_argument(
336
+ "--benchmark",
337
+ type=str,
338
+ choices=["all", "mmlu", "hellaswag", "arc", "gsm8k", "truthfulqa"],
339
+ default="all",
340
+ help="Benchmark to run"
341
+ )
342
+ parser.add_argument(
343
+ "--output",
344
+ type=str,
345
+ default="benchmark_results.json",
346
+ help="Output file for results"
347
+ )
348
+
349
+ args = parser.parse_args()
350
+
351
+ evaluator = BenchmarkEvaluator(args.model, args.device)
352
+
353
+ if args.benchmark == "all":
354
+ results = evaluator.run_all_benchmarks()
355
+ else:
356
+ benchmark_map = {
357
+ "mmlu": evaluator.evaluate_mmlu,
358
+ "hellaswag": evaluator.evaluate_hellaswag,
359
+ "arc": evaluator.evaluate_arc,
360
+ "gsm8k": evaluator.evaluate_gsm8k,
361
+ "truthfulqa": evaluator.evaluate_truthfulqa,
362
+ }
363
+ score = benchmark_map[args.benchmark]()
364
+ results = {args.benchmark: score}
365
+
366
+ # Save results
367
+ with open(args.output, 'w') as f:
368
+ json.dump(results, f, indent=2)
369
+
370
+ print(f"\n=== Final Results ===")
371
+ for benchmark, score in results.items():
372
+ print(f"{benchmark}: {score:.2%}")
373
+ print(f"\nResults saved to {args.output}")
374
+
375
+
376
+ if __name__ == "__main__":
377
+ main()