boxin-wbx's picture
Upload folder using huggingface_hub
3cdba69 verified
"""Math benchmark evaluation utilities.
This module provides evaluation functions for math benchmarks
It includes answer extraction, cleaning, and grading logic for math problems.
"""
import argparse
import json
import os
import re
import signal
import sys
import numpy as np
from sympy import simplify
from sympy.parsing.latex import parse_latex
from tqdm import tqdm
from tools.grader import math_equal
def read_text_data(datapath):
"""Read model outputs from JSONL file.
Args:
datapath: Path to JSONL file containing model outputs
Returns:
list: List of output strings
"""
print("reading from %s" % datapath)
data_list = []
with open(datapath, "r") as f:
for line in f:
data_list.append(json.loads(line.strip())['output'])
return data_list
def read_jsonl_data(datapath):
"""Read model outputs from JSONL file (alternative method).
Args:
datapath: Path to JSONL file
Returns:
list: List of output strings
"""
print("reading from %s" % datapath)
data_list = []
with open(datapath, "r") as f:
for line in f:
data_item = json.loads(line.strip())
data_list.append(data_item['output'])
return data_list
def read_json_data(datapath):
"""Read JSON data file.
Args:
datapath: Path to JSON file
Returns:
Data structure from JSON file
"""
print("reading from %s" % datapath)
with open(datapath, "r") as f:
data_list = json.load(f)
return data_list
def evaluate_gsm8k_zeroshot(input_datapath, test_datapath):
"""Evaluate GSM8K zero-shot performance.
Args:
input_datapath: Path to model output JSONL file
test_datapath: Path to GSM8K test JSON file
Returns:
float: Accuracy score
"""
output_list = read_text_data(input_datapath)
gold_list = read_json_data(test_datapath)
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
num_samples = len(gold_list)
assert len(output_list) == len(gold_list) == num_samples
count_none = 0
correct = 0
for output, gold in zip(output_list, gold_list):
gold = gold['answer'].split("#### ")[-1]
matches1 = pattern1_re.findall(output)
matches2 = pattern2_re.findall(output)
matches3 = pattern3_re.findall(output)
matches4 = pattern4_re.findall(output)
matches5 = pattern5_re.findall(output)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
if extracted_answer is None:
count_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
if math_equal(extracted_answer, gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
acc = correct / len(gold_list)
print("count_none:", count_none)
print("accuracy:", acc)
return acc
def is_completely_wrapped_by_text(input_string):
"""Check if input string is completely wrapped by LaTeX \\text{}.
Args:
input_string: LaTeX string to check
Returns:
str or None: Extracted content if wrapped, None otherwise
"""
pattern = r'^\\text{(.*)}$'
match = re.match(pattern, input_string)
if match:
extracted_content = match.group(1)
extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "")
return extracted_content
else:
return None
def math_answer_cleaning(answer):
"""Clean and normalize math answer for comparison.
Performs various cleaning operations:
- Remove LaTeX formatting (\\text, \\quad, etc.)
- Normalize fractions and scientific notation
- Remove units and special characters
- Convert to lowercase
Args:
answer: Raw answer string
Returns:
str: Cleaned answer string
"""
extracted_content = is_completely_wrapped_by_text(answer)
answer = extracted_content if extracted_content else answer
answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "")
answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{")
answer = answer.replace("^\circ", "").replace("^{\circ}", "")
answer = answer.replace("\quad", "")
answer = re.sub(r'\\,\\text\{.*?\}', '', answer)
answer = re.sub(r'\\text\{.*?\}', '', answer)
answer = re.sub(r'(\s\^\{-\d+\})', '', answer)
answer = answer.replace(" ", "").replace("\n", "").replace("\\n", "")
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer)
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer)
answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer)
answer = re.sub(r"10\^\{(-?\d+)\}", r"1e\1", answer)
answer = answer.replace(",", "").lower()
if answer.endswith("\\"):
answer = answer[:-1]
func_pattern = r'^[a-zA-Z_]\w*\([a-zA-Z_]\w*\)$'
if "=" in answer and (re.match(func_pattern, answer.split("=")[0]) or len(answer.split("=")[0])<=3):
answer = answer.split("=", 1)[1]
return answer
def round_number(answer):
"""Round very small numbers to 2 significant figures.
Args:
answer: Answer string
Returns:
str: Rounded answer if applicable, otherwise original answer
"""
def _is_float(string):
try:
float(string)
return True
except:
return False
if _is_float(answer) and float(answer) < 1:
return f"{float(answer):.2g}"
return answer
def evaluate_math500_zeroshot(input_datapath, test_datapath):
"""Evaluate MATH-500 zero-shot performance with timeout protection.
Args:
input_datapath: Path to model output JSONL file
test_datapath: Path to MATH-500 test JSONL file
Returns:
float: Accuracy score
"""
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
print("reading from %s" % test_datapath)
# suppress_prints()
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = item['answer']
gold_list.append(answer)
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
# match = re.search(pattern1, line)
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
# restore_prints()
acc = correct / len(gold_list)
print("len(gold_list):", len(gold_list))
print("count_output_none:", count_output_none)
print("count_timeout:", count_timeout)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def evaluate_minerva_math_zeroshot(input_datapath, test_datapath):
"""Evaluate Minerva Math zero-shot performance with timeout protection.
Args:
input_datapath: Path to model output JSONL file
test_datapath: Path to Minerva Math test JSONL file
Returns:
float: Accuracy score
"""
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
# pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
# pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
solution_list = []
print("reading from %s" % test_datapath)
# suppress_prints()
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
problem = item['problem']
solution = item['solution']
matches1 = pattern1_re.findall(solution)
if len(matches1) == 0:
extracted_answer = None
else:
extracted_answer = matches1[-1]
gold_list.append(extracted_answer)
solution_list.append(solution)
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
# match = re.search(pattern1, line)
matches1 = pattern1_re.findall(line)
# matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
# elif len(matches2) >= 1:
# extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
## remove unit
unit_list = ["\\hbar^{4}"]
for unit in unit_list:
if extracted_answer.endswith(unit):
extracted_answer = extracted_answer[:-len(unit)]
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
# restore_prints()
acc = correct / len(gold_list)
print("len(gold_list):", len(gold_list))
print("count_output_none:", count_output_none)
print("count_timeout:", count_timeout)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def calculate_numbers(input_string):
"""Safely evaluate mathematical expression string.
Args:
input_string: Mathematical expression as string
Returns:
Result of evaluation, or None if error
"""
try:
result = eval(input_string)
return result
except:
return None
def is_equal_after_calculation(extracted_answer, gold):
"""Check if answers are equal after converting fractions and evaluating.
Args:
extracted_answer: Extracted answer string
gold: Gold standard answer string
Returns:
bool: True if answers are mathematically equal
"""
gold = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', gold)
extracted_answer = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', extracted_answer)
gold_result = calculate_numbers(gold)
extracted_answer_result = calculate_numbers(extracted_answer)
if gold_result and gold_result == extracted_answer_result:
return True
else:
return False
def is_math_formula_equal(extracted_answer, gold):
"""Check if two LaTeX formulas are mathematically equivalent using SymPy.
Args:
extracted_answer: Extracted answer string (LaTeX)
gold: Gold standard answer string (LaTeX)
Returns:
bool: True if formulas are mathematically equivalent
"""
try:
extracted_answer_expr = parse_latex(extracted_answer)
gold_expr = parse_latex(gold)
return simplify(extracted_answer_expr - gold_expr) == 0
except Exception as e:
print("error:", e)
return False
def check_after_fraction_mapping(extracted_answer, gold):
"""Check if answers match after converting LaTeX fractions to division.
Args:
extracted_answer: Extracted answer string
gold: Gold standard answer string
Returns:
bool: True if answers match after fraction conversion
"""
return re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', extracted_answer) == re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', gold)
def evaluate_gaokao2023en_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes']
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
question = item['question'].strip()
question_list.append(question)
answer = item['answer']
answer = re.sub(r'^\$(.*)\$$', r'\1', answer)
gold_list.append(answer)
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
# suppress_prints()
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
question = question_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
if len(gold) == 1 and gold.isalpha():
## gold is a option like A, B, C, D
## need to extract the content of this option
option = "(" + gold + ")"
assert option in question
start_option_idx = question.index(option)
next_option = "(" + chr(ord(gold)+1) + ")"
next_option2 = "(" + chr(ord(gold)+2) + ")"
if next_option in question:
end_option_idx = question.index(next_option)
gold2 = question[start_option_idx: end_option_idx]
elif next_option2 in question:
end_option_idx = question.index(next_option2)
gold2 = question[start_option_idx: end_option_idx]
else:
gold2 = question[start_option_idx:]
gold2 = gold2.replace(option, "").strip().replace("\\n", "")
else:
gold2 = None
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
if gold2:
gold2 = math_answer_cleaning(gold2)
for unit in unit_strings:
extracted_answer = extracted_answer.replace(unit, "")
gold = gold.replace(unit, "")
if gold2:
gold2 = gold2.replace(unit, "")
if "=" in extracted_answer and not gold2:
## convert x=3 into 3
extracted_answer = extracted_answer.split("=", 1)[1]
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif gold2 and math_equal(extracted_answer, gold2):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
# restore_prints()
acc = correct / len(gold_list)
print("benchmark size:", len(gold_list))
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def evaluate_olympiadbench_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
assert len(item['final_answer']) == 1
answer = item['final_answer'][0]
answer = re.sub(r'^\$(.*)\$$', r'\1', answer)
gold_list.append(answer)
question_list.append(item['question'])
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("count_timeout:", count_timeout)
print("accuracy:", acc)
return acc
def evaluate_collegemath_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
# pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern6 = r"\\\[\s*(.*?)\s*\\\]"
def _extra_content_from_gold(input_string):
pattern_extra = r'\$(.*?)\$'
matches = re.findall(pattern_extra, input_string)
return matches
def _extra_cleaning(input_string):
## convert 558\mathrm{ft} into 558
input_string = re.sub(r'\\mathrm\{.*?\}', '', input_string)
input_string = re.sub(r'\$\([^)]*\)', '', input_string)
return input_string
def _check_after_equal(extracted_answer_ori, gold_ori, matches):
if extracted_answer_ori.replace(",", "").replace("$", "") == gold_ori.replace(",", "").replace("$", ""):
return True
if not extracted_answer_ori.split("=")[-1] == gold_ori.split("=")[-1].replace("$", ""):
return False
if "," in gold_ori or "or" in gold_ori or "and" in gold_ori:
## there are multiple answers
if len(matches) <= 1:
return False
answer1 = extracted_answer_ori.split("=")[-1]
answer2 = math_answer_cleaning(matches[-2].split("=")[-1])
gold_ori = gold_ori.replace("$", "")
if "or" in gold_ori:
gold_list = gold_ori.split("or", 1)
elif "and" in gold_ori:
gold_list = gold_ori.split("and", 1)
else:
gold_list = gold_ori.split(",", 1)
gold_ori1 = gold_list[-1].split("=")[-1]
gold_ori2 = gold_list[-2].split("=")[-1]
if math_equal(answer1, gold_ori1) and math_equal(answer2, gold_ori2):
return True
elif math_equal(answer2, gold_ori1) and math_equal(answer1, gold_ori2):
return True
return False
else:
return True
pattern1_re = re.compile(pattern1, re.DOTALL)
# pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
pattern6_re = re.compile(pattern6, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = item['answer']
answer = re.sub(r'^\$(.*)\$$', r'\1', answer)
gold_list.append(answer)
question_list.append(item['question'])
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
# matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
matches6 = pattern6_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
# elif len(matches2) >= 1:
# extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
elif len(matches6) >= 1:
extracted_answer = matches6[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
extracted_answer = _extra_cleaning(extracted_answer)
gold = _extra_cleaning(gold)
if gold.endswith("-"):
gold = gold[:-1]
if gold.endswith("."):
gold = gold[:-1]
if gold.endswith("hours"):
gold = gold[:-len("hours")]
if extracted_answer.endswith("."):
extracted_answer = extracted_answer[:-1]
extracted_answer_ori = extracted_answer
gold_ori = gold
if "=" in gold:
gold = gold.split("=", 1)[1]
if ":" in gold:
gold = gold.split(":", 1)[1]
if "=" in extracted_answer:
extracted_answer = extracted_answer.split("=", 1)[1]
if ":" in extracted_answer:
extracted_answer = extracted_answer.split(":", 1)[1]
## \emptyset and \oslash both reprsent empty set in latex
extracted_answer = extracted_answer.replace("\\emptyset", "\\oslash")
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif _check_after_equal(extracted_answer_ori, gold_ori, matches1):
correct += 1
elif _check_after_equal(extracted_answer, gold, matches1):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
else:
# restore_prints()
gold_matches = _extra_content_from_gold(gold)
correctflag = False
if len(gold_matches) >= 1:
gold2 = "".join(gold_matches)
if "\\approx" in gold2:
gold2 = gold2.split("\\approx", 1)[1]
## convert 1,500 into 1500
gold2 = gold2.replace(",", "")
extracted_answer2 = extracted_answer.replace(",", "")
if gold2 != "" and extracted_answer2 == gold2:
correct += 1
correctflag = True
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("count_timeout:", count_timeout)
print("accuracy:", acc)
return acc
def evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath):
"""Evaluate AMC23 or AIME24/25 zero-shot performance.
Args:
input_datapath: Path to model output JSONL file
test_datapath: Path to AMC23/AIME24/AIME25 test JSONL file
Returns:
float: Accuracy score
"""
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = str(item['answer'])
gold_list.append(answer)
question_list.append(item['problem'])
count_output_none = 0
count_answer_none = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
if math_equal(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def evaluate_omnimath_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = str(item['answer'])
gold_list.append(answer)
question_list.append(item['problem'])
count_output_none = 0
count_answer_none = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
gold_ori = gold
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def get_answer_by_marjority_voting(output_list):
"""Get the most common answer from multiple model outputs via majority voting.
Args:
output_list: List of model output strings
Returns:
dict: Dictionary with 'count' and 'original_output' for the majority answer
"""
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
answer_dict = {}
for output in output_list:
matches1 = pattern1_re.findall(output)
matches2 = pattern2_re.findall(output)
matches3 = pattern3_re.findall(output)
matches4 = pattern4_re.findall(output)
matches5 = pattern5_re.findall(output)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
if extracted_answer is None:
continue
extracted_answer = math_answer_cleaning(extracted_answer)
has_found = False
for prev_ans in answer_dict:
if extracted_answer == prev_ans:
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif math_equal(extracted_answer, prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif check_after_fraction_mapping(extracted_answer, prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif round_number(extracted_answer) == round_number(prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif is_equal_after_calculation(extracted_answer, prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
if not has_found:
answer_dict[extracted_answer] = {"count": 1, "original_output": output}
## rank the answer based on count
sorted_answers = sorted(answer_dict, key=lambda x: answer_dict[x]["count"], reverse=True)
return answer_dict[sorted_answers[0]]
def evaluate_gpqa(input_datapath, test_datapath):
"""Evaluate GPQA (Graduate-Level Google-Proof Q&A) benchmark.
Args:
input_datapath: Path to model output JSONL file
test_datapath: Path to GPQA test JSON file
Returns:
float: Accuracy score
"""
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
raise _TimeoutException
output_list = read_text_data(input_datapath)
gold_list = read_json_data(test_datapath)
num_samples = len(gold_list)
assert len(output_list) == len(gold_list) == num_samples
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern1_re = re.compile(pattern1, re.DOTALL)
# pattern2_re = re.compile(r"\*\*Answer:?(\*\*)?\s*\(?([A-D])\)?(\*\*)?")
pattern2_re = re.compile(r'\b(?:Answer|Final Answer|ANSWER)\b[:\s\*]*\(?([A-D])\)?')
count_none = 0
count_timeout = 0
correct = 0
for output, gold in zip(output_list, gold_list):
choices = [gold['choice_A'], gold['choice_B'], gold['choice_C'], gold['choice_D']]
correct_answer = gold["correct_answer"]
correct_index = choices.index(correct_answer)
correct_choice = "ABCD"[correct_index]
matches1 = pattern1_re.findall(output)
matches2 = pattern2_re.findall(output)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
else:
extracted_answer = None
if extracted_answer is None:
count_none += 1
continue
correct_answer = math_answer_cleaning(correct_answer)
extracted_answer = math_answer_cleaning(extracted_answer)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if extracted_answer.lower() == correct_choice.lower():
correct += 1
elif math_equal(extracted_answer, correct_answer):
correct += 1
elif "("+correct_choice+")" in extracted_answer:
## (A/B/C/D) in extracted_answer
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / num_samples
print("num_samples:", num_samples)
print("count_none:", count_none)
print("accuracy:", acc)
return acc
def get_args():
"""Parse command-line arguments for evaluation script.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Math Benchmark Evaluation")
parser.add_argument("--modelfolder", type=str, required=True, help="Path to model output folder")
parser.add_argument("--testfolder", type=str, required=True, help="Path to test data folder")
args = parser.parse_args()
return args
def check_finish(input_datapath):
"""Check the finish rate (non-empty outputs) of model outputs.
Args:
input_datapath: Path to model output JSONL file
Returns:
float: Finish rate (proportion of non-empty outputs)
"""
finish_rates = []
with open(input_datapath, "r") as f:
for line in f:
item = json.loads(line)
if not item['reason']:
finish_rates.append(0)
output = item['output']
finish_rates.append(1 if output else 0)
return np.mean(finish_rates)
def main():
"""Main evaluation function for AIME benchmarks with W&B logging."""
args = get_args()
model_folder = args.modelfolder
test_datafolder = args.testfolder
import glob
avg_acc = []
avg_common_acc = []
aime24_accs = []
aime25_accs = []
aime24_finish = []
aime25_finish = []
input_datapaths = glob.glob(model_folder+"/outputs_*/aime24.jsonl")
acc_tmp = 0
for input_datapath in input_datapaths:
test_datapath = os.path.join(test_datafolder, "qwen2_math/aime24/test.jsonl")
acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath)
aime24_accs.append(acc)
finish = check_finish(input_datapath)
aime24_finish.append(finish)
acc_tmp += acc
aime24_acc = acc_tmp / len(input_datapaths)
aime24_std = np.std(aime24_accs) if len(aime24_accs) > 1 else 0
aime24_finish = np.mean(aime24_finish)
print("-"*80)
print("avg acc for aime24:", aime24_acc, "std:", aime24_std)
print("avg finish for aime24:", aime24_finish)
avg_acc.append(aime24_acc)
avg_common_acc.append(aime24_acc)
input_datapaths = glob.glob(model_folder+"/outputs_*/aime25.jsonl")
acc_tmp = 0
for input_datapath in input_datapaths:
test_datapath = os.path.join(test_datafolder, "aime25/test.jsonl")
acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath)
aime25_accs.append(acc)
finish = check_finish(input_datapath)
aime25_finish.append(finish)
acc_tmp += acc
aime25_acc = acc_tmp / len(input_datapaths)
aime25_std = np.std(aime25_accs) if len(aime25_accs) > 1 else 0
aime25_finish = np.mean(aime25_finish)
print("-"*80)
print("avg acc for aime25:", aime25_acc, "std:", aime25_std)
print("avg finish for aime25:", aime25_finish)
avg_acc.append(aime25_acc)
avg_common_acc.append(aime25_acc)
print("="*80)
print("average acc across AIME24:", aime24_acc, "±", aime24_std, ", AIME25:", aime25_acc, "±", aime25_std)
print("average finish across AIME24:", aime24_finish, ", AIME25:", aime25_finish)
if __name__ == "__main__":
main()