boxin-wbx's picture
Upload folder using huggingface_hub
3cdba69 verified
"""GPQA (Graduate-Level Google-Proof Q&A) benchmark evaluation utilities.
This module provides evaluation functions for the GPQA benchmark, a challenging
multiple-choice science question dataset. It includes answer extraction from various
formats and mathematical equivalence checking.
"""
import argparse
import glob
import json
import os
import re
import signal
import numpy as np
from sympy import simplify
from sympy.parsing.latex import parse_latex
from tools.grader import math_equal
def get_option_char(s: str):
"""Extract single-letter option from LaTeX \\boxed{} construct.
Handles multiple formats:
- \\boxed{B}
- \\boxed{\\text{D}}
- \\boxed{\\text{(E)}}
- \\boxed{{A}}
- \\boxed{(A)}
Args:
s: String containing potential boxed answer
Returns:
str or None: Extracted letter (A-D), or None if no match found
"""
pattern = r"""
\\boxed\{ # \boxed{
\s*
(?: # one of:
\\text\{ # \text{…}
\s*\(?([A-Za-z])\)?\s*
\}
| # or
\{([A-Za-z])\} # {A}
| # or
\(\s*([A-Za-z])\s*\) # (A)
| # or
([A-Za-z]) # B
)
\s*
\} # }
"""
m = re.search(pattern, s, re.VERBOSE)
if not m:
return None
return m.group(1) or m.group(2) or m.group(3) or m.group(4)
def read_text_data(datapath):
"""Read model outputs from JSONL file.
Args:
datapath: Path to JSONL file containing model outputs
Returns:
list: List of output strings
"""
print("reading from %s" % datapath)
data_list = []
with open(datapath, "r") as f:
for line in f:
data_list.append(json.loads(line.strip())['output'])
return data_list
def read_jsonl_data(datapath):
"""Read model outputs from JSONL file (alternative method).
Args:
datapath: Path to JSONL file
Returns:
list: List of output strings
"""
print("reading from %s" % datapath)
data_list = []
with open(datapath, "r") as f:
for line in f:
data_item = json.loads(line.strip())
data_list.append(data_item['output'])
return data_list
def read_json_data(datapath):
"""Read JSON data file.
Args:
datapath: Path to JSON file
Returns:
Data structure from JSON file
"""
print("reading from %s" % datapath)
with open(datapath, "r") as f:
data_list = json.load(f)
return data_list
def evaluate_gsm8k_zeroshot(input_datapath, test_datapath):
def _maybe_remove_comma(x: str):
# Example: 5,600 -> 5600
return x.replace(',', '')
output_list = read_text_data(input_datapath)
gold_list = read_json_data(test_datapath)
# pattern = r"\\boxed\{(.*?)\}"
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
num_samples = len(gold_list)
# print(len(output_list), len(gold_list))
assert len(output_list) == len(gold_list) == num_samples
count_none = 0
correct = 0
for output, gold in zip(output_list, gold_list):
gold = gold['answer'].split("#### ")[-1]
matches1 = pattern1_re.findall(output)
matches2 = pattern2_re.findall(output)
matches3 = pattern3_re.findall(output)
matches4 = pattern4_re.findall(output)
matches5 = pattern5_re.findall(output)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
if extracted_answer is None:
count_none += 1
continue
# extracted_answer = _maybe_remove_comma(extracted_answer)
extracted_answer = math_answer_cleaning(extracted_answer)
# gold = _maybe_remove_comma(gold)
gold = math_answer_cleaning(gold)
# if extracted_answer == gold:
if math_equal(extracted_answer, gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
acc = correct / len(gold_list)
print("count_none:", count_none)
print("accuracy:", acc)
return acc
def is_completely_wrapped_by_text(input_string):
pattern = r'^\\text{(.*)}$'
match = re.match(pattern, input_string)
if match:
## input_string is completely wrapped by \text{}
extracted_content = match.group(1)
extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "")
return extracted_content
else:
return None
def math_answer_cleaning(answer):
## remove irrelevant text and space to see whether it is exact match
extracted_content = is_completely_wrapped_by_text(answer)
answer = extracted_content if extracted_content else answer
## convert 5,\!460 into 5460; convert 14{,}916 into 14916; convert \$4 into 4; convert 50\\\% into 50
answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "")
## convert \dfrac{3}{2} into frac{3}{2}
answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{")
## convert 121^\circ into 121
answer = answer.replace("^\circ", "")
answer = answer.replace("^{\circ}", "")
## remove \quad
answer = answer.replace("\quad", "")
## convert 558\,\text{nm} into 558
answer = re.sub(r'\\,\\text\{.*?\}', '', answer)
## convert 558\text{nm} into 558
answer = re.sub(r'\\text\{.*?\}', '', answer)
## convert 2.45e6^{-1} into 2.45e6; "15000^{-2}^{-1}" into "15000"
answer = re.sub(r'(\s\^\{-\d+\})', '', answer)
## remove space
answer = answer.replace(" ", "")
## remove \n
answer = answer.replace("\n", "").replace("\\n", "")
## convert 3.54\times10^{10} into 3.54e10
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer)
## convert 3.54\times10^10 into 3.54e10
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer)
## convert 2^{10} into 2^10
answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer)
## convert 10^{-5} into 1e-5; 10^{5} into 1e5
answer = re.sub(r"10\^\{(-?\d+)\}", r"1e\1", answer)
## remove comma
answer = answer.replace(",", "")
## lowercase
answer = answer.lower()
## convert 7.04e5\ into 7.04e5
if answer.endswith("\\"):
answer = answer[:-1]
## convert f(x)=ax+b into ax+b; convert z=123 into 123; convert t_r=123 into 123
func_pattern = r'^[a-zA-Z_]\w*\([a-zA-Z_]\w*\)$'
if "=" in answer and (re.match(func_pattern, answer.split("=")[0]) or len(answer.split("=")[0])<=3):
answer = answer.split("=", 1)[1]
return answer
def round_number(answer):
def _is_float(string):
try:
float(string)
return True
except:
return False
if _is_float(answer) and float(answer) < 1:
## to consider the case like 5.56e-10 (convert 5.56e-10 into 5.6e-10)
## still return a string type
return f"{float(answer):.2g}"
return answer
def evaluate_math500_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
print("reading from %s" % test_datapath)
# suppress_prints()
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = item['answer']
gold_list.append(answer)
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
# match = re.search(pattern1, line)
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
# restore_prints()
acc = correct / len(gold_list)
print("len(gold_list):", len(gold_list))
print("count_output_none:", count_output_none)
print("count_timeout:", count_timeout)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def evaluate_minerva_math_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
# pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
# pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
solution_list = []
print("reading from %s" % test_datapath)
# suppress_prints()
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
problem = item['problem']
solution = item['solution']
matches1 = pattern1_re.findall(solution)
if len(matches1) == 0:
extracted_answer = None
else:
extracted_answer = matches1[-1]
gold_list.append(extracted_answer)
solution_list.append(solution)
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
# match = re.search(pattern1, line)
matches1 = pattern1_re.findall(line)
# matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
# elif len(matches2) >= 1:
# extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
## remove unit
unit_list = ["\\hbar^{4}"]
for unit in unit_list:
if extracted_answer.endswith(unit):
extracted_answer = extracted_answer[:-len(unit)]
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
# restore_prints()
acc = correct / len(gold_list)
print("len(gold_list):", len(gold_list))
print("count_output_none:", count_output_none)
print("count_timeout:", count_timeout)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def calculate_numbers(input_string):
try:
result = eval(input_string)
return result
except:
return None
def is_equal_after_calculation(extracted_answer, gold):
## convert \frac{3}{2} into 3/2
gold = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', gold)
extracted_answer = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', extracted_answer)
gold_result = calculate_numbers(gold)
extracted_answer_result = calculate_numbers(extracted_answer)
if gold_result and gold_result == extracted_answer_result:
return True
else:
return False
def is_math_formula_equal(extracted_answer, gold):
try:
extracted_answer_expr = parse_latex(extracted_answer)
gold_expr = parse_latex(gold)
return simplify(extracted_answer_expr - gold_expr) == 0
except Exception as e:
print("error:", e)
return False
def check_after_fraction_mapping(extracted_answer, gold):
return re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', extracted_answer) == re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', gold)
def evaluate_gaokao2023en_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes']
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
question = item['question'].strip()
question_list.append(question)
answer = item['answer']
answer = re.sub(r'^\$(.*)\$$', r'\1', answer)
gold_list.append(answer)
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
# suppress_prints()
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
question = question_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
if len(gold) == 1 and gold.isalpha():
## gold is a option like A, B, C, D
## need to extract the content of this option
option = "(" + gold + ")"
assert option in question
start_option_idx = question.index(option)
next_option = "(" + chr(ord(gold)+1) + ")"
next_option2 = "(" + chr(ord(gold)+2) + ")"
if next_option in question:
end_option_idx = question.index(next_option)
gold2 = question[start_option_idx: end_option_idx]
elif next_option2 in question:
end_option_idx = question.index(next_option2)
gold2 = question[start_option_idx: end_option_idx]
else:
gold2 = question[start_option_idx:]
gold2 = gold2.replace(option, "").strip().replace("\\n", "")
else:
gold2 = None
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
if gold2:
gold2 = math_answer_cleaning(gold2)
for unit in unit_strings:
extracted_answer = extracted_answer.replace(unit, "")
gold = gold.replace(unit, "")
if gold2:
gold2 = gold2.replace(unit, "")
if "=" in extracted_answer and not gold2:
## convert x=3 into 3
extracted_answer = extracted_answer.split("=", 1)[1]
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif gold2 and math_equal(extracted_answer, gold2):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
# restore_prints()
acc = correct / len(gold_list)
print("benchmark size:", len(gold_list))
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def evaluate_olympiadbench_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
assert len(item['final_answer']) == 1
answer = item['final_answer'][0]
answer = re.sub(r'^\$(.*)\$$', r'\1', answer)
gold_list.append(answer)
question_list.append(item['question'])
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("count_timeout:", count_timeout)
print("accuracy:", acc)
return acc
def evaluate_collegemath_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
# pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern6 = r"\\\[\s*(.*?)\s*\\\]"
def _extra_content_from_gold(input_string):
pattern_extra = r'\$(.*?)\$'
matches = re.findall(pattern_extra, input_string)
return matches
def _extra_cleaning(input_string):
## convert 558\mathrm{ft} into 558
input_string = re.sub(r'\\mathrm\{.*?\}', '', input_string)
input_string = re.sub(r'\$\([^)]*\)', '', input_string)
return input_string
def _check_after_equal(extracted_answer_ori, gold_ori, matches):
if extracted_answer_ori.replace(",", "").replace("$", "") == gold_ori.replace(",", "").replace("$", ""):
return True
if not extracted_answer_ori.split("=")[-1] == gold_ori.split("=")[-1].replace("$", ""):
return False
if "," in gold_ori or "or" in gold_ori or "and" in gold_ori:
## there are multiple answers
if len(matches) <= 1:
return False
answer1 = extracted_answer_ori.split("=")[-1]
answer2 = math_answer_cleaning(matches[-2].split("=")[-1])
gold_ori = gold_ori.replace("$", "")
if "or" in gold_ori:
gold_list = gold_ori.split("or", 1)
elif "and" in gold_ori:
gold_list = gold_ori.split("and", 1)
else:
gold_list = gold_ori.split(",", 1)
gold_ori1 = gold_list[-1].split("=")[-1]
gold_ori2 = gold_list[-2].split("=")[-1]
if math_equal(answer1, gold_ori1) and math_equal(answer2, gold_ori2):
return True
elif math_equal(answer2, gold_ori1) and math_equal(answer1, gold_ori2):
return True
return False
else:
return True
pattern1_re = re.compile(pattern1, re.DOTALL)
# pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
pattern6_re = re.compile(pattern6, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = item['answer']
answer = re.sub(r'^\$(.*)\$$', r'\1', answer)
gold_list.append(answer)
question_list.append(item['question'])
count_output_none = 0
count_answer_none = 0
count_timeout = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
# matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
matches6 = pattern6_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
# elif len(matches2) >= 1:
# extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
elif len(matches6) >= 1:
extracted_answer = matches6[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
extracted_answer = _extra_cleaning(extracted_answer)
gold = _extra_cleaning(gold)
if gold.endswith("-"):
gold = gold[:-1]
if gold.endswith("."):
gold = gold[:-1]
if gold.endswith("hours"):
gold = gold[:-len("hours")]
if extracted_answer.endswith("."):
extracted_answer = extracted_answer[:-1]
extracted_answer_ori = extracted_answer
gold_ori = gold
if "=" in gold:
gold = gold.split("=", 1)[1]
if ":" in gold:
gold = gold.split(":", 1)[1]
if "=" in extracted_answer:
extracted_answer = extracted_answer.split("=", 1)[1]
if ":" in extracted_answer:
extracted_answer = extracted_answer.split(":", 1)[1]
## \emptyset and \oslash both reprsent empty set in latex
extracted_answer = extracted_answer.replace("\\emptyset", "\\oslash")
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif _check_after_equal(extracted_answer_ori, gold_ori, matches1):
correct += 1
elif _check_after_equal(extracted_answer, gold, matches1):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
else:
# restore_prints()
gold_matches = _extra_content_from_gold(gold)
correctflag = False
if len(gold_matches) >= 1:
gold2 = "".join(gold_matches)
if "\\approx" in gold2:
gold2 = gold2.split("\\approx", 1)[1]
## convert 1,500 into 1500
gold2 = gold2.replace(",", "")
extracted_answer2 = extracted_answer.replace(",", "")
if gold2 != "" and extracted_answer2 == gold2:
correct += 1
correctflag = True
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("count_timeout:", count_timeout)
print("accuracy:", acc)
return acc
def evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath):
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
# print(item)
# print(item.keys())
# exit(1)
answer = str(item['answer'])
gold_list.append(answer)
question_list.append(item['problem'])
count_output_none = 0
count_answer_none = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
if math_equal(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def evaluate_omnimath_zeroshot(input_datapath, test_datapath):
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
# raise Exception("Function took too long to complete.")
raise _TimeoutException
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
gold_list = []
question_list = []
print("reading from %s" % test_datapath)
with open(test_datapath, "r") as f:
for line in f:
item = json.loads(line)
answer = str(item['answer'])
gold_list.append(answer)
question_list.append(item['problem'])
count_output_none = 0
count_answer_none = 0
correct = 0
print("reading from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line = json.loads(line)['output']
matches1 = pattern1_re.findall(line)
matches2 = pattern2_re.findall(line)
matches3 = pattern3_re.findall(line)
matches4 = pattern4_re.findall(line)
matches5 = pattern5_re.findall(line)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
gold = gold_list[i]
if extracted_answer is None:
# print("="*80)
# print(line)
count_output_none += 1
continue
if gold is None:
count_answer_none += 1
continue
gold_ori = gold
extracted_answer = math_answer_cleaning(extracted_answer)
gold = math_answer_cleaning(gold)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if math_equal(extracted_answer, gold):
correct += 1
elif check_after_fraction_mapping(extracted_answer, gold):
correct += 1
elif round_number(extracted_answer) == round_number(gold):
correct += 1
elif is_equal_after_calculation(extracted_answer, gold):
correct += 1
## Disable the alarm
signal.alarm(0)
except:
## Disable the alarm
signal.alarm(0)
count_timeout += 1
acc = correct / len(gold_list)
print("count_output_none:", count_output_none)
print("count_answer_none:", count_answer_none)
print("accuracy:", acc)
return acc
def get_answer_by_marjority_voting(output_list):
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern2 = r"\*\*(.*?)\*\*"
pattern3 = r"\\\[\n(.*?)\n\\\]"
pattern4 = r'is \\\((.*?)\\\)'
pattern5 = r"\\\[\\n(.*?)\\n\\\]"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(pattern2, re.DOTALL)
pattern3_re = re.compile(pattern3, re.DOTALL)
pattern4_re = re.compile(pattern4, re.DOTALL)
pattern5_re = re.compile(pattern5, re.DOTALL)
answer_dict = {}
for output in output_list:
matches1 = pattern1_re.findall(output)
matches2 = pattern2_re.findall(output)
matches3 = pattern3_re.findall(output)
matches4 = pattern4_re.findall(output)
matches5 = pattern5_re.findall(output)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
elif len(matches3) >= 1:
extracted_answer = matches3[-1]
elif len(matches4) >= 1:
extracted_answer = matches4[-1]
elif len(matches5) >= 1:
extracted_answer = matches5[-1]
else:
extracted_answer = None
if extracted_answer is None:
continue
extracted_answer = math_answer_cleaning(extracted_answer)
has_found = False
for prev_ans in answer_dict:
if extracted_answer == prev_ans:
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif math_equal(extracted_answer, prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif check_after_fraction_mapping(extracted_answer, prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif round_number(extracted_answer) == round_number(prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
elif is_equal_after_calculation(extracted_answer, prev_ans):
answer_dict[prev_ans]['count'] += 1
has_found = True
break
if not has_found:
answer_dict[extracted_answer] = {"count": 1, "original_output": output}
## rank the answer based on count
sorted_answers = sorted(answer_dict, key=lambda x: answer_dict[x]["count"], reverse=True)
return answer_dict[sorted_answers[0]]
def evaluate_gpqa(input_datapath, test_datapath):
"""Evaluate GPQA (Graduate-Level Google-Proof Q&A) benchmark.
Extracts answers from multiple formats and compares against correct choices
using both exact matching and mathematical equivalence checking.
Args:
input_datapath: Path to model output JSONL file
test_datapath: Path to GPQA test JSON file
Returns:
float: Accuracy score
"""
class _TimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
raise _TimeoutException
output_list = read_text_data(input_datapath)
gold_list = read_json_data(test_datapath)
num_samples = len(gold_list)
assert len(output_list) == len(gold_list) == num_samples
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
pattern1_re = re.compile(pattern1, re.DOTALL)
pattern2_re = re.compile(r'\b(?:Answer|Final Answer|ANSWER)\b[:\s\*]*\(?([A-D])\)?')
count_none = 0
count_timeout = 0
correct = 0
for output, gold in zip(output_list, gold_list):
choices = [gold['choice_A'], gold['choice_B'], gold['choice_C'], gold['choice_D']]
correct_answer = gold["correct_answer"]
correct_index = choices.index(correct_answer)
correct_choice = "ABCD"[correct_index]
extracted_answer = get_option_char(output)
if extracted_answer is None:
matches1 = pattern1_re.findall(output)
matches2 = pattern2_re.findall(output)
if len(matches1) >= 1:
extracted_answer = matches1[-1]
elif len(matches2) >= 1:
extracted_answer = matches2[-1]
else:
extracted_answer = None
if extracted_answer is None:
count_none += 1
continue
correct_answer = math_answer_cleaning(correct_answer)
extracted_answer = math_answer_cleaning(extracted_answer)
try:
# raise exception after 5 sections
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(5)
if extracted_answer.lower() == correct_choice.lower():
correct += 1
elif math_equal(extracted_answer, correct_answer):
correct += 1
elif "("+correct_choice+")" in extracted_answer:
correct += 1
signal.alarm(0)
except:
signal.alarm(0)
count_timeout += 1
acc = correct / num_samples
print("num_samples:", num_samples)
print("count_none:", count_none)
print("accuracy:", acc)
return acc
def get_args():
"""Parse command-line arguments for GPQA evaluation script.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="GPQA Benchmark Evaluation")
parser.add_argument("--modelfolder", type=str, required=True,
help="Path to model output folder")
parser.add_argument("--testfolder", type=str, required=True,
help="Path to test data folder")
args = parser.parse_args()
return args
def main():
"""Main evaluation function for GPQA benchmark."""
args = get_args()
model_folder = args.modelfolder
test_datafolder = args.testfolder
gpqa_accs = []
input_datapaths = glob.glob(model_folder+"/outputs_*/gpqa_diamond.jsonl")
if not input_datapaths:
print(f"No GPQA output files found in {model_folder}")
return
acc_tmp = 0
for input_datapath in input_datapaths:
test_datapath = os.path.join(test_datafolder, "gpqa/gpqa_diamond.json")
print(f"\nEvaluating: {input_datapath}")
acc = evaluate_gpqa(input_datapath, test_datapath)
gpqa_accs.append(acc)
acc_tmp += acc
gpqa_acc = acc_tmp / len(input_datapaths)
gpqa_std = np.std(gpqa_accs) if len(gpqa_accs) > 1 else 0
print("="*80)
print(f"Average accuracy for GPQA: {gpqa_acc:.4f} ± {gpqa_std:.4f}")
print(f"Number of runs evaluated: {len(gpqa_accs)}")
if __name__ == "__main__":
main()