File size: 17,186 Bytes
b46a1aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 | """
Benchmarking a language model on the GSM8K dataset with batch processing.
This script benchmarks a pre-trained causal language model on the GSM8K test set.
It is adapted from https://github.com/Guangxuan-Xiao/GSM8K-eval.
"""
import argparse
import gzip
import json
import os
import os.path as osp
import random
import re
import ssl
import urllib.request
from typing import List, Dict
import numpy as np
import torch
import transformers
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
# Set logging verbosity to error-level (suppress warnings/info)
transformers.logging.set_verbosity(40)
# Regular expressions for answer extraction
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
BOXED_PATTERN = re.compile(r"\\boxed\{(.*?)\}")
ANSWER_TAG_PATTERN = re.compile(r"<answer>(.*?)</answer>", re.IGNORECASE | re.DOTALL)
THINK_ANSWER_PATTERN = re.compile(
r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL
) # simplified pattern
DEBUG = True
USE_COT = False
ANSWER_TRIGGER = "The answer is"
def download_url(url: str, folder: str = "folder") -> str:
"""
Download a file from a URL and save it to the specified folder.
If the file already exists in the folder, it will not be re-downloaded.
Args:
url (str): The URL of the target file.
folder (str): The folder where the file will be saved.
Returns:
str: The file path of the downloaded (or existing) file.
"""
file = url.rpartition("/")[2]
file = file if file[0] == "?" else file.split("?")[0]
path = osp.join(folder, file)
if osp.exists(path):
print(f"File {file} exists, using existing file.")
return path
print(f"Downloading {url}")
os.makedirs(folder, exist_ok=True)
ctx = ssl._create_unverified_context() # pylint:disable=protected-access
data = urllib.request.urlopen(url, context=ctx)
with open(path, "wb") as f:
f.write(data.read())
return path
def load_jsonl(
file_path: str,
instruction: str = "instruction",
inp: str = "input",
output: str = "output",
category: str = "category",
is_gzip: bool = False,
) -> list:
"""
Load a JSONL file into a list of dictionaries.
Each line of the file should be a JSON object. The function extracts
the values for keys corresponding to instruction, input, output, and category.
If a key is missing, its value is set to None.
Args:
file_path (str): The path to the JSONL file.
instruction (str): The key for the instruction/question text.
inp (str): The key for the input (if any).
output (str): The key for the expected output/answer.
category (str): The key for the category.
is_gzip (bool): Whether the file is gzip-compressed.
Returns:
list: A list of dictionaries with the extracted keys.
"""
data_list = []
open_func = open if not is_gzip else gzip.open
with open_func(file_path, "r") as f:
for line in f:
item = json.loads(line)
new_item = {
"instruction": item.get(instruction),
"input": item.get(inp),
"output": item.get(output),
"category": item.get(category),
}
data_list.append(new_item)
return data_list
def clean_answer(answer: str) -> str:
"""standard cleanups"""
answer = answer.lower()
answer = answer.rstrip(".").replace(",", "")
answer = answer.strip()
return answer
def extract_answer(completion: str) -> str:
"""
Extract the answer from a formatted output string.
The function attempts to find an answer using:
1. A boxed format (e.g., \boxed{...}).
2. A pattern matching using '####' followed by the answer.
3. A fallback using <answer>...</answer> tags (taking the last token of the last pair).
The extracted answer is cleaned by removing trailing periods and commas.
Args:
completion (str): The text output (from the model or dataset) containing the answer.
Returns:
str: The cleaned answer extracted from the text.
"""
answer = ""
boxed_match = BOXED_PATTERN.search(completion)
answer_match = ANS_RE.search(completion)
if boxed_match:
answer = boxed_match.group(1)
elif answer_match:
answer = answer_match.group(1)
else:
# Fallback: extract all matches within <answer> tags and process the last one.
answer_tags = ANSWER_TAG_PATTERN.findall(completion)
if answer_tags:
last_answer_content = answer_tags[-1]
tokens = last_answer_content.split()
if tokens:
answer = tokens[-1]
return answer
def check(ground_truth: str, completion: str) -> bool:
"""
Compare the ground truth answer with the answer extracted from the given completion text.
Args:
ground_truth (str): The expected (correct) answer.
completion (str): The text from which to extract the answer.
Returns:
bool: True if the extracted answer matches the ground truth; False otherwise.
"""
return clean_answer(ground_truth) == clean_answer(completion)
# In-context examples for demonstration (8-shot prompting)
# pylint:disable=line-too-long
demo_examples = [
{
"question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
"think": "<think>There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.</think>",
"answer": "<answer>\\boxed{5}</answer>",
},
{
"question": "If you have 4 apples and you buy 3 more, how many apples do you have?",
"think": "<think>You start with 4 apples and then add 3 more. 4 + 3 = 7.</think>",
"answer": "<answer>\\boxed{7}</answer>",
},
{
"question": "John had 10 candies and gave 3 to his friend. How many candies does John have left?",
"think": "<think>John had 10 candies and after giving away 3, he is left with 10 - 3 = 7 candies.</think>",
"answer": "<answer>\\boxed{7}</answer>",
},
{
"question": "There are 5 birds on a tree. If 2 fly away, how many birds remain?",
"think": "<think>5 birds minus 2 that fly away leaves 3 birds remaining.</think>",
"answer": "<answer>\\boxed{3}</answer>",
},
{
"question": "A basket has 6 oranges. If 4 oranges are taken out, how many oranges are left in the basket?",
"think": "<think>6 oranges minus 4 equals 2 oranges remaining.</think>",
"answer": "<answer>\\boxed{2}</answer>",
},
{
"question": "There are 8 books on the shelf. If 5 are removed, how many books remain?",
"think": "<think>8 books minus 5 removed gives 3 books remaining.</think>",
"answer": "<answer>\\boxed{3}</answer>",
},
{
"question": "If a car travels 60 miles in 1 hour, how far does it travel in 2 hours?",
"think": "<think>At 60 miles per hour, in 2 hours the car travels 60 x 2 = 120 miles.</think>",
"answer": "<answer>\\boxed{120}</answer>",
},
{
"question": "If a cake is cut into 8 pieces and you eat 3, how many pieces remain?",
"think": "<think>8 pieces minus 3 eaten equals 5 remaining pieces.</think>",
"answer": "<answer>\\boxed{5}</answer>",
},
]
# pylint:enable=line-too-long
def seed_everything(seed: int) -> None:
"""
Set seeds for various random number generators to ensure reproducibility.
Args:
seed (int): The seed value to be used.
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def load(model_name_or_path: str):
"""
Load a pre-trained causal language model and its tokenizer.
The function downloads the model/tokenizer from the given checkpoint path.
It also ensures that the tokenizer has a pad token (defaulting to eos_token_id or 0 if missing)
and sets the model to evaluation mode.
Args:
model_name_or_path (str): The path or identifier of the model checkpoint.
Returns:
tuple: A tuple containing the model and tokenizer.
"""
print(f"Loading model from {model_name_or_path} ...")
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=False
)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=False,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = (
tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
)
model.eval()
return model, tokenizer
def parse_args():
"""
Parse command-line arguments.
Returns:
argparse.Namespace: Parsed command-line arguments including model checkpoint path,
data root, seed, output directory, and an optional quantized model load path.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="/dataset/llama2/llama-2-7b-hf",
help="The model checkpoint for weights initialization.",
)
parser.add_argument(
"--data_root",
type=str,
default="./data",
help="The root folder of the data.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./output",
help="The directory where model predictions and checkpoints will be saved.",
)
parser.add_argument(
"--load", type=str, default=None, help="Path to a quantized model to load."
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size for processing multiple examples at once.",
)
return parser.parse_args()
def generate_batch(
model, tokenizer, input_texts: List[str], generate_kwargs: Dict
) -> List[str]:
"""
Generate responses from the model for a batch of input prompts.
Args:
model: The language model.
tokenizer: The tokenizer corresponding to the model.
input_texts (List[str]): List of prompt texts.
generate_kwargs (Dict): Additional keyword arguments for the model.generate() method.
Returns:
List[str]: List of generated responses.
"""
# Tokenize all inputs with padding
encoded_inputs = tokenizer(
input_texts,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
max_length=32 * 1024,
)
input_ids = encoded_inputs.input_ids.cuda()
attention_mask = encoded_inputs.attention_mask.cuda()
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs
)
responses = []
for i, output_seq in enumerate(output_ids):
# Get the length of the input sequence for this example
input_length = len(input_ids[i].nonzero())
# Decode only the generated part (skip the input)
responses.append(
tokenizer.decode(
output_seq[input_length:],
skip_special_tokens=True,
ignore_tokenization_spaces=True,
)
)
return responses
def build_prompt(question_text: str, demo_examples: List[Dict]) -> str:
"""
Build a prompt with in-context examples and the target question.
Args:
question_text (str): The target question.
demo_examples (List[Dict]): List of demonstration examples.
Returns:
str: The formatted prompt.
"""
# Build an 8-shot in-context prompt using shuffled demo examples
prompt = "You are a thoughtful assistant. You first think about solution step by step in mind. And then provides user with succinct answer.\n"
if USE_COT:
prompt += "Examples:\n"
shuffled_demos = random.sample(demo_examples, len(demo_examples))
for demo in shuffled_demos:
prompt += f"Question: {demo['question']}\nResponse: {demo['think']}\n{demo['answer']}\n\n"
# Append guidelines and the actual question
guidelines = (
"Guidelines:\n"
"- Show your thinking between opening <think> and closing </think> tags.\n"
"- Provide the answer between opening <answer> and closing </answer> tags.\n"
"- Include the specific value of the answer using \\boxed{ } format.\n\n"
"Task: Think step by step and solve the question given below.\n"
"Question:\n"
)
prompt += guidelines + question_text
return prompt
def main():
"""
Main function to benchmark the model on the GSM8K test set using batch processing.
"""
args = parse_args()
seed_everything(args.seed)
model_name = args.model_name_or_path.split("/")[-1]
print(model_name)
# Prepare test file path and download if needed
test_filepath = os.path.join(args.data_root, "gsm8k_test.jsonl")
if not os.path.exists(test_filepath):
download_url(
"https://raw.githubusercontent.com/openai/"
"grade-school-math/2909d34ef28520753df82a2234c357259d254aa8/"
"grade_school_math/data/test.jsonl",
args.data_root,
)
os.rename(os.path.join(args.data_root, "test.jsonl"), test_filepath)
# Load test data (mapping "question" to instruction and "answer" to output)
list_data_dict = load_jsonl(test_filepath, instruction="question", output="answer")
# Load the model and tokenizer
model, tokenizer = load(args.model_name_or_path)
# Optionally load a quantized model state
if args.load:
print("Loading quantized model from:", args.load)
model_state = torch.load(args.load, map_location="cpu")
model.load_state_dict(model_state, strict=False)
model.half().cuda()
# Process data in batches
batch_size = args.batch_size
answers = []
num_batches = (
len(list_data_dict) + batch_size - 1
) // batch_size # Ceiling division
generate_kwargs = dict(
max_new_tokens=2048,
min_p=0.01,
temperature=0.5,
max_length=32 * 1024,
do_sample=True,
)
print(
f"Processing {len(list_data_dict)} examples in {num_batches} batches of size {batch_size}"
)
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(list_data_dict))
batch_samples = list_data_dict[start_idx:end_idx]
# Prepare batch prompts
batch_prompts = []
batch_ground_truths = []
for sample in batch_samples:
prompt = build_prompt(sample["instruction"], demo_examples)
batch_prompts.append(prompt)
batch_ground_truths.append(extract_answer(sample["output"]))
# Generate completions for the batch
batch_completions = generate_batch(
model, tokenizer, batch_prompts, generate_kwargs
)
# Process results
for i, (prompt, completion, ground_truth, sample) in enumerate(
zip(batch_prompts, batch_completions, batch_ground_truths, batch_samples)
):
model_answer = extract_answer(completion)
is_correct = check(ground_truth, model_answer)
answers.append(is_correct)
if DEBUG or (
batch_idx == 0 and i < 3
): # Show first few examples of first batch
print(
f"Full Prompt: {prompt}\n\n"
f"Model Completion: {completion}\n\n"
f"Expected Answer: {ground_truth}\n\n"
f"Model Answer: {model_answer}\n\n"
f"Correct: {is_correct}\n\n"
)
# Print progress update
if (batch_idx + 1) % 1 == 0 or batch_idx == num_batches - 1:
print(
f"Processed {min((batch_idx + 1) * batch_size, len(list_data_dict))}/{len(list_data_dict)} questions, "
f"Correct: {sum(answers)}, "
f"Current Accuracy: {float(sum(answers)) / len(answers):.4f}"
)
# Save results
os.makedirs(args.output_dir, exist_ok=True)
with open(
os.path.join(args.output_dir, f"{model_name}_results.txt"),
"w",
encoding="utf-8",
) as f:
for answer in answers:
print(int(answer), file=f)
with open(
os.path.join(args.output_dir, f"{model_name}_scores.txt"), "w", encoding="utf-8"
) as f:
print(
f"Total questions: {len(answers)}, Correct: {sum(answers)}, "
f"Final Accuracy: {float(sum(answers)) / len(answers):.4f}",
file=f,
)
if __name__ == "__main__":
main()
|