|
|
"""Math benchmark evaluation utilities. |
|
|
|
|
|
This module provides evaluation functions for math benchmarks |
|
|
|
|
|
It includes answer extraction, cleaning, and grading logic for math problems. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import signal |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
from sympy import simplify |
|
|
from sympy.parsing.latex import parse_latex |
|
|
from tqdm import tqdm |
|
|
|
|
|
from tools.grader import math_equal |
|
|
|
|
|
|
|
|
def read_text_data(datapath): |
|
|
"""Read model outputs from JSONL file. |
|
|
|
|
|
Args: |
|
|
datapath: Path to JSONL file containing model outputs |
|
|
|
|
|
Returns: |
|
|
list: List of output strings |
|
|
""" |
|
|
print("reading from %s" % datapath) |
|
|
data_list = [] |
|
|
with open(datapath, "r") as f: |
|
|
for line in f: |
|
|
data_list.append(json.loads(line.strip())['output']) |
|
|
|
|
|
return data_list |
|
|
|
|
|
|
|
|
def read_jsonl_data(datapath): |
|
|
"""Read model outputs from JSONL file (alternative method). |
|
|
|
|
|
Args: |
|
|
datapath: Path to JSONL file |
|
|
|
|
|
Returns: |
|
|
list: List of output strings |
|
|
""" |
|
|
print("reading from %s" % datapath) |
|
|
data_list = [] |
|
|
with open(datapath, "r") as f: |
|
|
for line in f: |
|
|
data_item = json.loads(line.strip()) |
|
|
data_list.append(data_item['output']) |
|
|
|
|
|
return data_list |
|
|
|
|
|
|
|
|
def read_json_data(datapath): |
|
|
"""Read JSON data file. |
|
|
|
|
|
Args: |
|
|
datapath: Path to JSON file |
|
|
|
|
|
Returns: |
|
|
Data structure from JSON file |
|
|
""" |
|
|
print("reading from %s" % datapath) |
|
|
with open(datapath, "r") as f: |
|
|
data_list = json.load(f) |
|
|
|
|
|
return data_list |
|
|
|
|
|
|
|
|
def evaluate_gsm8k_zeroshot(input_datapath, test_datapath): |
|
|
"""Evaluate GSM8K zero-shot performance. |
|
|
|
|
|
Args: |
|
|
input_datapath: Path to model output JSONL file |
|
|
test_datapath: Path to GSM8K test JSON file |
|
|
|
|
|
Returns: |
|
|
float: Accuracy score |
|
|
""" |
|
|
output_list = read_text_data(input_datapath) |
|
|
gold_list = read_json_data(test_datapath) |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
num_samples = len(gold_list) |
|
|
assert len(output_list) == len(gold_list) == num_samples |
|
|
|
|
|
count_none = 0 |
|
|
correct = 0 |
|
|
for output, gold in zip(output_list, gold_list): |
|
|
gold = gold['answer'].split("#### ")[-1] |
|
|
|
|
|
matches1 = pattern1_re.findall(output) |
|
|
matches2 = pattern2_re.findall(output) |
|
|
matches3 = pattern3_re.findall(output) |
|
|
matches4 = pattern4_re.findall(output) |
|
|
matches5 = pattern5_re.findall(output) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_none += 1 |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("count_none:", count_none) |
|
|
print("accuracy:", acc) |
|
|
return acc |
|
|
|
|
|
|
|
|
def is_completely_wrapped_by_text(input_string): |
|
|
"""Check if input string is completely wrapped by LaTeX \\text{}. |
|
|
|
|
|
Args: |
|
|
input_string: LaTeX string to check |
|
|
|
|
|
Returns: |
|
|
str or None: Extracted content if wrapped, None otherwise |
|
|
""" |
|
|
pattern = r'^\\text{(.*)}$' |
|
|
match = re.match(pattern, input_string) |
|
|
if match: |
|
|
extracted_content = match.group(1) |
|
|
extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "") |
|
|
return extracted_content |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
def math_answer_cleaning(answer): |
|
|
"""Clean and normalize math answer for comparison. |
|
|
|
|
|
Performs various cleaning operations: |
|
|
- Remove LaTeX formatting (\\text, \\quad, etc.) |
|
|
- Normalize fractions and scientific notation |
|
|
- Remove units and special characters |
|
|
- Convert to lowercase |
|
|
|
|
|
Args: |
|
|
answer: Raw answer string |
|
|
|
|
|
Returns: |
|
|
str: Cleaned answer string |
|
|
""" |
|
|
extracted_content = is_completely_wrapped_by_text(answer) |
|
|
answer = extracted_content if extracted_content else answer |
|
|
|
|
|
answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "") |
|
|
answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{") |
|
|
answer = answer.replace("^\circ", "").replace("^{\circ}", "") |
|
|
answer = answer.replace("\quad", "") |
|
|
answer = re.sub(r'\\,\\text\{.*?\}', '', answer) |
|
|
answer = re.sub(r'\\text\{.*?\}', '', answer) |
|
|
answer = re.sub(r'(\s\^\{-\d+\})', '', answer) |
|
|
answer = answer.replace(" ", "").replace("\n", "").replace("\\n", "") |
|
|
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer) |
|
|
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer) |
|
|
answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer) |
|
|
answer = re.sub(r"10\^\{(-?\d+)\}", r"1e\1", answer) |
|
|
answer = answer.replace(",", "").lower() |
|
|
|
|
|
if answer.endswith("\\"): |
|
|
answer = answer[:-1] |
|
|
|
|
|
func_pattern = r'^[a-zA-Z_]\w*\([a-zA-Z_]\w*\)$' |
|
|
if "=" in answer and (re.match(func_pattern, answer.split("=")[0]) or len(answer.split("=")[0])<=3): |
|
|
answer = answer.split("=", 1)[1] |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
def round_number(answer): |
|
|
"""Round very small numbers to 2 significant figures. |
|
|
|
|
|
Args: |
|
|
answer: Answer string |
|
|
|
|
|
Returns: |
|
|
str: Rounded answer if applicable, otherwise original answer |
|
|
""" |
|
|
def _is_float(string): |
|
|
try: |
|
|
float(string) |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
if _is_float(answer) and float(answer) < 1: |
|
|
return f"{float(answer):.2g}" |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
def evaluate_math500_zeroshot(input_datapath, test_datapath): |
|
|
"""Evaluate MATH-500 zero-shot performance with timeout protection. |
|
|
|
|
|
Args: |
|
|
input_datapath: Path to model output JSONL file |
|
|
test_datapath: Path to MATH-500 test JSONL file |
|
|
|
|
|
Returns: |
|
|
float: Accuracy score |
|
|
""" |
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
def _timeout_handler(signum, frame): |
|
|
raise _TimeoutException |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
gold_list = [] |
|
|
print("reading from %s" % test_datapath) |
|
|
|
|
|
|
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
answer = item['answer'] |
|
|
gold_list.append(answer) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
count_timeout = 0 |
|
|
correct = 0 |
|
|
print("reading from %s" % input_datapath) |
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
matches2 = pattern2_re.findall(line) |
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif check_after_fraction_mapping(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
|
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("len(gold_list):", len(gold_list)) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_timeout:", count_timeout) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
def evaluate_minerva_math_zeroshot(input_datapath, test_datapath): |
|
|
"""Evaluate Minerva Math zero-shot performance with timeout protection. |
|
|
|
|
|
Args: |
|
|
input_datapath: Path to model output JSONL file |
|
|
test_datapath: Path to Minerva Math test JSONL file |
|
|
|
|
|
Returns: |
|
|
float: Accuracy score |
|
|
""" |
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
def _timeout_handler(signum, frame): |
|
|
raise _TimeoutException |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
|
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
|
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
gold_list = [] |
|
|
solution_list = [] |
|
|
print("reading from %s" % test_datapath) |
|
|
|
|
|
|
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
problem = item['problem'] |
|
|
solution = item['solution'] |
|
|
matches1 = pattern1_re.findall(solution) |
|
|
if len(matches1) == 0: |
|
|
extracted_answer = None |
|
|
else: |
|
|
extracted_answer = matches1[-1] |
|
|
gold_list.append(extracted_answer) |
|
|
solution_list.append(solution) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
count_timeout = 0 |
|
|
correct = 0 |
|
|
print("reading from %s" % input_datapath) |
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
|
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
|
|
|
|
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
|
|
|
|
|
|
unit_list = ["\\hbar^{4}"] |
|
|
for unit in unit_list: |
|
|
if extracted_answer.endswith(unit): |
|
|
extracted_answer = extracted_answer[:-len(unit)] |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif round_number(extracted_answer) == round_number(gold): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif check_after_fraction_mapping(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
|
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("len(gold_list):", len(gold_list)) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_timeout:", count_timeout) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
def calculate_numbers(input_string): |
|
|
"""Safely evaluate mathematical expression string. |
|
|
|
|
|
Args: |
|
|
input_string: Mathematical expression as string |
|
|
|
|
|
Returns: |
|
|
Result of evaluation, or None if error |
|
|
""" |
|
|
try: |
|
|
result = eval(input_string) |
|
|
return result |
|
|
except: |
|
|
return None |
|
|
|
|
|
|
|
|
def is_equal_after_calculation(extracted_answer, gold): |
|
|
"""Check if answers are equal after converting fractions and evaluating. |
|
|
|
|
|
Args: |
|
|
extracted_answer: Extracted answer string |
|
|
gold: Gold standard answer string |
|
|
|
|
|
Returns: |
|
|
bool: True if answers are mathematically equal |
|
|
""" |
|
|
gold = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', gold) |
|
|
extracted_answer = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', extracted_answer) |
|
|
gold_result = calculate_numbers(gold) |
|
|
extracted_answer_result = calculate_numbers(extracted_answer) |
|
|
|
|
|
if gold_result and gold_result == extracted_answer_result: |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
|
|
|
|
|
|
def is_math_formula_equal(extracted_answer, gold): |
|
|
"""Check if two LaTeX formulas are mathematically equivalent using SymPy. |
|
|
|
|
|
Args: |
|
|
extracted_answer: Extracted answer string (LaTeX) |
|
|
gold: Gold standard answer string (LaTeX) |
|
|
|
|
|
Returns: |
|
|
bool: True if formulas are mathematically equivalent |
|
|
""" |
|
|
try: |
|
|
extracted_answer_expr = parse_latex(extracted_answer) |
|
|
gold_expr = parse_latex(gold) |
|
|
return simplify(extracted_answer_expr - gold_expr) == 0 |
|
|
except Exception as e: |
|
|
print("error:", e) |
|
|
return False |
|
|
|
|
|
|
|
|
def check_after_fraction_mapping(extracted_answer, gold): |
|
|
"""Check if answers match after converting LaTeX fractions to division. |
|
|
|
|
|
Args: |
|
|
extracted_answer: Extracted answer string |
|
|
gold: Gold standard answer string |
|
|
|
|
|
Returns: |
|
|
bool: True if answers match after fraction conversion |
|
|
""" |
|
|
return re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', extracted_answer) == re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', gold) |
|
|
|
|
|
|
|
|
def evaluate_gaokao2023en_zeroshot(input_datapath, test_datapath): |
|
|
|
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
def _timeout_handler(signum, frame): |
|
|
|
|
|
raise _TimeoutException |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes'] |
|
|
|
|
|
gold_list = [] |
|
|
question_list = [] |
|
|
print("reading from %s" % test_datapath) |
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
question = item['question'].strip() |
|
|
question_list.append(question) |
|
|
answer = item['answer'] |
|
|
answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
|
|
gold_list.append(answer) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
count_timeout = 0 |
|
|
correct = 0 |
|
|
print("reading from %s" % input_datapath) |
|
|
|
|
|
|
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
matches2 = pattern2_re.findall(line) |
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
question = question_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
if len(gold) == 1 and gold.isalpha(): |
|
|
|
|
|
|
|
|
option = "(" + gold + ")" |
|
|
assert option in question |
|
|
start_option_idx = question.index(option) |
|
|
next_option = "(" + chr(ord(gold)+1) + ")" |
|
|
next_option2 = "(" + chr(ord(gold)+2) + ")" |
|
|
if next_option in question: |
|
|
end_option_idx = question.index(next_option) |
|
|
gold2 = question[start_option_idx: end_option_idx] |
|
|
elif next_option2 in question: |
|
|
end_option_idx = question.index(next_option2) |
|
|
gold2 = question[start_option_idx: end_option_idx] |
|
|
else: |
|
|
gold2 = question[start_option_idx:] |
|
|
|
|
|
gold2 = gold2.replace(option, "").strip().replace("\\n", "") |
|
|
else: |
|
|
gold2 = None |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
if gold2: |
|
|
gold2 = math_answer_cleaning(gold2) |
|
|
|
|
|
for unit in unit_strings: |
|
|
extracted_answer = extracted_answer.replace(unit, "") |
|
|
gold = gold.replace(unit, "") |
|
|
if gold2: |
|
|
gold2 = gold2.replace(unit, "") |
|
|
|
|
|
if "=" in extracted_answer and not gold2: |
|
|
|
|
|
extracted_answer = extracted_answer.split("=", 1)[1] |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif gold2 and math_equal(extracted_answer, gold2): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif check_after_fraction_mapping(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
|
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("benchmark size:", len(gold_list)) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
def evaluate_olympiadbench_zeroshot(input_datapath, test_datapath): |
|
|
|
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
def _timeout_handler(signum, frame): |
|
|
|
|
|
raise _TimeoutException |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
gold_list = [] |
|
|
question_list = [] |
|
|
|
|
|
print("reading from %s" % test_datapath) |
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
assert len(item['final_answer']) == 1 |
|
|
answer = item['final_answer'][0] |
|
|
answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
|
|
gold_list.append(answer) |
|
|
question_list.append(item['question']) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
count_timeout = 0 |
|
|
correct = 0 |
|
|
|
|
|
print("reading from %s" % input_datapath) |
|
|
|
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
matches2 = pattern2_re.findall(line) |
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif check_after_fraction_mapping(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("count_timeout:", count_timeout) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_collegemath_zeroshot(input_datapath, test_datapath): |
|
|
|
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
def _timeout_handler(signum, frame): |
|
|
|
|
|
raise _TimeoutException |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
|
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
pattern6 = r"\\\[\s*(.*?)\s*\\\]" |
|
|
|
|
|
def _extra_content_from_gold(input_string): |
|
|
pattern_extra = r'\$(.*?)\$' |
|
|
matches = re.findall(pattern_extra, input_string) |
|
|
return matches |
|
|
|
|
|
def _extra_cleaning(input_string): |
|
|
|
|
|
input_string = re.sub(r'\\mathrm\{.*?\}', '', input_string) |
|
|
input_string = re.sub(r'\$\([^)]*\)', '', input_string) |
|
|
return input_string |
|
|
|
|
|
def _check_after_equal(extracted_answer_ori, gold_ori, matches): |
|
|
|
|
|
if extracted_answer_ori.replace(",", "").replace("$", "") == gold_ori.replace(",", "").replace("$", ""): |
|
|
return True |
|
|
|
|
|
if not extracted_answer_ori.split("=")[-1] == gold_ori.split("=")[-1].replace("$", ""): |
|
|
return False |
|
|
|
|
|
if "," in gold_ori or "or" in gold_ori or "and" in gold_ori: |
|
|
|
|
|
if len(matches) <= 1: |
|
|
return False |
|
|
|
|
|
answer1 = extracted_answer_ori.split("=")[-1] |
|
|
answer2 = math_answer_cleaning(matches[-2].split("=")[-1]) |
|
|
gold_ori = gold_ori.replace("$", "") |
|
|
if "or" in gold_ori: |
|
|
gold_list = gold_ori.split("or", 1) |
|
|
elif "and" in gold_ori: |
|
|
gold_list = gold_ori.split("and", 1) |
|
|
else: |
|
|
gold_list = gold_ori.split(",", 1) |
|
|
|
|
|
gold_ori1 = gold_list[-1].split("=")[-1] |
|
|
gold_ori2 = gold_list[-2].split("=")[-1] |
|
|
|
|
|
if math_equal(answer1, gold_ori1) and math_equal(answer2, gold_ori2): |
|
|
return True |
|
|
elif math_equal(answer2, gold_ori1) and math_equal(answer1, gold_ori2): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
else: |
|
|
return True |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
|
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
pattern6_re = re.compile(pattern6, re.DOTALL) |
|
|
|
|
|
gold_list = [] |
|
|
question_list = [] |
|
|
print("reading from %s" % test_datapath) |
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
answer = item['answer'] |
|
|
answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
|
|
gold_list.append(answer) |
|
|
question_list.append(item['question']) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
count_timeout = 0 |
|
|
correct = 0 |
|
|
print("reading from %s" % input_datapath) |
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
|
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
matches6 = pattern6_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
|
|
|
|
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
elif len(matches6) >= 1: |
|
|
extracted_answer = matches6[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
extracted_answer = _extra_cleaning(extracted_answer) |
|
|
gold = _extra_cleaning(gold) |
|
|
if gold.endswith("-"): |
|
|
gold = gold[:-1] |
|
|
if gold.endswith("."): |
|
|
gold = gold[:-1] |
|
|
if gold.endswith("hours"): |
|
|
gold = gold[:-len("hours")] |
|
|
if extracted_answer.endswith("."): |
|
|
extracted_answer = extracted_answer[:-1] |
|
|
|
|
|
extracted_answer_ori = extracted_answer |
|
|
gold_ori = gold |
|
|
|
|
|
if "=" in gold: |
|
|
gold = gold.split("=", 1)[1] |
|
|
if ":" in gold: |
|
|
gold = gold.split(":", 1)[1] |
|
|
|
|
|
if "=" in extracted_answer: |
|
|
extracted_answer = extracted_answer.split("=", 1)[1] |
|
|
if ":" in extracted_answer: |
|
|
extracted_answer = extracted_answer.split(":", 1)[1] |
|
|
|
|
|
|
|
|
extracted_answer = extracted_answer.replace("\\emptyset", "\\oslash") |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif _check_after_equal(extracted_answer_ori, gold_ori, matches1): |
|
|
correct += 1 |
|
|
elif _check_after_equal(extracted_answer, gold, matches1): |
|
|
correct += 1 |
|
|
elif check_after_fraction_mapping(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif round_number(extracted_answer) == round_number(gold): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
else: |
|
|
|
|
|
gold_matches = _extra_content_from_gold(gold) |
|
|
|
|
|
correctflag = False |
|
|
if len(gold_matches) >= 1: |
|
|
gold2 = "".join(gold_matches) |
|
|
if "\\approx" in gold2: |
|
|
gold2 = gold2.split("\\approx", 1)[1] |
|
|
|
|
|
gold2 = gold2.replace(",", "") |
|
|
extracted_answer2 = extracted_answer.replace(",", "") |
|
|
if gold2 != "" and extracted_answer2 == gold2: |
|
|
correct += 1 |
|
|
correctflag = True |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("count_timeout:", count_timeout) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
def evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath): |
|
|
"""Evaluate AMC23 or AIME24/25 zero-shot performance. |
|
|
|
|
|
Args: |
|
|
input_datapath: Path to model output JSONL file |
|
|
test_datapath: Path to AMC23/AIME24/AIME25 test JSONL file |
|
|
|
|
|
Returns: |
|
|
float: Accuracy score |
|
|
""" |
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
gold_list = [] |
|
|
question_list = [] |
|
|
print("reading from %s" % test_datapath) |
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
answer = str(item['answer']) |
|
|
gold_list.append(answer) |
|
|
question_list.append(item['problem']) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
correct = 0 |
|
|
print("reading from %s" % input_datapath) |
|
|
|
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
matches2 = pattern2_re.findall(line) |
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif round_number(extracted_answer) == round_number(gold): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
def evaluate_omnimath_zeroshot(input_datapath, test_datapath): |
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
def _timeout_handler(signum, frame): |
|
|
|
|
|
raise _TimeoutException |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
gold_list = [] |
|
|
question_list = [] |
|
|
print("reading from %s" % test_datapath) |
|
|
with open(test_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
answer = str(item['answer']) |
|
|
gold_list.append(answer) |
|
|
question_list.append(item['problem']) |
|
|
|
|
|
count_output_none = 0 |
|
|
count_answer_none = 0 |
|
|
correct = 0 |
|
|
print("reading from %s" % input_datapath) |
|
|
|
|
|
with open(input_datapath, "r") as f: |
|
|
for i, line in enumerate(f): |
|
|
line = line.strip() |
|
|
line = json.loads(line)['output'] |
|
|
|
|
|
matches1 = pattern1_re.findall(line) |
|
|
matches2 = pattern2_re.findall(line) |
|
|
matches3 = pattern3_re.findall(line) |
|
|
matches4 = pattern4_re.findall(line) |
|
|
matches5 = pattern5_re.findall(line) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
gold = gold_list[i] |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_output_none += 1 |
|
|
continue |
|
|
if gold is None: |
|
|
count_answer_none += 1 |
|
|
continue |
|
|
|
|
|
gold_ori = gold |
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
gold = math_answer_cleaning(gold) |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
|
|
|
if math_equal(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif check_after_fraction_mapping(extracted_answer, gold): |
|
|
correct += 1 |
|
|
elif round_number(extracted_answer) == round_number(gold): |
|
|
correct += 1 |
|
|
elif is_equal_after_calculation(extracted_answer, gold): |
|
|
correct += 1 |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
acc = correct / len(gold_list) |
|
|
print("count_output_none:", count_output_none) |
|
|
print("count_answer_none:", count_answer_none) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
|
|
|
def get_answer_by_marjority_voting(output_list): |
|
|
"""Get the most common answer from multiple model outputs via majority voting. |
|
|
|
|
|
Args: |
|
|
output_list: List of model output strings |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary with 'count' and 'original_output' for the majority answer |
|
|
""" |
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern2 = r"\*\*(.*?)\*\*" |
|
|
pattern3 = r"\\\[\n(.*?)\n\\\]" |
|
|
pattern4 = r'is \\\((.*?)\\\)' |
|
|
pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
|
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
pattern2_re = re.compile(pattern2, re.DOTALL) |
|
|
pattern3_re = re.compile(pattern3, re.DOTALL) |
|
|
pattern4_re = re.compile(pattern4, re.DOTALL) |
|
|
pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
|
|
|
answer_dict = {} |
|
|
for output in output_list: |
|
|
matches1 = pattern1_re.findall(output) |
|
|
matches2 = pattern2_re.findall(output) |
|
|
matches3 = pattern3_re.findall(output) |
|
|
matches4 = pattern4_re.findall(output) |
|
|
matches5 = pattern5_re.findall(output) |
|
|
|
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
elif len(matches3) >= 1: |
|
|
extracted_answer = matches3[-1] |
|
|
elif len(matches4) >= 1: |
|
|
extracted_answer = matches4[-1] |
|
|
elif len(matches5) >= 1: |
|
|
extracted_answer = matches5[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
|
|
|
if extracted_answer is None: |
|
|
continue |
|
|
|
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
|
|
|
has_found = False |
|
|
for prev_ans in answer_dict: |
|
|
if extracted_answer == prev_ans: |
|
|
answer_dict[prev_ans]['count'] += 1 |
|
|
has_found = True |
|
|
break |
|
|
elif math_equal(extracted_answer, prev_ans): |
|
|
answer_dict[prev_ans]['count'] += 1 |
|
|
has_found = True |
|
|
break |
|
|
elif check_after_fraction_mapping(extracted_answer, prev_ans): |
|
|
answer_dict[prev_ans]['count'] += 1 |
|
|
has_found = True |
|
|
break |
|
|
elif round_number(extracted_answer) == round_number(prev_ans): |
|
|
answer_dict[prev_ans]['count'] += 1 |
|
|
has_found = True |
|
|
break |
|
|
elif is_equal_after_calculation(extracted_answer, prev_ans): |
|
|
answer_dict[prev_ans]['count'] += 1 |
|
|
has_found = True |
|
|
break |
|
|
|
|
|
if not has_found: |
|
|
answer_dict[extracted_answer] = {"count": 1, "original_output": output} |
|
|
|
|
|
|
|
|
sorted_answers = sorted(answer_dict, key=lambda x: answer_dict[x]["count"], reverse=True) |
|
|
|
|
|
return answer_dict[sorted_answers[0]] |
|
|
|
|
|
|
|
|
def evaluate_gpqa(input_datapath, test_datapath): |
|
|
"""Evaluate GPQA (Graduate-Level Google-Proof Q&A) benchmark. |
|
|
|
|
|
Args: |
|
|
input_datapath: Path to model output JSONL file |
|
|
test_datapath: Path to GPQA test JSON file |
|
|
|
|
|
Returns: |
|
|
float: Accuracy score |
|
|
""" |
|
|
class _TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
def _timeout_handler(signum, frame): |
|
|
raise _TimeoutException |
|
|
|
|
|
output_list = read_text_data(input_datapath) |
|
|
gold_list = read_json_data(test_datapath) |
|
|
|
|
|
num_samples = len(gold_list) |
|
|
assert len(output_list) == len(gold_list) == num_samples |
|
|
|
|
|
pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
|
|
pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
|
|
|
|
|
|
pattern2_re = re.compile(r'\b(?:Answer|Final Answer|ANSWER)\b[:\s\*]*\(?([A-D])\)?') |
|
|
|
|
|
count_none = 0 |
|
|
count_timeout = 0 |
|
|
correct = 0 |
|
|
|
|
|
for output, gold in zip(output_list, gold_list): |
|
|
choices = [gold['choice_A'], gold['choice_B'], gold['choice_C'], gold['choice_D']] |
|
|
correct_answer = gold["correct_answer"] |
|
|
correct_index = choices.index(correct_answer) |
|
|
correct_choice = "ABCD"[correct_index] |
|
|
|
|
|
matches1 = pattern1_re.findall(output) |
|
|
matches2 = pattern2_re.findall(output) |
|
|
if len(matches1) >= 1: |
|
|
extracted_answer = matches1[-1] |
|
|
elif len(matches2) >= 1: |
|
|
extracted_answer = matches2[-1] |
|
|
else: |
|
|
extracted_answer = None |
|
|
|
|
|
if extracted_answer is None: |
|
|
count_none += 1 |
|
|
continue |
|
|
|
|
|
correct_answer = math_answer_cleaning(correct_answer) |
|
|
extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
|
|
|
try: |
|
|
|
|
|
signal.signal(signal.SIGALRM, _timeout_handler) |
|
|
signal.alarm(5) |
|
|
|
|
|
if extracted_answer.lower() == correct_choice.lower(): |
|
|
correct += 1 |
|
|
elif math_equal(extracted_answer, correct_answer): |
|
|
correct += 1 |
|
|
elif "("+correct_choice+")" in extracted_answer: |
|
|
|
|
|
correct += 1 |
|
|
|
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
except: |
|
|
|
|
|
signal.alarm(0) |
|
|
count_timeout += 1 |
|
|
|
|
|
acc = correct / num_samples |
|
|
print("num_samples:", num_samples) |
|
|
print("count_none:", count_none) |
|
|
print("accuracy:", acc) |
|
|
|
|
|
return acc |
|
|
|
|
|
def get_args(): |
|
|
"""Parse command-line arguments for evaluation script. |
|
|
|
|
|
Returns: |
|
|
argparse.Namespace: Parsed arguments |
|
|
""" |
|
|
parser = argparse.ArgumentParser(description="Math Benchmark Evaluation") |
|
|
parser.add_argument("--modelfolder", type=str, required=True, help="Path to model output folder") |
|
|
parser.add_argument("--testfolder", type=str, required=True, help="Path to test data folder") |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def check_finish(input_datapath): |
|
|
"""Check the finish rate (non-empty outputs) of model outputs. |
|
|
|
|
|
Args: |
|
|
input_datapath: Path to model output JSONL file |
|
|
|
|
|
Returns: |
|
|
float: Finish rate (proportion of non-empty outputs) |
|
|
""" |
|
|
finish_rates = [] |
|
|
with open(input_datapath, "r") as f: |
|
|
for line in f: |
|
|
item = json.loads(line) |
|
|
if not item['reason']: |
|
|
finish_rates.append(0) |
|
|
output = item['output'] |
|
|
finish_rates.append(1 if output else 0) |
|
|
|
|
|
return np.mean(finish_rates) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main evaluation function for AIME benchmarks with W&B logging.""" |
|
|
args = get_args() |
|
|
|
|
|
model_folder = args.modelfolder |
|
|
test_datafolder = args.testfolder |
|
|
import glob |
|
|
|
|
|
avg_acc = [] |
|
|
avg_common_acc = [] |
|
|
|
|
|
aime24_accs = [] |
|
|
aime25_accs = [] |
|
|
aime24_finish = [] |
|
|
aime25_finish = [] |
|
|
|
|
|
input_datapaths = glob.glob(model_folder+"/outputs_*/aime24.jsonl") |
|
|
acc_tmp = 0 |
|
|
for input_datapath in input_datapaths: |
|
|
test_datapath = os.path.join(test_datafolder, "qwen2_math/aime24/test.jsonl") |
|
|
acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath) |
|
|
aime24_accs.append(acc) |
|
|
finish = check_finish(input_datapath) |
|
|
aime24_finish.append(finish) |
|
|
acc_tmp += acc |
|
|
|
|
|
aime24_acc = acc_tmp / len(input_datapaths) |
|
|
aime24_std = np.std(aime24_accs) if len(aime24_accs) > 1 else 0 |
|
|
aime24_finish = np.mean(aime24_finish) |
|
|
print("-"*80) |
|
|
print("avg acc for aime24:", aime24_acc, "std:", aime24_std) |
|
|
print("avg finish for aime24:", aime24_finish) |
|
|
avg_acc.append(aime24_acc) |
|
|
avg_common_acc.append(aime24_acc) |
|
|
|
|
|
input_datapaths = glob.glob(model_folder+"/outputs_*/aime25.jsonl") |
|
|
acc_tmp = 0 |
|
|
for input_datapath in input_datapaths: |
|
|
test_datapath = os.path.join(test_datafolder, "aime25/test.jsonl") |
|
|
acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath) |
|
|
aime25_accs.append(acc) |
|
|
finish = check_finish(input_datapath) |
|
|
aime25_finish.append(finish) |
|
|
acc_tmp += acc |
|
|
|
|
|
aime25_acc = acc_tmp / len(input_datapaths) |
|
|
aime25_std = np.std(aime25_accs) if len(aime25_accs) > 1 else 0 |
|
|
aime25_finish = np.mean(aime25_finish) |
|
|
print("-"*80) |
|
|
print("avg acc for aime25:", aime25_acc, "std:", aime25_std) |
|
|
print("avg finish for aime25:", aime25_finish) |
|
|
avg_acc.append(aime25_acc) |
|
|
avg_common_acc.append(aime25_acc) |
|
|
|
|
|
print("="*80) |
|
|
print("average acc across AIME24:", aime24_acc, "±", aime24_std, ", AIME25:", aime25_acc, "±", aime25_std) |
|
|
print("average finish across AIME24:", aime24_finish, ", AIME25:", aime25_finish) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|