HarshadAbleCredit's picture
Upload benchmark_gsm8k.py
b46a1aa verified
"""
Benchmarking a language model on the GSM8K dataset with batch processing.
This script benchmarks a pre-trained causal language model on the GSM8K test set.
It is adapted from https://github.com/Guangxuan-Xiao/GSM8K-eval.
"""
import argparse
import gzip
import json
import os
import os.path as osp
import random
import re
import ssl
import urllib.request
from typing import List, Dict
import numpy as np
import torch
import transformers
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
# Set logging verbosity to error-level (suppress warnings/info)
transformers.logging.set_verbosity(40)
# Regular expressions for answer extraction
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
BOXED_PATTERN = re.compile(r"\\boxed\{(.*?)\}")
ANSWER_TAG_PATTERN = re.compile(r"<answer>(.*?)</answer>", re.IGNORECASE | re.DOTALL)
THINK_ANSWER_PATTERN = re.compile(
r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL
) # simplified pattern
DEBUG = True
USE_COT = False
ANSWER_TRIGGER = "The answer is"
def download_url(url: str, folder: str = "folder") -> str:
"""
Download a file from a URL and save it to the specified folder.
If the file already exists in the folder, it will not be re-downloaded.
Args:
url (str): The URL of the target file.
folder (str): The folder where the file will be saved.
Returns:
str: The file path of the downloaded (or existing) file.
"""
file = url.rpartition("/")[2]
file = file if file[0] == "?" else file.split("?")[0]
path = osp.join(folder, file)
if osp.exists(path):
print(f"File {file} exists, using existing file.")
return path
print(f"Downloading {url}")
os.makedirs(folder, exist_ok=True)
ctx = ssl._create_unverified_context() # pylint:disable=protected-access
data = urllib.request.urlopen(url, context=ctx)
with open(path, "wb") as f:
f.write(data.read())
return path
def load_jsonl(
file_path: str,
instruction: str = "instruction",
inp: str = "input",
output: str = "output",
category: str = "category",
is_gzip: bool = False,
) -> list:
"""
Load a JSONL file into a list of dictionaries.
Each line of the file should be a JSON object. The function extracts
the values for keys corresponding to instruction, input, output, and category.
If a key is missing, its value is set to None.
Args:
file_path (str): The path to the JSONL file.
instruction (str): The key for the instruction/question text.
inp (str): The key for the input (if any).
output (str): The key for the expected output/answer.
category (str): The key for the category.
is_gzip (bool): Whether the file is gzip-compressed.
Returns:
list: A list of dictionaries with the extracted keys.
"""
data_list = []
open_func = open if not is_gzip else gzip.open
with open_func(file_path, "r") as f:
for line in f:
item = json.loads(line)
new_item = {
"instruction": item.get(instruction),
"input": item.get(inp),
"output": item.get(output),
"category": item.get(category),
}
data_list.append(new_item)
return data_list
def clean_answer(answer: str) -> str:
"""standard cleanups"""
answer = answer.lower()
answer = answer.rstrip(".").replace(",", "")
answer = answer.strip()
return answer
def extract_answer(completion: str) -> str:
"""
Extract the answer from a formatted output string.
The function attempts to find an answer using:
1. A boxed format (e.g., \boxed{...}).
2. A pattern matching using '####' followed by the answer.
3. A fallback using <answer>...</answer> tags (taking the last token of the last pair).
The extracted answer is cleaned by removing trailing periods and commas.
Args:
completion (str): The text output (from the model or dataset) containing the answer.
Returns:
str: The cleaned answer extracted from the text.
"""
answer = ""
boxed_match = BOXED_PATTERN.search(completion)
answer_match = ANS_RE.search(completion)
if boxed_match:
answer = boxed_match.group(1)
elif answer_match:
answer = answer_match.group(1)
else:
# Fallback: extract all matches within <answer> tags and process the last one.
answer_tags = ANSWER_TAG_PATTERN.findall(completion)
if answer_tags:
last_answer_content = answer_tags[-1]
tokens = last_answer_content.split()
if tokens:
answer = tokens[-1]
return answer
def check(ground_truth: str, completion: str) -> bool:
"""
Compare the ground truth answer with the answer extracted from the given completion text.
Args:
ground_truth (str): The expected (correct) answer.
completion (str): The text from which to extract the answer.
Returns:
bool: True if the extracted answer matches the ground truth; False otherwise.
"""
return clean_answer(ground_truth) == clean_answer(completion)
# In-context examples for demonstration (8-shot prompting)
# pylint:disable=line-too-long
demo_examples = [
{
"question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
"think": "<think>There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.</think>",
"answer": "<answer>\\boxed{5}</answer>",
},
{
"question": "If you have 4 apples and you buy 3 more, how many apples do you have?",
"think": "<think>You start with 4 apples and then add 3 more. 4 + 3 = 7.</think>",
"answer": "<answer>\\boxed{7}</answer>",
},
{
"question": "John had 10 candies and gave 3 to his friend. How many candies does John have left?",
"think": "<think>John had 10 candies and after giving away 3, he is left with 10 - 3 = 7 candies.</think>",
"answer": "<answer>\\boxed{7}</answer>",
},
{
"question": "There are 5 birds on a tree. If 2 fly away, how many birds remain?",
"think": "<think>5 birds minus 2 that fly away leaves 3 birds remaining.</think>",
"answer": "<answer>\\boxed{3}</answer>",
},
{
"question": "A basket has 6 oranges. If 4 oranges are taken out, how many oranges are left in the basket?",
"think": "<think>6 oranges minus 4 equals 2 oranges remaining.</think>",
"answer": "<answer>\\boxed{2}</answer>",
},
{
"question": "There are 8 books on the shelf. If 5 are removed, how many books remain?",
"think": "<think>8 books minus 5 removed gives 3 books remaining.</think>",
"answer": "<answer>\\boxed{3}</answer>",
},
{
"question": "If a car travels 60 miles in 1 hour, how far does it travel in 2 hours?",
"think": "<think>At 60 miles per hour, in 2 hours the car travels 60 x 2 = 120 miles.</think>",
"answer": "<answer>\\boxed{120}</answer>",
},
{
"question": "If a cake is cut into 8 pieces and you eat 3, how many pieces remain?",
"think": "<think>8 pieces minus 3 eaten equals 5 remaining pieces.</think>",
"answer": "<answer>\\boxed{5}</answer>",
},
]
# pylint:enable=line-too-long
def seed_everything(seed: int) -> None:
"""
Set seeds for various random number generators to ensure reproducibility.
Args:
seed (int): The seed value to be used.
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def load(model_name_or_path: str):
"""
Load a pre-trained causal language model and its tokenizer.
The function downloads the model/tokenizer from the given checkpoint path.
It also ensures that the tokenizer has a pad token (defaulting to eos_token_id or 0 if missing)
and sets the model to evaluation mode.
Args:
model_name_or_path (str): The path or identifier of the model checkpoint.
Returns:
tuple: A tuple containing the model and tokenizer.
"""
print(f"Loading model from {model_name_or_path} ...")
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=False
)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=False,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = (
tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)
model.eval()
return model, tokenizer
def parse_args():
"""
Parse command-line arguments.
Returns:
argparse.Namespace: Parsed command-line arguments including model checkpoint path,
data root, seed, output directory, and an optional quantized model load path.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="/dataset/llama2/llama-2-7b-hf",
help="The model checkpoint for weights initialization.",
)
parser.add_argument(
"--data_root",
type=str,
default="./data",
help="The root folder of the data.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./output",
help="The directory where model predictions and checkpoints will be saved.",
)
parser.add_argument(
"--load", type=str, default=None, help="Path to a quantized model to load."
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size for processing multiple examples at once.",
)
return parser.parse_args()
def generate_batch(
model, tokenizer, input_texts: List[str], generate_kwargs: Dict
) -> List[str]:
"""
Generate responses from the model for a batch of input prompts.
Args:
model: The language model.
tokenizer: The tokenizer corresponding to the model.
input_texts (List[str]): List of prompt texts.
generate_kwargs (Dict): Additional keyword arguments for the model.generate() method.
Returns:
List[str]: List of generated responses.
"""
# Tokenize all inputs with padding
encoded_inputs = tokenizer(
input_texts,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
max_length=32 * 1024,
)
input_ids = encoded_inputs.input_ids.cuda()
attention_mask = encoded_inputs.attention_mask.cuda()
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs
)
responses = []
for i, output_seq in enumerate(output_ids):
# Get the length of the input sequence for this example
input_length = len(input_ids[i].nonzero())
# Decode only the generated part (skip the input)
responses.append(
tokenizer.decode(
output_seq[input_length:],
skip_special_tokens=True,
ignore_tokenization_spaces=True,
)
)
return responses
def build_prompt(question_text: str, demo_examples: List[Dict]) -> str:
"""
Build a prompt with in-context examples and the target question.
Args:
question_text (str): The target question.
demo_examples (List[Dict]): List of demonstration examples.
Returns:
str: The formatted prompt.
"""
# Build an 8-shot in-context prompt using shuffled demo examples
prompt = "You are a thoughtful assistant. You first think about solution step by step in mind. And then provides user with succinct answer.\n"
if USE_COT:
prompt += "Examples:\n"
shuffled_demos = random.sample(demo_examples, len(demo_examples))
for demo in shuffled_demos:
prompt += f"Question: {demo['question']}\nResponse: {demo['think']}\n{demo['answer']}\n\n"
# Append guidelines and the actual question
guidelines = (
"Guidelines:\n"
"- Show your thinking between opening <think> and closing </think> tags.\n"
"- Provide the answer between opening <answer> and closing </answer> tags.\n"
"- Include the specific value of the answer using \\boxed{ } format.\n\n"
"Task: Think step by step and solve the question given below.\n"
"Question:\n"
)
prompt += guidelines + question_text
return prompt
def main():
"""
Main function to benchmark the model on the GSM8K test set using batch processing.
"""
args = parse_args()
seed_everything(args.seed)
model_name = args.model_name_or_path.split("/")[-1]
print(model_name)
# Prepare test file path and download if needed
test_filepath = os.path.join(args.data_root, "gsm8k_test.jsonl")
if not os.path.exists(test_filepath):
download_url(
"https://raw.githubusercontent.com/openai/"
"grade-school-math/2909d34ef28520753df82a2234c357259d254aa8/"
"grade_school_math/data/test.jsonl",
args.data_root,
)
os.rename(os.path.join(args.data_root, "test.jsonl"), test_filepath)
# Load test data (mapping "question" to instruction and "answer" to output)
list_data_dict = load_jsonl(test_filepath, instruction="question", output="answer")
# Load the model and tokenizer
model, tokenizer = load(args.model_name_or_path)
# Optionally load a quantized model state
if args.load:
print("Loading quantized model from:", args.load)
model_state = torch.load(args.load, map_location="cpu")
model.load_state_dict(model_state, strict=False)
model.half().cuda()
# Process data in batches
batch_size = args.batch_size
answers = []
num_batches = (
len(list_data_dict) + batch_size - 1
) // batch_size # Ceiling division
generate_kwargs = dict(
max_new_tokens=2048,
min_p=0.01,
temperature=0.5,
max_length=32 * 1024,
do_sample=True,
)
print(
f"Processing {len(list_data_dict)} examples in {num_batches} batches of size {batch_size}"
)
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(list_data_dict))
batch_samples = list_data_dict[start_idx:end_idx]
# Prepare batch prompts
batch_prompts = []
batch_ground_truths = []
for sample in batch_samples:
prompt = build_prompt(sample["instruction"], demo_examples)
batch_prompts.append(prompt)
batch_ground_truths.append(extract_answer(sample["output"]))
# Generate completions for the batch
batch_completions = generate_batch(
model, tokenizer, batch_prompts, generate_kwargs
)
# Process results
for i, (prompt, completion, ground_truth, sample) in enumerate(
zip(batch_prompts, batch_completions, batch_ground_truths, batch_samples)
):
model_answer = extract_answer(completion)
is_correct = check(ground_truth, model_answer)
answers.append(is_correct)
if DEBUG or (
batch_idx == 0 and i < 3
): # Show first few examples of first batch
print(
f"Full Prompt: {prompt}\n\n"
f"Model Completion: {completion}\n\n"
f"Expected Answer: {ground_truth}\n\n"
f"Model Answer: {model_answer}\n\n"
f"Correct: {is_correct}\n\n"
)
# Print progress update
if (batch_idx + 1) % 1 == 0 or batch_idx == num_batches - 1:
print(
f"Processed {min((batch_idx + 1) * batch_size, len(list_data_dict))}/{len(list_data_dict)} questions, "
f"Correct: {sum(answers)}, "
f"Current Accuracy: {float(sum(answers)) / len(answers):.4f}"
)
# Save results
os.makedirs(args.output_dir, exist_ok=True)
with open(
os.path.join(args.output_dir, f"{model_name}_results.txt"),
"w",
encoding="utf-8",
) as f:
for answer in answers:
print(int(answer), file=f)
with open(
os.path.join(args.output_dir, f"{model_name}_scores.txt"), "w", encoding="utf-8"
) as f:
print(
f"Total questions: {len(answers)}, Correct: {sum(answers)}, "
f"Final Accuracy: {float(sum(answers)) / len(answers):.4f}",
file=f,
)
if __name__ == "__main__":
main()