HarshadAbleCredit commited on
Commit
b46a1aa
·
verified ·
1 Parent(s): 1f7ddb3

Upload benchmark_gsm8k.py

Browse files

Benchmarking script. Use 0 shot setting!

Files changed (1) hide show
  1. benchmark_gsm8k.py +512 -0
benchmark_gsm8k.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmarking a language model on the GSM8K dataset with batch processing.
3
+
4
+ This script benchmarks a pre-trained causal language model on the GSM8K test set.
5
+ It is adapted from https://github.com/Guangxuan-Xiao/GSM8K-eval.
6
+ """
7
+
8
+ import argparse
9
+ import gzip
10
+ import json
11
+ import os
12
+ import os.path as osp
13
+ import random
14
+ import re
15
+ import ssl
16
+ import urllib.request
17
+ from typing import List, Dict
18
+
19
+ import numpy as np
20
+ import torch
21
+ import transformers
22
+ from tqdm import tqdm
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM
24
+
25
+ # Set logging verbosity to error-level (suppress warnings/info)
26
+ transformers.logging.set_verbosity(40)
27
+
28
+ # Regular expressions for answer extraction
29
+ ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
30
+ BOXED_PATTERN = re.compile(r"\\boxed\{(.*?)\}")
31
+ ANSWER_TAG_PATTERN = re.compile(r"<answer>(.*?)</answer>", re.IGNORECASE | re.DOTALL)
32
+ THINK_ANSWER_PATTERN = re.compile(
33
+ r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL
34
+ ) # simplified pattern
35
+
36
+ DEBUG = True
37
+ USE_COT = False
38
+ ANSWER_TRIGGER = "The answer is"
39
+
40
+
41
+ def download_url(url: str, folder: str = "folder") -> str:
42
+ """
43
+ Download a file from a URL and save it to the specified folder.
44
+
45
+ If the file already exists in the folder, it will not be re-downloaded.
46
+
47
+ Args:
48
+ url (str): The URL of the target file.
49
+ folder (str): The folder where the file will be saved.
50
+
51
+ Returns:
52
+ str: The file path of the downloaded (or existing) file.
53
+ """
54
+ file = url.rpartition("/")[2]
55
+ file = file if file[0] == "?" else file.split("?")[0]
56
+ path = osp.join(folder, file)
57
+ if osp.exists(path):
58
+ print(f"File {file} exists, using existing file.")
59
+ return path
60
+
61
+ print(f"Downloading {url}")
62
+ os.makedirs(folder, exist_ok=True)
63
+ ctx = ssl._create_unverified_context() # pylint:disable=protected-access
64
+ data = urllib.request.urlopen(url, context=ctx)
65
+ with open(path, "wb") as f:
66
+ f.write(data.read())
67
+
68
+ return path
69
+
70
+
71
+ def load_jsonl(
72
+ file_path: str,
73
+ instruction: str = "instruction",
74
+ inp: str = "input",
75
+ output: str = "output",
76
+ category: str = "category",
77
+ is_gzip: bool = False,
78
+ ) -> list:
79
+ """
80
+ Load a JSONL file into a list of dictionaries.
81
+
82
+ Each line of the file should be a JSON object. The function extracts
83
+ the values for keys corresponding to instruction, input, output, and category.
84
+ If a key is missing, its value is set to None.
85
+
86
+ Args:
87
+ file_path (str): The path to the JSONL file.
88
+ instruction (str): The key for the instruction/question text.
89
+ inp (str): The key for the input (if any).
90
+ output (str): The key for the expected output/answer.
91
+ category (str): The key for the category.
92
+ is_gzip (bool): Whether the file is gzip-compressed.
93
+
94
+ Returns:
95
+ list: A list of dictionaries with the extracted keys.
96
+ """
97
+ data_list = []
98
+ open_func = open if not is_gzip else gzip.open
99
+ with open_func(file_path, "r") as f:
100
+ for line in f:
101
+ item = json.loads(line)
102
+ new_item = {
103
+ "instruction": item.get(instruction),
104
+ "input": item.get(inp),
105
+ "output": item.get(output),
106
+ "category": item.get(category),
107
+ }
108
+ data_list.append(new_item)
109
+ return data_list
110
+
111
+
112
+ def clean_answer(answer: str) -> str:
113
+ """standard cleanups"""
114
+ answer = answer.lower()
115
+ answer = answer.rstrip(".").replace(",", "")
116
+ answer = answer.strip()
117
+ return answer
118
+
119
+
120
+ def extract_answer(completion: str) -> str:
121
+ """
122
+ Extract the answer from a formatted output string.
123
+
124
+ The function attempts to find an answer using:
125
+ 1. A boxed format (e.g., \boxed{...}).
126
+ 2. A pattern matching using '####' followed by the answer.
127
+ 3. A fallback using <answer>...</answer> tags (taking the last token of the last pair).
128
+
129
+ The extracted answer is cleaned by removing trailing periods and commas.
130
+
131
+ Args:
132
+ completion (str): The text output (from the model or dataset) containing the answer.
133
+
134
+ Returns:
135
+ str: The cleaned answer extracted from the text.
136
+ """
137
+ answer = ""
138
+ boxed_match = BOXED_PATTERN.search(completion)
139
+ answer_match = ANS_RE.search(completion)
140
+ if boxed_match:
141
+ answer = boxed_match.group(1)
142
+ elif answer_match:
143
+ answer = answer_match.group(1)
144
+ else:
145
+ # Fallback: extract all matches within <answer> tags and process the last one.
146
+ answer_tags = ANSWER_TAG_PATTERN.findall(completion)
147
+ if answer_tags:
148
+ last_answer_content = answer_tags[-1]
149
+ tokens = last_answer_content.split()
150
+ if tokens:
151
+ answer = tokens[-1]
152
+ return answer
153
+
154
+
155
+ def check(ground_truth: str, completion: str) -> bool:
156
+ """
157
+ Compare the ground truth answer with the answer extracted from the given completion text.
158
+
159
+ Args:
160
+ ground_truth (str): The expected (correct) answer.
161
+ completion (str): The text from which to extract the answer.
162
+
163
+ Returns:
164
+ bool: True if the extracted answer matches the ground truth; False otherwise.
165
+ """
166
+
167
+ return clean_answer(ground_truth) == clean_answer(completion)
168
+
169
+
170
+ # In-context examples for demonstration (8-shot prompting)
171
+
172
+ # pylint:disable=line-too-long
173
+ demo_examples = [
174
+ {
175
+ "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
176
+ "think": "<think>There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.</think>",
177
+ "answer": "<answer>\\boxed{5}</answer>",
178
+ },
179
+ {
180
+ "question": "If you have 4 apples and you buy 3 more, how many apples do you have?",
181
+ "think": "<think>You start with 4 apples and then add 3 more. 4 + 3 = 7.</think>",
182
+ "answer": "<answer>\\boxed{7}</answer>",
183
+ },
184
+ {
185
+ "question": "John had 10 candies and gave 3 to his friend. How many candies does John have left?",
186
+ "think": "<think>John had 10 candies and after giving away 3, he is left with 10 - 3 = 7 candies.</think>",
187
+ "answer": "<answer>\\boxed{7}</answer>",
188
+ },
189
+ {
190
+ "question": "There are 5 birds on a tree. If 2 fly away, how many birds remain?",
191
+ "think": "<think>5 birds minus 2 that fly away leaves 3 birds remaining.</think>",
192
+ "answer": "<answer>\\boxed{3}</answer>",
193
+ },
194
+ {
195
+ "question": "A basket has 6 oranges. If 4 oranges are taken out, how many oranges are left in the basket?",
196
+ "think": "<think>6 oranges minus 4 equals 2 oranges remaining.</think>",
197
+ "answer": "<answer>\\boxed{2}</answer>",
198
+ },
199
+ {
200
+ "question": "There are 8 books on the shelf. If 5 are removed, how many books remain?",
201
+ "think": "<think>8 books minus 5 removed gives 3 books remaining.</think>",
202
+ "answer": "<answer>\\boxed{3}</answer>",
203
+ },
204
+ {
205
+ "question": "If a car travels 60 miles in 1 hour, how far does it travel in 2 hours?",
206
+ "think": "<think>At 60 miles per hour, in 2 hours the car travels 60 x 2 = 120 miles.</think>",
207
+ "answer": "<answer>\\boxed{120}</answer>",
208
+ },
209
+ {
210
+ "question": "If a cake is cut into 8 pieces and you eat 3, how many pieces remain?",
211
+ "think": "<think>8 pieces minus 3 eaten equals 5 remaining pieces.</think>",
212
+ "answer": "<answer>\\boxed{5}</answer>",
213
+ },
214
+ ]
215
+ # pylint:enable=line-too-long
216
+
217
+
218
+ def seed_everything(seed: int) -> None:
219
+ """
220
+ Set seeds for various random number generators to ensure reproducibility.
221
+
222
+ Args:
223
+ seed (int): The seed value to be used.
224
+ """
225
+ random.seed(seed)
226
+ os.environ["PYTHONHASHSEED"] = str(seed)
227
+ np.random.seed(seed)
228
+ torch.manual_seed(seed)
229
+ torch.cuda.manual_seed(seed)
230
+ torch.backends.cudnn.deterministic = True
231
+ torch.backends.cudnn.benchmark = True
232
+
233
+
234
+ def load(model_name_or_path: str):
235
+ """
236
+ Load a pre-trained causal language model and its tokenizer.
237
+
238
+ The function downloads the model/tokenizer from the given checkpoint path.
239
+ It also ensures that the tokenizer has a pad token (defaulting to eos_token_id or 0 if missing)
240
+ and sets the model to evaluation mode.
241
+
242
+ Args:
243
+ model_name_or_path (str): The path or identifier of the model checkpoint.
244
+
245
+ Returns:
246
+ tuple: A tuple containing the model and tokenizer.
247
+ """
248
+ print(f"Loading model from {model_name_or_path} ...")
249
+ tokenizer = AutoTokenizer.from_pretrained(
250
+ model_name_or_path, trust_remote_code=False
251
+ )
252
+ model = AutoModelForCausalLM.from_pretrained(
253
+ model_name_or_path,
254
+ device_map="auto",
255
+ torch_dtype=torch.bfloat16,
256
+ trust_remote_code=False,
257
+ )
258
+ if tokenizer.pad_token_id is None:
259
+ tokenizer.pad_token_id = (
260
+ tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
261
+ )
262
+
263
+ model.eval()
264
+ return model, tokenizer
265
+
266
+
267
+ def parse_args():
268
+ """
269
+ Parse command-line arguments.
270
+
271
+ Returns:
272
+ argparse.Namespace: Parsed command-line arguments including model checkpoint path,
273
+ data root, seed, output directory, and an optional quantized model load path.
274
+ """
275
+ parser = argparse.ArgumentParser()
276
+ parser.add_argument(
277
+ "--model_name_or_path",
278
+ type=str,
279
+ default="/dataset/llama2/llama-2-7b-hf",
280
+ help="The model checkpoint for weights initialization.",
281
+ )
282
+ parser.add_argument(
283
+ "--data_root",
284
+ type=str,
285
+ default="./data",
286
+ help="The root folder of the data.",
287
+ )
288
+ parser.add_argument(
289
+ "--seed",
290
+ type=int,
291
+ default=42,
292
+ help="Random seed for reproducibility.",
293
+ )
294
+ parser.add_argument(
295
+ "--output_dir",
296
+ type=str,
297
+ default="./output",
298
+ help="The directory where model predictions and checkpoints will be saved.",
299
+ )
300
+ parser.add_argument(
301
+ "--load", type=str, default=None, help="Path to a quantized model to load."
302
+ )
303
+ parser.add_argument(
304
+ "--batch_size",
305
+ type=int,
306
+ default=8,
307
+ help="Batch size for processing multiple examples at once.",
308
+ )
309
+ return parser.parse_args()
310
+
311
+
312
+ def generate_batch(
313
+ model, tokenizer, input_texts: List[str], generate_kwargs: Dict
314
+ ) -> List[str]:
315
+ """
316
+ Generate responses from the model for a batch of input prompts.
317
+
318
+ Args:
319
+ model: The language model.
320
+ tokenizer: The tokenizer corresponding to the model.
321
+ input_texts (List[str]): List of prompt texts.
322
+ generate_kwargs (Dict): Additional keyword arguments for the model.generate() method.
323
+
324
+ Returns:
325
+ List[str]: List of generated responses.
326
+ """
327
+ # Tokenize all inputs with padding
328
+ encoded_inputs = tokenizer(
329
+ input_texts,
330
+ padding=True,
331
+ add_special_tokens=True,
332
+ return_tensors="pt",
333
+ truncation=True,
334
+ max_length=32 * 1024,
335
+ )
336
+
337
+ input_ids = encoded_inputs.input_ids.cuda()
338
+ attention_mask = encoded_inputs.attention_mask.cuda()
339
+
340
+ with torch.no_grad():
341
+ output_ids = model.generate(
342
+ input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs
343
+ )
344
+
345
+ responses = []
346
+ for i, output_seq in enumerate(output_ids):
347
+ # Get the length of the input sequence for this example
348
+ input_length = len(input_ids[i].nonzero())
349
+ # Decode only the generated part (skip the input)
350
+ responses.append(
351
+ tokenizer.decode(
352
+ output_seq[input_length:],
353
+ skip_special_tokens=True,
354
+ ignore_tokenization_spaces=True,
355
+ )
356
+ )
357
+
358
+ return responses
359
+
360
+
361
+ def build_prompt(question_text: str, demo_examples: List[Dict]) -> str:
362
+ """
363
+ Build a prompt with in-context examples and the target question.
364
+
365
+ Args:
366
+ question_text (str): The target question.
367
+ demo_examples (List[Dict]): List of demonstration examples.
368
+
369
+ Returns:
370
+ str: The formatted prompt.
371
+ """
372
+ # Build an 8-shot in-context prompt using shuffled demo examples
373
+ prompt = "You are a thoughtful assistant. You first think about solution step by step in mind. And then provides user with succinct answer.\n"
374
+
375
+ if USE_COT:
376
+ prompt += "Examples:\n"
377
+ shuffled_demos = random.sample(demo_examples, len(demo_examples))
378
+ for demo in shuffled_demos:
379
+ prompt += f"Question: {demo['question']}\nResponse: {demo['think']}\n{demo['answer']}\n\n"
380
+ # Append guidelines and the actual question
381
+ guidelines = (
382
+ "Guidelines:\n"
383
+ "- Show your thinking between opening <think> and closing </think> tags.\n"
384
+ "- Provide the answer between opening <answer> and closing </answer> tags.\n"
385
+ "- Include the specific value of the answer using \\boxed{ } format.\n\n"
386
+ "Task: Think step by step and solve the question given below.\n"
387
+ "Question:\n"
388
+ )
389
+ prompt += guidelines + question_text
390
+ return prompt
391
+
392
+
393
+ def main():
394
+ """
395
+ Main function to benchmark the model on the GSM8K test set using batch processing.
396
+ """
397
+ args = parse_args()
398
+ seed_everything(args.seed)
399
+ model_name = args.model_name_or_path.split("/")[-1]
400
+ print(model_name)
401
+ # Prepare test file path and download if needed
402
+ test_filepath = os.path.join(args.data_root, "gsm8k_test.jsonl")
403
+ if not os.path.exists(test_filepath):
404
+ download_url(
405
+ "https://raw.githubusercontent.com/openai/"
406
+ "grade-school-math/2909d34ef28520753df82a2234c357259d254aa8/"
407
+ "grade_school_math/data/test.jsonl",
408
+ args.data_root,
409
+ )
410
+ os.rename(os.path.join(args.data_root, "test.jsonl"), test_filepath)
411
+
412
+ # Load test data (mapping "question" to instruction and "answer" to output)
413
+ list_data_dict = load_jsonl(test_filepath, instruction="question", output="answer")
414
+
415
+ # Load the model and tokenizer
416
+ model, tokenizer = load(args.model_name_or_path)
417
+
418
+ # Optionally load a quantized model state
419
+ if args.load:
420
+ print("Loading quantized model from:", args.load)
421
+ model_state = torch.load(args.load, map_location="cpu")
422
+ model.load_state_dict(model_state, strict=False)
423
+ model.half().cuda()
424
+
425
+ # Process data in batches
426
+ batch_size = args.batch_size
427
+ answers = []
428
+ num_batches = (
429
+ len(list_data_dict) + batch_size - 1
430
+ ) // batch_size # Ceiling division
431
+
432
+ generate_kwargs = dict(
433
+ max_new_tokens=2048,
434
+ min_p=0.01,
435
+ temperature=0.5,
436
+ max_length=32 * 1024,
437
+ do_sample=True,
438
+ )
439
+
440
+ print(
441
+ f"Processing {len(list_data_dict)} examples in {num_batches} batches of size {batch_size}"
442
+ )
443
+
444
+ for batch_idx in tqdm(range(num_batches)):
445
+ start_idx = batch_idx * batch_size
446
+ end_idx = min((batch_idx + 1) * batch_size, len(list_data_dict))
447
+ batch_samples = list_data_dict[start_idx:end_idx]
448
+
449
+ # Prepare batch prompts
450
+ batch_prompts = []
451
+ batch_ground_truths = []
452
+
453
+ for sample in batch_samples:
454
+ prompt = build_prompt(sample["instruction"], demo_examples)
455
+ batch_prompts.append(prompt)
456
+ batch_ground_truths.append(extract_answer(sample["output"]))
457
+
458
+ # Generate completions for the batch
459
+ batch_completions = generate_batch(
460
+ model, tokenizer, batch_prompts, generate_kwargs
461
+ )
462
+
463
+ # Process results
464
+ for i, (prompt, completion, ground_truth, sample) in enumerate(
465
+ zip(batch_prompts, batch_completions, batch_ground_truths, batch_samples)
466
+ ):
467
+ model_answer = extract_answer(completion)
468
+ is_correct = check(ground_truth, model_answer)
469
+ answers.append(is_correct)
470
+
471
+ if DEBUG or (
472
+ batch_idx == 0 and i < 3
473
+ ): # Show first few examples of first batch
474
+ print(
475
+ f"Full Prompt: {prompt}\n\n"
476
+ f"Model Completion: {completion}\n\n"
477
+ f"Expected Answer: {ground_truth}\n\n"
478
+ f"Model Answer: {model_answer}\n\n"
479
+ f"Correct: {is_correct}\n\n"
480
+ )
481
+
482
+ # Print progress update
483
+ if (batch_idx + 1) % 1 == 0 or batch_idx == num_batches - 1:
484
+ print(
485
+ f"Processed {min((batch_idx + 1) * batch_size, len(list_data_dict))}/{len(list_data_dict)} questions, "
486
+ f"Correct: {sum(answers)}, "
487
+ f"Current Accuracy: {float(sum(answers)) / len(answers):.4f}"
488
+ )
489
+
490
+ # Save results
491
+ os.makedirs(args.output_dir, exist_ok=True)
492
+
493
+ with open(
494
+ os.path.join(args.output_dir, f"{model_name}_results.txt"),
495
+ "w",
496
+ encoding="utf-8",
497
+ ) as f:
498
+ for answer in answers:
499
+ print(int(answer), file=f)
500
+
501
+ with open(
502
+ os.path.join(args.output_dir, f"{model_name}_scores.txt"), "w", encoding="utf-8"
503
+ ) as f:
504
+ print(
505
+ f"Total questions: {len(answers)}, Correct: {sum(answers)}, "
506
+ f"Final Accuracy: {float(sum(answers)) / len(answers):.4f}",
507
+ file=f,
508
+ )
509
+
510
+
511
+ if __name__ == "__main__":
512
+ main()