| | """ |
| | Benchmarking a language model on the GSM8K dataset with batch processing. |
| | |
| | This script benchmarks a pre-trained causal language model on the GSM8K test set. |
| | It is adapted from https://github.com/Guangxuan-Xiao/GSM8K-eval. |
| | """ |
| |
|
| | import argparse |
| | import gzip |
| | import json |
| | import os |
| | import os.path as osp |
| | import random |
| | import re |
| | import ssl |
| | import urllib.request |
| | from typing import List, Dict |
| |
|
| | import numpy as np |
| | import torch |
| | import transformers |
| | from tqdm import tqdm |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | |
| | transformers.logging.set_verbosity(40) |
| |
|
| | |
| | ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") |
| | BOXED_PATTERN = re.compile(r"\\boxed\{(.*?)\}") |
| | ANSWER_TAG_PATTERN = re.compile(r"<answer>(.*?)</answer>", re.IGNORECASE | re.DOTALL) |
| | THINK_ANSWER_PATTERN = re.compile( |
| | r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL |
| | ) |
| |
|
| | DEBUG = True |
| | USE_COT = False |
| | ANSWER_TRIGGER = "The answer is" |
| |
|
| |
|
| | def download_url(url: str, folder: str = "folder") -> str: |
| | """ |
| | Download a file from a URL and save it to the specified folder. |
| | |
| | If the file already exists in the folder, it will not be re-downloaded. |
| | |
| | Args: |
| | url (str): The URL of the target file. |
| | folder (str): The folder where the file will be saved. |
| | |
| | Returns: |
| | str: The file path of the downloaded (or existing) file. |
| | """ |
| | file = url.rpartition("/")[2] |
| | file = file if file[0] == "?" else file.split("?")[0] |
| | path = osp.join(folder, file) |
| | if osp.exists(path): |
| | print(f"File {file} exists, using existing file.") |
| | return path |
| |
|
| | print(f"Downloading {url}") |
| | os.makedirs(folder, exist_ok=True) |
| | ctx = ssl._create_unverified_context() |
| | data = urllib.request.urlopen(url, context=ctx) |
| | with open(path, "wb") as f: |
| | f.write(data.read()) |
| |
|
| | return path |
| |
|
| |
|
| | def load_jsonl( |
| | file_path: str, |
| | instruction: str = "instruction", |
| | inp: str = "input", |
| | output: str = "output", |
| | category: str = "category", |
| | is_gzip: bool = False, |
| | ) -> list: |
| | """ |
| | Load a JSONL file into a list of dictionaries. |
| | |
| | Each line of the file should be a JSON object. The function extracts |
| | the values for keys corresponding to instruction, input, output, and category. |
| | If a key is missing, its value is set to None. |
| | |
| | Args: |
| | file_path (str): The path to the JSONL file. |
| | instruction (str): The key for the instruction/question text. |
| | inp (str): The key for the input (if any). |
| | output (str): The key for the expected output/answer. |
| | category (str): The key for the category. |
| | is_gzip (bool): Whether the file is gzip-compressed. |
| | |
| | Returns: |
| | list: A list of dictionaries with the extracted keys. |
| | """ |
| | data_list = [] |
| | open_func = open if not is_gzip else gzip.open |
| | with open_func(file_path, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | new_item = { |
| | "instruction": item.get(instruction), |
| | "input": item.get(inp), |
| | "output": item.get(output), |
| | "category": item.get(category), |
| | } |
| | data_list.append(new_item) |
| | return data_list |
| |
|
| |
|
| | def clean_answer(answer: str) -> str: |
| | """standard cleanups""" |
| | answer = answer.lower() |
| | answer = answer.rstrip(".").replace(",", "") |
| | answer = answer.strip() |
| | return answer |
| |
|
| |
|
| | def extract_answer(completion: str) -> str: |
| | """ |
| | Extract the answer from a formatted output string. |
| | |
| | The function attempts to find an answer using: |
| | 1. A boxed format (e.g., \boxed{...}). |
| | 2. A pattern matching using '####' followed by the answer. |
| | 3. A fallback using <answer>...</answer> tags (taking the last token of the last pair). |
| | |
| | The extracted answer is cleaned by removing trailing periods and commas. |
| | |
| | Args: |
| | completion (str): The text output (from the model or dataset) containing the answer. |
| | |
| | Returns: |
| | str: The cleaned answer extracted from the text. |
| | """ |
| | answer = "" |
| | boxed_match = BOXED_PATTERN.search(completion) |
| | answer_match = ANS_RE.search(completion) |
| | if boxed_match: |
| | answer = boxed_match.group(1) |
| | elif answer_match: |
| | answer = answer_match.group(1) |
| | else: |
| | |
| | answer_tags = ANSWER_TAG_PATTERN.findall(completion) |
| | if answer_tags: |
| | last_answer_content = answer_tags[-1] |
| | tokens = last_answer_content.split() |
| | if tokens: |
| | answer = tokens[-1] |
| | return answer |
| |
|
| |
|
| | def check(ground_truth: str, completion: str) -> bool: |
| | """ |
| | Compare the ground truth answer with the answer extracted from the given completion text. |
| | |
| | Args: |
| | ground_truth (str): The expected (correct) answer. |
| | completion (str): The text from which to extract the answer. |
| | |
| | Returns: |
| | bool: True if the extracted answer matches the ground truth; False otherwise. |
| | """ |
| |
|
| | return clean_answer(ground_truth) == clean_answer(completion) |
| |
|
| |
|
| | |
| |
|
| | |
| | demo_examples = [ |
| | { |
| | "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", |
| | "think": "<think>There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.</think>", |
| | "answer": "<answer>\\boxed{5}</answer>", |
| | }, |
| | { |
| | "question": "If you have 4 apples and you buy 3 more, how many apples do you have?", |
| | "think": "<think>You start with 4 apples and then add 3 more. 4 + 3 = 7.</think>", |
| | "answer": "<answer>\\boxed{7}</answer>", |
| | }, |
| | { |
| | "question": "John had 10 candies and gave 3 to his friend. How many candies does John have left?", |
| | "think": "<think>John had 10 candies and after giving away 3, he is left with 10 - 3 = 7 candies.</think>", |
| | "answer": "<answer>\\boxed{7}</answer>", |
| | }, |
| | { |
| | "question": "There are 5 birds on a tree. If 2 fly away, how many birds remain?", |
| | "think": "<think>5 birds minus 2 that fly away leaves 3 birds remaining.</think>", |
| | "answer": "<answer>\\boxed{3}</answer>", |
| | }, |
| | { |
| | "question": "A basket has 6 oranges. If 4 oranges are taken out, how many oranges are left in the basket?", |
| | "think": "<think>6 oranges minus 4 equals 2 oranges remaining.</think>", |
| | "answer": "<answer>\\boxed{2}</answer>", |
| | }, |
| | { |
| | "question": "There are 8 books on the shelf. If 5 are removed, how many books remain?", |
| | "think": "<think>8 books minus 5 removed gives 3 books remaining.</think>", |
| | "answer": "<answer>\\boxed{3}</answer>", |
| | }, |
| | { |
| | "question": "If a car travels 60 miles in 1 hour, how far does it travel in 2 hours?", |
| | "think": "<think>At 60 miles per hour, in 2 hours the car travels 60 x 2 = 120 miles.</think>", |
| | "answer": "<answer>\\boxed{120}</answer>", |
| | }, |
| | { |
| | "question": "If a cake is cut into 8 pieces and you eat 3, how many pieces remain?", |
| | "think": "<think>8 pieces minus 3 eaten equals 5 remaining pieces.</think>", |
| | "answer": "<answer>\\boxed{5}</answer>", |
| | }, |
| | ] |
| | |
| |
|
| |
|
| | def seed_everything(seed: int) -> None: |
| | """ |
| | Set seeds for various random number generators to ensure reproducibility. |
| | |
| | Args: |
| | seed (int): The seed value to be used. |
| | """ |
| | random.seed(seed) |
| | os.environ["PYTHONHASHSEED"] = str(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = True |
| |
|
| |
|
| | def load(model_name_or_path: str): |
| | """ |
| | Load a pre-trained causal language model and its tokenizer. |
| | |
| | The function downloads the model/tokenizer from the given checkpoint path. |
| | It also ensures that the tokenizer has a pad token (defaulting to eos_token_id or 0 if missing) |
| | and sets the model to evaluation mode. |
| | |
| | Args: |
| | model_name_or_path (str): The path or identifier of the model checkpoint. |
| | |
| | Returns: |
| | tuple: A tuple containing the model and tokenizer. |
| | """ |
| | print(f"Loading model from {model_name_or_path} ...") |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_name_or_path, trust_remote_code=False |
| | ) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name_or_path, |
| | device_map="auto", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=False, |
| | ) |
| | if tokenizer.pad_token_id is None: |
| | tokenizer.pad_token_id = ( |
| | tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 |
| | ) |
| |
|
| | model.eval() |
| | return model, tokenizer |
| |
|
| |
|
| | def parse_args(): |
| | """ |
| | Parse command-line arguments. |
| | |
| | Returns: |
| | argparse.Namespace: Parsed command-line arguments including model checkpoint path, |
| | data root, seed, output directory, and an optional quantized model load path. |
| | """ |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_name_or_path", |
| | type=str, |
| | default="/dataset/llama2/llama-2-7b-hf", |
| | help="The model checkpoint for weights initialization.", |
| | ) |
| | parser.add_argument( |
| | "--data_root", |
| | type=str, |
| | default="./data", |
| | help="The root folder of the data.", |
| | ) |
| | parser.add_argument( |
| | "--seed", |
| | type=int, |
| | default=42, |
| | help="Random seed for reproducibility.", |
| | ) |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default="./output", |
| | help="The directory where model predictions and checkpoints will be saved.", |
| | ) |
| | parser.add_argument( |
| | "--load", type=str, default=None, help="Path to a quantized model to load." |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=8, |
| | help="Batch size for processing multiple examples at once.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def generate_batch( |
| | model, tokenizer, input_texts: List[str], generate_kwargs: Dict |
| | ) -> List[str]: |
| | """ |
| | Generate responses from the model for a batch of input prompts. |
| | |
| | Args: |
| | model: The language model. |
| | tokenizer: The tokenizer corresponding to the model. |
| | input_texts (List[str]): List of prompt texts. |
| | generate_kwargs (Dict): Additional keyword arguments for the model.generate() method. |
| | |
| | Returns: |
| | List[str]: List of generated responses. |
| | """ |
| | |
| | encoded_inputs = tokenizer( |
| | input_texts, |
| | padding=True, |
| | add_special_tokens=True, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=32 * 1024, |
| | ) |
| |
|
| | input_ids = encoded_inputs.input_ids.cuda() |
| | attention_mask = encoded_inputs.attention_mask.cuda() |
| |
|
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs |
| | ) |
| |
|
| | responses = [] |
| | for i, output_seq in enumerate(output_ids): |
| | |
| | input_length = len(input_ids[i].nonzero()) |
| | |
| | responses.append( |
| | tokenizer.decode( |
| | output_seq[input_length:], |
| | skip_special_tokens=True, |
| | ignore_tokenization_spaces=True, |
| | ) |
| | ) |
| |
|
| | return responses |
| |
|
| |
|
| | def build_prompt(question_text: str, demo_examples: List[Dict]) -> str: |
| | """ |
| | Build a prompt with in-context examples and the target question. |
| | |
| | Args: |
| | question_text (str): The target question. |
| | demo_examples (List[Dict]): List of demonstration examples. |
| | |
| | Returns: |
| | str: The formatted prompt. |
| | """ |
| | |
| | prompt = "You are a thoughtful assistant. You first think about solution step by step in mind. And then provides user with succinct answer.\n" |
| |
|
| | if USE_COT: |
| | prompt += "Examples:\n" |
| | shuffled_demos = random.sample(demo_examples, len(demo_examples)) |
| | for demo in shuffled_demos: |
| | prompt += f"Question: {demo['question']}\nResponse: {demo['think']}\n{demo['answer']}\n\n" |
| | |
| | guidelines = ( |
| | "Guidelines:\n" |
| | "- Show your thinking between opening <think> and closing </think> tags.\n" |
| | "- Provide the answer between opening <answer> and closing </answer> tags.\n" |
| | "- Include the specific value of the answer using \\boxed{ } format.\n\n" |
| | "Task: Think step by step and solve the question given below.\n" |
| | "Question:\n" |
| | ) |
| | prompt += guidelines + question_text |
| | return prompt |
| |
|
| |
|
| | def main(): |
| | """ |
| | Main function to benchmark the model on the GSM8K test set using batch processing. |
| | """ |
| | args = parse_args() |
| | seed_everything(args.seed) |
| | model_name = args.model_name_or_path.split("/")[-1] |
| | print(model_name) |
| | |
| | test_filepath = os.path.join(args.data_root, "gsm8k_test.jsonl") |
| | if not os.path.exists(test_filepath): |
| | download_url( |
| | "https://raw.githubusercontent.com/openai/" |
| | "grade-school-math/2909d34ef28520753df82a2234c357259d254aa8/" |
| | "grade_school_math/data/test.jsonl", |
| | args.data_root, |
| | ) |
| | os.rename(os.path.join(args.data_root, "test.jsonl"), test_filepath) |
| |
|
| | |
| | list_data_dict = load_jsonl(test_filepath, instruction="question", output="answer") |
| |
|
| | |
| | model, tokenizer = load(args.model_name_or_path) |
| |
|
| | |
| | if args.load: |
| | print("Loading quantized model from:", args.load) |
| | model_state = torch.load(args.load, map_location="cpu") |
| | model.load_state_dict(model_state, strict=False) |
| | model.half().cuda() |
| |
|
| | |
| | batch_size = args.batch_size |
| | answers = [] |
| | num_batches = ( |
| | len(list_data_dict) + batch_size - 1 |
| | ) // batch_size |
| |
|
| | generate_kwargs = dict( |
| | max_new_tokens=2048, |
| | min_p=0.01, |
| | temperature=0.5, |
| | max_length=32 * 1024, |
| | do_sample=True, |
| | ) |
| |
|
| | print( |
| | f"Processing {len(list_data_dict)} examples in {num_batches} batches of size {batch_size}" |
| | ) |
| |
|
| | for batch_idx in tqdm(range(num_batches)): |
| | start_idx = batch_idx * batch_size |
| | end_idx = min((batch_idx + 1) * batch_size, len(list_data_dict)) |
| | batch_samples = list_data_dict[start_idx:end_idx] |
| |
|
| | |
| | batch_prompts = [] |
| | batch_ground_truths = [] |
| |
|
| | for sample in batch_samples: |
| | prompt = build_prompt(sample["instruction"], demo_examples) |
| | batch_prompts.append(prompt) |
| | batch_ground_truths.append(extract_answer(sample["output"])) |
| |
|
| | |
| | batch_completions = generate_batch( |
| | model, tokenizer, batch_prompts, generate_kwargs |
| | ) |
| |
|
| | |
| | for i, (prompt, completion, ground_truth, sample) in enumerate( |
| | zip(batch_prompts, batch_completions, batch_ground_truths, batch_samples) |
| | ): |
| | model_answer = extract_answer(completion) |
| | is_correct = check(ground_truth, model_answer) |
| | answers.append(is_correct) |
| |
|
| | if DEBUG or ( |
| | batch_idx == 0 and i < 3 |
| | ): |
| | print( |
| | f"Full Prompt: {prompt}\n\n" |
| | f"Model Completion: {completion}\n\n" |
| | f"Expected Answer: {ground_truth}\n\n" |
| | f"Model Answer: {model_answer}\n\n" |
| | f"Correct: {is_correct}\n\n" |
| | ) |
| |
|
| | |
| | if (batch_idx + 1) % 1 == 0 or batch_idx == num_batches - 1: |
| | print( |
| | f"Processed {min((batch_idx + 1) * batch_size, len(list_data_dict))}/{len(list_data_dict)} questions, " |
| | f"Correct: {sum(answers)}, " |
| | f"Current Accuracy: {float(sum(answers)) / len(answers):.4f}" |
| | ) |
| |
|
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| |
|
| | with open( |
| | os.path.join(args.output_dir, f"{model_name}_results.txt"), |
| | "w", |
| | encoding="utf-8", |
| | ) as f: |
| | for answer in answers: |
| | print(int(answer), file=f) |
| |
|
| | with open( |
| | os.path.join(args.output_dir, f"{model_name}_scores.txt"), "w", encoding="utf-8" |
| | ) as f: |
| | print( |
| | f"Total questions: {len(answers)}, Correct: {sum(answers)}, " |
| | f"Final Accuracy: {float(sum(answers)) / len(answers):.4f}", |
| | file=f, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|