""" Benchmarking a language model on the GSM8K dataset with batch processing. This script benchmarks a pre-trained causal language model on the GSM8K test set. It is adapted from https://github.com/Guangxuan-Xiao/GSM8K-eval. """ import argparse import gzip import json import os import os.path as osp import random import re import ssl import urllib.request from typing import List, Dict import numpy as np import torch import transformers from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM # Set logging verbosity to error-level (suppress warnings/info) transformers.logging.set_verbosity(40) # Regular expressions for answer extraction ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") BOXED_PATTERN = re.compile(r"\\boxed\{(.*?)\}") ANSWER_TAG_PATTERN = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) THINK_ANSWER_PATTERN = re.compile( r".*?\s*.*?", re.DOTALL ) # simplified pattern DEBUG = True USE_COT = False ANSWER_TRIGGER = "The answer is" def download_url(url: str, folder: str = "folder") -> str: """ Download a file from a URL and save it to the specified folder. If the file already exists in the folder, it will not be re-downloaded. Args: url (str): The URL of the target file. folder (str): The folder where the file will be saved. Returns: str: The file path of the downloaded (or existing) file. """ file = url.rpartition("/")[2] file = file if file[0] == "?" else file.split("?")[0] path = osp.join(folder, file) if osp.exists(path): print(f"File {file} exists, using existing file.") return path print(f"Downloading {url}") os.makedirs(folder, exist_ok=True) ctx = ssl._create_unverified_context() # pylint:disable=protected-access data = urllib.request.urlopen(url, context=ctx) with open(path, "wb") as f: f.write(data.read()) return path def load_jsonl( file_path: str, instruction: str = "instruction", inp: str = "input", output: str = "output", category: str = "category", is_gzip: bool = False, ) -> list: """ Load a JSONL file into a list of dictionaries. Each line of the file should be a JSON object. The function extracts the values for keys corresponding to instruction, input, output, and category. If a key is missing, its value is set to None. Args: file_path (str): The path to the JSONL file. instruction (str): The key for the instruction/question text. inp (str): The key for the input (if any). output (str): The key for the expected output/answer. category (str): The key for the category. is_gzip (bool): Whether the file is gzip-compressed. Returns: list: A list of dictionaries with the extracted keys. """ data_list = [] open_func = open if not is_gzip else gzip.open with open_func(file_path, "r") as f: for line in f: item = json.loads(line) new_item = { "instruction": item.get(instruction), "input": item.get(inp), "output": item.get(output), "category": item.get(category), } data_list.append(new_item) return data_list def clean_answer(answer: str) -> str: """standard cleanups""" answer = answer.lower() answer = answer.rstrip(".").replace(",", "") answer = answer.strip() return answer def extract_answer(completion: str) -> str: """ Extract the answer from a formatted output string. The function attempts to find an answer using: 1. A boxed format (e.g., \boxed{...}). 2. A pattern matching using '####' followed by the answer. 3. A fallback using ... tags (taking the last token of the last pair). The extracted answer is cleaned by removing trailing periods and commas. Args: completion (str): The text output (from the model or dataset) containing the answer. Returns: str: The cleaned answer extracted from the text. """ answer = "" boxed_match = BOXED_PATTERN.search(completion) answer_match = ANS_RE.search(completion) if boxed_match: answer = boxed_match.group(1) elif answer_match: answer = answer_match.group(1) else: # Fallback: extract all matches within tags and process the last one. answer_tags = ANSWER_TAG_PATTERN.findall(completion) if answer_tags: last_answer_content = answer_tags[-1] tokens = last_answer_content.split() if tokens: answer = tokens[-1] return answer def check(ground_truth: str, completion: str) -> bool: """ Compare the ground truth answer with the answer extracted from the given completion text. Args: ground_truth (str): The expected (correct) answer. completion (str): The text from which to extract the answer. Returns: bool: True if the extracted answer matches the ground truth; False otherwise. """ return clean_answer(ground_truth) == clean_answer(completion) # In-context examples for demonstration (8-shot prompting) # pylint:disable=line-too-long demo_examples = [ { "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", "think": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.", "answer": "\\boxed{5}", }, { "question": "If you have 4 apples and you buy 3 more, how many apples do you have?", "think": "You start with 4 apples and then add 3 more. 4 + 3 = 7.", "answer": "\\boxed{7}", }, { "question": "John had 10 candies and gave 3 to his friend. How many candies does John have left?", "think": "John had 10 candies and after giving away 3, he is left with 10 - 3 = 7 candies.", "answer": "\\boxed{7}", }, { "question": "There are 5 birds on a tree. If 2 fly away, how many birds remain?", "think": "5 birds minus 2 that fly away leaves 3 birds remaining.", "answer": "\\boxed{3}", }, { "question": "A basket has 6 oranges. If 4 oranges are taken out, how many oranges are left in the basket?", "think": "6 oranges minus 4 equals 2 oranges remaining.", "answer": "\\boxed{2}", }, { "question": "There are 8 books on the shelf. If 5 are removed, how many books remain?", "think": "8 books minus 5 removed gives 3 books remaining.", "answer": "\\boxed{3}", }, { "question": "If a car travels 60 miles in 1 hour, how far does it travel in 2 hours?", "think": "At 60 miles per hour, in 2 hours the car travels 60 x 2 = 120 miles.", "answer": "\\boxed{120}", }, { "question": "If a cake is cut into 8 pieces and you eat 3, how many pieces remain?", "think": "8 pieces minus 3 eaten equals 5 remaining pieces.", "answer": "\\boxed{5}", }, ] # pylint:enable=line-too-long def seed_everything(seed: int) -> None: """ Set seeds for various random number generators to ensure reproducibility. Args: seed (int): The seed value to be used. """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def load(model_name_or_path: str): """ Load a pre-trained causal language model and its tokenizer. The function downloads the model/tokenizer from the given checkpoint path. It also ensures that the tokenizer has a pad token (defaulting to eos_token_id or 0 if missing) and sets the model to evaluation mode. Args: model_name_or_path (str): The path or identifier of the model checkpoint. Returns: tuple: A tuple containing the model and tokenizer. """ print(f"Loading model from {model_name_or_path} ...") tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=False ) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=False, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = ( tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 ) model.eval() return model, tokenizer def parse_args(): """ Parse command-line arguments. Returns: argparse.Namespace: Parsed command-line arguments including model checkpoint path, data root, seed, output directory, and an optional quantized model load path. """ parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, default="/dataset/llama2/llama-2-7b-hf", help="The model checkpoint for weights initialization.", ) parser.add_argument( "--data_root", type=str, default="./data", help="The root folder of the data.", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility.", ) parser.add_argument( "--output_dir", type=str, default="./output", help="The directory where model predictions and checkpoints will be saved.", ) parser.add_argument( "--load", type=str, default=None, help="Path to a quantized model to load." ) parser.add_argument( "--batch_size", type=int, default=8, help="Batch size for processing multiple examples at once.", ) return parser.parse_args() def generate_batch( model, tokenizer, input_texts: List[str], generate_kwargs: Dict ) -> List[str]: """ Generate responses from the model for a batch of input prompts. Args: model: The language model. tokenizer: The tokenizer corresponding to the model. input_texts (List[str]): List of prompt texts. generate_kwargs (Dict): Additional keyword arguments for the model.generate() method. Returns: List[str]: List of generated responses. """ # Tokenize all inputs with padding encoded_inputs = tokenizer( input_texts, padding=True, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=32 * 1024, ) input_ids = encoded_inputs.input_ids.cuda() attention_mask = encoded_inputs.attention_mask.cuda() with torch.no_grad(): output_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs ) responses = [] for i, output_seq in enumerate(output_ids): # Get the length of the input sequence for this example input_length = len(input_ids[i].nonzero()) # Decode only the generated part (skip the input) responses.append( tokenizer.decode( output_seq[input_length:], skip_special_tokens=True, ignore_tokenization_spaces=True, ) ) return responses def build_prompt(question_text: str, demo_examples: List[Dict]) -> str: """ Build a prompt with in-context examples and the target question. Args: question_text (str): The target question. demo_examples (List[Dict]): List of demonstration examples. Returns: str: The formatted prompt. """ # Build an 8-shot in-context prompt using shuffled demo examples prompt = "You are a thoughtful assistant. You first think about solution step by step in mind. And then provides user with succinct answer.\n" if USE_COT: prompt += "Examples:\n" shuffled_demos = random.sample(demo_examples, len(demo_examples)) for demo in shuffled_demos: prompt += f"Question: {demo['question']}\nResponse: {demo['think']}\n{demo['answer']}\n\n" # Append guidelines and the actual question guidelines = ( "Guidelines:\n" "- Show your thinking between opening and closing tags.\n" "- Provide the answer between opening and closing tags.\n" "- Include the specific value of the answer using \\boxed{ } format.\n\n" "Task: Think step by step and solve the question given below.\n" "Question:\n" ) prompt += guidelines + question_text return prompt def main(): """ Main function to benchmark the model on the GSM8K test set using batch processing. """ args = parse_args() seed_everything(args.seed) model_name = args.model_name_or_path.split("/")[-1] print(model_name) # Prepare test file path and download if needed test_filepath = os.path.join(args.data_root, "gsm8k_test.jsonl") if not os.path.exists(test_filepath): download_url( "https://raw.githubusercontent.com/openai/" "grade-school-math/2909d34ef28520753df82a2234c357259d254aa8/" "grade_school_math/data/test.jsonl", args.data_root, ) os.rename(os.path.join(args.data_root, "test.jsonl"), test_filepath) # Load test data (mapping "question" to instruction and "answer" to output) list_data_dict = load_jsonl(test_filepath, instruction="question", output="answer") # Load the model and tokenizer model, tokenizer = load(args.model_name_or_path) # Optionally load a quantized model state if args.load: print("Loading quantized model from:", args.load) model_state = torch.load(args.load, map_location="cpu") model.load_state_dict(model_state, strict=False) model.half().cuda() # Process data in batches batch_size = args.batch_size answers = [] num_batches = ( len(list_data_dict) + batch_size - 1 ) // batch_size # Ceiling division generate_kwargs = dict( max_new_tokens=2048, min_p=0.01, temperature=0.5, max_length=32 * 1024, do_sample=True, ) print( f"Processing {len(list_data_dict)} examples in {num_batches} batches of size {batch_size}" ) for batch_idx in tqdm(range(num_batches)): start_idx = batch_idx * batch_size end_idx = min((batch_idx + 1) * batch_size, len(list_data_dict)) batch_samples = list_data_dict[start_idx:end_idx] # Prepare batch prompts batch_prompts = [] batch_ground_truths = [] for sample in batch_samples: prompt = build_prompt(sample["instruction"], demo_examples) batch_prompts.append(prompt) batch_ground_truths.append(extract_answer(sample["output"])) # Generate completions for the batch batch_completions = generate_batch( model, tokenizer, batch_prompts, generate_kwargs ) # Process results for i, (prompt, completion, ground_truth, sample) in enumerate( zip(batch_prompts, batch_completions, batch_ground_truths, batch_samples) ): model_answer = extract_answer(completion) is_correct = check(ground_truth, model_answer) answers.append(is_correct) if DEBUG or ( batch_idx == 0 and i < 3 ): # Show first few examples of first batch print( f"Full Prompt: {prompt}\n\n" f"Model Completion: {completion}\n\n" f"Expected Answer: {ground_truth}\n\n" f"Model Answer: {model_answer}\n\n" f"Correct: {is_correct}\n\n" ) # Print progress update if (batch_idx + 1) % 1 == 0 or batch_idx == num_batches - 1: print( f"Processed {min((batch_idx + 1) * batch_size, len(list_data_dict))}/{len(list_data_dict)} questions, " f"Correct: {sum(answers)}, " f"Current Accuracy: {float(sum(answers)) / len(answers):.4f}" ) # Save results os.makedirs(args.output_dir, exist_ok=True) with open( os.path.join(args.output_dir, f"{model_name}_results.txt"), "w", encoding="utf-8", ) as f: for answer in answers: print(int(answer), file=f) with open( os.path.join(args.output_dir, f"{model_name}_scores.txt"), "w", encoding="utf-8" ) as f: print( f"Total questions: {len(answers)}, Correct: {sum(answers)}, " f"Final Accuracy: {float(sum(answers)) / len(answers):.4f}", file=f, ) if __name__ == "__main__": main()