| """Math benchmark evaluation utilities. |
| |
| This module provides evaluation functions for math benchmarks |
| |
| It includes answer extraction, cleaning, and grading logic for math problems. |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import signal |
| import sys |
|
|
| import numpy as np |
| from sympy import simplify |
| from sympy.parsing.latex import parse_latex |
| from tqdm import tqdm |
|
|
| from tools.grader import math_equal |
|
|
|
|
| def read_text_data(datapath): |
| """Read model outputs from JSONL file. |
| |
| Args: |
| datapath: Path to JSONL file containing model outputs |
| |
| Returns: |
| list: List of output strings |
| """ |
| print("reading from %s" % datapath) |
| data_list = [] |
| with open(datapath, "r") as f: |
| for line in f: |
| data_list.append(json.loads(line.strip())['output']) |
| |
| return data_list |
|
|
|
|
| def read_jsonl_data(datapath): |
| """Read model outputs from JSONL file (alternative method). |
| |
| Args: |
| datapath: Path to JSONL file |
| |
| Returns: |
| list: List of output strings |
| """ |
| print("reading from %s" % datapath) |
| data_list = [] |
| with open(datapath, "r") as f: |
| for line in f: |
| data_item = json.loads(line.strip()) |
| data_list.append(data_item['output']) |
|
|
| return data_list |
|
|
|
|
| def read_json_data(datapath): |
| """Read JSON data file. |
| |
| Args: |
| datapath: Path to JSON file |
| |
| Returns: |
| Data structure from JSON file |
| """ |
| print("reading from %s" % datapath) |
| with open(datapath, "r") as f: |
| data_list = json.load(f) |
|
|
| return data_list |
|
|
|
|
| def evaluate_gsm8k_zeroshot(input_datapath, test_datapath): |
| """Evaluate GSM8K zero-shot performance. |
| |
| Args: |
| input_datapath: Path to model output JSONL file |
| test_datapath: Path to GSM8K test JSON file |
| |
| Returns: |
| float: Accuracy score |
| """ |
| output_list = read_text_data(input_datapath) |
| gold_list = read_json_data(test_datapath) |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| num_samples = len(gold_list) |
| assert len(output_list) == len(gold_list) == num_samples |
|
|
| count_none = 0 |
| correct = 0 |
| for output, gold in zip(output_list, gold_list): |
| gold = gold['answer'].split("#### ")[-1] |
|
|
| matches1 = pattern1_re.findall(output) |
| matches2 = pattern2_re.findall(output) |
| matches3 = pattern3_re.findall(output) |
| matches4 = pattern4_re.findall(output) |
| matches5 = pattern5_re.findall(output) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
|
|
| if extracted_answer is None: |
| count_none += 1 |
| continue |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
|
|
| acc = correct / len(gold_list) |
| print("count_none:", count_none) |
| print("accuracy:", acc) |
| return acc |
|
|
|
|
| def is_completely_wrapped_by_text(input_string): |
| """Check if input string is completely wrapped by LaTeX \\text{}. |
| |
| Args: |
| input_string: LaTeX string to check |
| |
| Returns: |
| str or None: Extracted content if wrapped, None otherwise |
| """ |
| pattern = r'^\\text{(.*)}$' |
| match = re.match(pattern, input_string) |
| if match: |
| extracted_content = match.group(1) |
| extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "") |
| return extracted_content |
| else: |
| return None |
|
|
|
|
| def math_answer_cleaning(answer): |
| """Clean and normalize math answer for comparison. |
| |
| Performs various cleaning operations: |
| - Remove LaTeX formatting (\\text, \\quad, etc.) |
| - Normalize fractions and scientific notation |
| - Remove units and special characters |
| - Convert to lowercase |
| |
| Args: |
| answer: Raw answer string |
| |
| Returns: |
| str: Cleaned answer string |
| """ |
| extracted_content = is_completely_wrapped_by_text(answer) |
| answer = extracted_content if extracted_content else answer |
|
|
| answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "") |
| answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{") |
| answer = answer.replace("^\circ", "").replace("^{\circ}", "") |
| answer = answer.replace("\quad", "") |
| answer = re.sub(r'\\,\\text\{.*?\}', '', answer) |
| answer = re.sub(r'\\text\{.*?\}', '', answer) |
| answer = re.sub(r'(\s\^\{-\d+\})', '', answer) |
| answer = answer.replace(" ", "").replace("\n", "").replace("\\n", "") |
| answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer) |
| answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer) |
| answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer) |
| answer = re.sub(r"10\^\{(-?\d+)\}", r"1e\1", answer) |
| answer = answer.replace(",", "").lower() |
|
|
| if answer.endswith("\\"): |
| answer = answer[:-1] |
| |
| func_pattern = r'^[a-zA-Z_]\w*\([a-zA-Z_]\w*\)$' |
| if "=" in answer and (re.match(func_pattern, answer.split("=")[0]) or len(answer.split("=")[0])<=3): |
| answer = answer.split("=", 1)[1] |
|
|
| return answer |
|
|
|
|
| def round_number(answer): |
| """Round very small numbers to 2 significant figures. |
| |
| Args: |
| answer: Answer string |
| |
| Returns: |
| str: Rounded answer if applicable, otherwise original answer |
| """ |
| def _is_float(string): |
| try: |
| float(string) |
| return True |
| except: |
| return False |
|
|
| if _is_float(answer) and float(answer) < 1: |
| return f"{float(answer):.2g}" |
| |
| return answer |
|
|
|
|
| def evaluate_math500_zeroshot(input_datapath, test_datapath): |
| """Evaluate MATH-500 zero-shot performance with timeout protection. |
| |
| Args: |
| input_datapath: Path to model output JSONL file |
| test_datapath: Path to MATH-500 test JSONL file |
| |
| Returns: |
| float: Accuracy score |
| """ |
| class _TimeoutException(Exception): |
| pass |
|
|
| def _timeout_handler(signum, frame): |
| raise _TimeoutException |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| gold_list = [] |
| print("reading from %s" % test_datapath) |
|
|
| |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| answer = item['answer'] |
| gold_list.append(answer) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| count_timeout = 0 |
| correct = 0 |
| print("reading from %s" % input_datapath) |
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
|
|
| |
| matches1 = pattern1_re.findall(line) |
| matches2 = pattern2_re.findall(line) |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
|
|
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
| elif check_after_fraction_mapping(extracted_answer, gold): |
| correct += 1 |
| |
| |
| signal.alarm(0) |
| |
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| |
|
|
| acc = correct / len(gold_list) |
| print("len(gold_list):", len(gold_list)) |
| print("count_output_none:", count_output_none) |
| print("count_timeout:", count_timeout) |
| print("count_answer_none:", count_answer_none) |
| print("accuracy:", acc) |
|
|
| return acc |
|
|
|
|
| def evaluate_minerva_math_zeroshot(input_datapath, test_datapath): |
| """Evaluate Minerva Math zero-shot performance with timeout protection. |
| |
| Args: |
| input_datapath: Path to model output JSONL file |
| test_datapath: Path to Minerva Math test JSONL file |
| |
| Returns: |
| float: Accuracy score |
| """ |
| class _TimeoutException(Exception): |
| pass |
|
|
| def _timeout_handler(signum, frame): |
| raise _TimeoutException |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| gold_list = [] |
| solution_list = [] |
| print("reading from %s" % test_datapath) |
|
|
| |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| problem = item['problem'] |
| solution = item['solution'] |
| matches1 = pattern1_re.findall(solution) |
| if len(matches1) == 0: |
| extracted_answer = None |
| else: |
| extracted_answer = matches1[-1] |
| gold_list.append(extracted_answer) |
| solution_list.append(solution) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| count_timeout = 0 |
| correct = 0 |
| print("reading from %s" % input_datapath) |
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
|
|
| |
| matches1 = pattern1_re.findall(line) |
| |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| |
| |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
|
|
| |
| unit_list = ["\\hbar^{4}"] |
| for unit in unit_list: |
| if extracted_answer.endswith(unit): |
| extracted_answer = extracted_answer[:-len(unit)] |
| |
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif round_number(extracted_answer) == round_number(gold): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
| elif check_after_fraction_mapping(extracted_answer, gold): |
| correct += 1 |
|
|
| |
| signal.alarm(0) |
| |
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| |
|
|
| acc = correct / len(gold_list) |
| print("len(gold_list):", len(gold_list)) |
| print("count_output_none:", count_output_none) |
| print("count_timeout:", count_timeout) |
| print("count_answer_none:", count_answer_none) |
| print("accuracy:", acc) |
|
|
| return acc |
|
|
|
|
| def calculate_numbers(input_string): |
| """Safely evaluate mathematical expression string. |
| |
| Args: |
| input_string: Mathematical expression as string |
| |
| Returns: |
| Result of evaluation, or None if error |
| """ |
| try: |
| result = eval(input_string) |
| return result |
| except: |
| return None |
|
|
|
|
| def is_equal_after_calculation(extracted_answer, gold): |
| """Check if answers are equal after converting fractions and evaluating. |
| |
| Args: |
| extracted_answer: Extracted answer string |
| gold: Gold standard answer string |
| |
| Returns: |
| bool: True if answers are mathematically equal |
| """ |
| gold = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', gold) |
| extracted_answer = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', extracted_answer) |
| gold_result = calculate_numbers(gold) |
| extracted_answer_result = calculate_numbers(extracted_answer) |
|
|
| if gold_result and gold_result == extracted_answer_result: |
| return True |
| else: |
| return False |
|
|
|
|
| def is_math_formula_equal(extracted_answer, gold): |
| """Check if two LaTeX formulas are mathematically equivalent using SymPy. |
| |
| Args: |
| extracted_answer: Extracted answer string (LaTeX) |
| gold: Gold standard answer string (LaTeX) |
| |
| Returns: |
| bool: True if formulas are mathematically equivalent |
| """ |
| try: |
| extracted_answer_expr = parse_latex(extracted_answer) |
| gold_expr = parse_latex(gold) |
| return simplify(extracted_answer_expr - gold_expr) == 0 |
| except Exception as e: |
| print("error:", e) |
| return False |
|
|
|
|
| def check_after_fraction_mapping(extracted_answer, gold): |
| """Check if answers match after converting LaTeX fractions to division. |
| |
| Args: |
| extracted_answer: Extracted answer string |
| gold: Gold standard answer string |
| |
| Returns: |
| bool: True if answers match after fraction conversion |
| """ |
| return re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', extracted_answer) == re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', gold) |
|
|
|
|
| def evaluate_gaokao2023en_zeroshot(input_datapath, test_datapath): |
|
|
| class _TimeoutException(Exception): |
| pass |
|
|
| def _timeout_handler(signum, frame): |
| |
| raise _TimeoutException |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes'] |
| |
| gold_list = [] |
| question_list = [] |
| print("reading from %s" % test_datapath) |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| question = item['question'].strip() |
| question_list.append(question) |
| answer = item['answer'] |
| answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
| gold_list.append(answer) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| count_timeout = 0 |
| correct = 0 |
| print("reading from %s" % input_datapath) |
|
|
| |
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
|
|
| matches1 = pattern1_re.findall(line) |
| matches2 = pattern2_re.findall(line) |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| question = question_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| if len(gold) == 1 and gold.isalpha(): |
| |
| |
| option = "(" + gold + ")" |
| assert option in question |
| start_option_idx = question.index(option) |
| next_option = "(" + chr(ord(gold)+1) + ")" |
| next_option2 = "(" + chr(ord(gold)+2) + ")" |
| if next_option in question: |
| end_option_idx = question.index(next_option) |
| gold2 = question[start_option_idx: end_option_idx] |
| elif next_option2 in question: |
| end_option_idx = question.index(next_option2) |
| gold2 = question[start_option_idx: end_option_idx] |
| else: |
| gold2 = question[start_option_idx:] |
| |
| gold2 = gold2.replace(option, "").strip().replace("\\n", "") |
| else: |
| gold2 = None |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
| if gold2: |
| gold2 = math_answer_cleaning(gold2) |
|
|
| for unit in unit_strings: |
| extracted_answer = extracted_answer.replace(unit, "") |
| gold = gold.replace(unit, "") |
| if gold2: |
| gold2 = gold2.replace(unit, "") |
|
|
| if "=" in extracted_answer and not gold2: |
| |
| extracted_answer = extracted_answer.split("=", 1)[1] |
|
|
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif gold2 and math_equal(extracted_answer, gold2): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
| elif check_after_fraction_mapping(extracted_answer, gold): |
| correct += 1 |
|
|
| |
| signal.alarm(0) |
|
|
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| |
|
|
| acc = correct / len(gold_list) |
| print("benchmark size:", len(gold_list)) |
| print("count_output_none:", count_output_none) |
| print("count_answer_none:", count_answer_none) |
| print("accuracy:", acc) |
|
|
| return acc |
|
|
|
|
| def evaluate_olympiadbench_zeroshot(input_datapath, test_datapath): |
| |
| class _TimeoutException(Exception): |
| pass |
|
|
| def _timeout_handler(signum, frame): |
| |
| raise _TimeoutException |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| gold_list = [] |
| question_list = [] |
|
|
| print("reading from %s" % test_datapath) |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| assert len(item['final_answer']) == 1 |
| answer = item['final_answer'][0] |
| answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
| gold_list.append(answer) |
| question_list.append(item['question']) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| count_timeout = 0 |
| correct = 0 |
| |
| print("reading from %s" % input_datapath) |
|
|
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
|
|
| matches1 = pattern1_re.findall(line) |
| matches2 = pattern2_re.findall(line) |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
| |
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
|
|
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
| elif check_after_fraction_mapping(extracted_answer, gold): |
| correct += 1 |
|
|
| |
| signal.alarm(0) |
|
|
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| acc = correct / len(gold_list) |
| print("count_output_none:", count_output_none) |
| print("count_answer_none:", count_answer_none) |
| print("count_timeout:", count_timeout) |
| print("accuracy:", acc) |
| |
| return acc |
|
|
|
|
|
|
|
|
| def evaluate_collegemath_zeroshot(input_datapath, test_datapath): |
|
|
| class _TimeoutException(Exception): |
| pass |
|
|
| def _timeout_handler(signum, frame): |
| |
| raise _TimeoutException |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| pattern6 = r"\\\[\s*(.*?)\s*\\\]" |
|
|
| def _extra_content_from_gold(input_string): |
| pattern_extra = r'\$(.*?)\$' |
| matches = re.findall(pattern_extra, input_string) |
| return matches |
| |
| def _extra_cleaning(input_string): |
| |
| input_string = re.sub(r'\\mathrm\{.*?\}', '', input_string) |
| input_string = re.sub(r'\$\([^)]*\)', '', input_string) |
| return input_string |
| |
| def _check_after_equal(extracted_answer_ori, gold_ori, matches): |
|
|
| if extracted_answer_ori.replace(",", "").replace("$", "") == gold_ori.replace(",", "").replace("$", ""): |
| return True |
| |
| if not extracted_answer_ori.split("=")[-1] == gold_ori.split("=")[-1].replace("$", ""): |
| return False |
|
|
| if "," in gold_ori or "or" in gold_ori or "and" in gold_ori: |
| |
| if len(matches) <= 1: |
| return False |
|
|
| answer1 = extracted_answer_ori.split("=")[-1] |
| answer2 = math_answer_cleaning(matches[-2].split("=")[-1]) |
| gold_ori = gold_ori.replace("$", "") |
| if "or" in gold_ori: |
| gold_list = gold_ori.split("or", 1) |
| elif "and" in gold_ori: |
| gold_list = gold_ori.split("and", 1) |
| else: |
| gold_list = gold_ori.split(",", 1) |
| |
| gold_ori1 = gold_list[-1].split("=")[-1] |
| gold_ori2 = gold_list[-2].split("=")[-1] |
| |
| if math_equal(answer1, gold_ori1) and math_equal(answer2, gold_ori2): |
| return True |
| elif math_equal(answer2, gold_ori1) and math_equal(answer1, gold_ori2): |
| return True |
|
|
| return False |
| |
| else: |
| return True |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
| pattern6_re = re.compile(pattern6, re.DOTALL) |
|
|
| gold_list = [] |
| question_list = [] |
| print("reading from %s" % test_datapath) |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| answer = item['answer'] |
| answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
| gold_list.append(answer) |
| question_list.append(item['question']) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| count_timeout = 0 |
| correct = 0 |
| print("reading from %s" % input_datapath) |
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
|
|
| matches1 = pattern1_re.findall(line) |
| |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
| matches6 = pattern6_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| |
| |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| elif len(matches6) >= 1: |
| extracted_answer = matches6[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
| extracted_answer = _extra_cleaning(extracted_answer) |
| gold = _extra_cleaning(gold) |
| if gold.endswith("-"): |
| gold = gold[:-1] |
| if gold.endswith("."): |
| gold = gold[:-1] |
| if gold.endswith("hours"): |
| gold = gold[:-len("hours")] |
| if extracted_answer.endswith("."): |
| extracted_answer = extracted_answer[:-1] |
| |
| extracted_answer_ori = extracted_answer |
| gold_ori = gold |
|
|
| if "=" in gold: |
| gold = gold.split("=", 1)[1] |
| if ":" in gold: |
| gold = gold.split(":", 1)[1] |
|
|
| if "=" in extracted_answer: |
| extracted_answer = extracted_answer.split("=", 1)[1] |
| if ":" in extracted_answer: |
| extracted_answer = extracted_answer.split(":", 1)[1] |
|
|
| |
| extracted_answer = extracted_answer.replace("\\emptyset", "\\oslash") |
|
|
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif _check_after_equal(extracted_answer_ori, gold_ori, matches1): |
| correct += 1 |
| elif _check_after_equal(extracted_answer, gold, matches1): |
| correct += 1 |
| elif check_after_fraction_mapping(extracted_answer, gold): |
| correct += 1 |
| elif round_number(extracted_answer) == round_number(gold): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
| else: |
| |
| gold_matches = _extra_content_from_gold(gold) |
|
|
| correctflag = False |
| if len(gold_matches) >= 1: |
| gold2 = "".join(gold_matches) |
| if "\\approx" in gold2: |
| gold2 = gold2.split("\\approx", 1)[1] |
| |
| gold2 = gold2.replace(",", "") |
| extracted_answer2 = extracted_answer.replace(",", "") |
| if gold2 != "" and extracted_answer2 == gold2: |
| correct += 1 |
| correctflag = True |
| |
| |
| signal.alarm(0) |
|
|
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| acc = correct / len(gold_list) |
| print("count_output_none:", count_output_none) |
| print("count_answer_none:", count_answer_none) |
| print("count_timeout:", count_timeout) |
| print("accuracy:", acc) |
|
|
| return acc |
|
|
|
|
| def evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath): |
| """Evaluate AMC23 or AIME24/25 zero-shot performance. |
| |
| Args: |
| input_datapath: Path to model output JSONL file |
| test_datapath: Path to AMC23/AIME24/AIME25 test JSONL file |
| |
| Returns: |
| float: Accuracy score |
| """ |
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| gold_list = [] |
| question_list = [] |
| print("reading from %s" % test_datapath) |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| answer = str(item['answer']) |
| gold_list.append(answer) |
| question_list.append(item['problem']) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| correct = 0 |
| print("reading from %s" % input_datapath) |
|
|
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
|
|
| matches1 = pattern1_re.findall(line) |
| matches2 = pattern2_re.findall(line) |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif round_number(extracted_answer) == round_number(gold): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
|
|
| acc = correct / len(gold_list) |
| print("count_output_none:", count_output_none) |
| print("count_answer_none:", count_answer_none) |
| print("accuracy:", acc) |
| |
| return acc |
|
|
|
|
| def evaluate_omnimath_zeroshot(input_datapath, test_datapath): |
| class _TimeoutException(Exception): |
| pass |
| def _timeout_handler(signum, frame): |
| |
| raise _TimeoutException |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| gold_list = [] |
| question_list = [] |
| print("reading from %s" % test_datapath) |
| with open(test_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| answer = str(item['answer']) |
| gold_list.append(answer) |
| question_list.append(item['problem']) |
|
|
| count_output_none = 0 |
| count_answer_none = 0 |
| correct = 0 |
| print("reading from %s" % input_datapath) |
| |
| with open(input_datapath, "r") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| line = json.loads(line)['output'] |
| |
| matches1 = pattern1_re.findall(line) |
| matches2 = pattern2_re.findall(line) |
| matches3 = pattern3_re.findall(line) |
| matches4 = pattern4_re.findall(line) |
| matches5 = pattern5_re.findall(line) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
| gold = gold_list[i] |
| |
| if extracted_answer is None: |
| count_output_none += 1 |
| continue |
| if gold is None: |
| count_answer_none += 1 |
| continue |
|
|
| gold_ori = gold |
| extracted_answer = math_answer_cleaning(extracted_answer) |
| gold = math_answer_cleaning(gold) |
|
|
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
|
|
| if math_equal(extracted_answer, gold): |
| correct += 1 |
| elif check_after_fraction_mapping(extracted_answer, gold): |
| correct += 1 |
| elif round_number(extracted_answer) == round_number(gold): |
| correct += 1 |
| elif is_equal_after_calculation(extracted_answer, gold): |
| correct += 1 |
|
|
| |
| signal.alarm(0) |
| |
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| acc = correct / len(gold_list) |
| print("count_output_none:", count_output_none) |
| print("count_answer_none:", count_answer_none) |
| print("accuracy:", acc) |
| |
| return acc |
|
|
|
|
| def get_answer_by_marjority_voting(output_list): |
| """Get the most common answer from multiple model outputs via majority voting. |
| |
| Args: |
| output_list: List of model output strings |
| |
| Returns: |
| dict: Dictionary with 'count' and 'original_output' for the majority answer |
| """ |
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern2 = r"\*\*(.*?)\*\*" |
| pattern3 = r"\\\[\n(.*?)\n\\\]" |
| pattern4 = r'is \\\((.*?)\\\)' |
| pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
|
|
| pattern1_re = re.compile(pattern1, re.DOTALL) |
| pattern2_re = re.compile(pattern2, re.DOTALL) |
| pattern3_re = re.compile(pattern3, re.DOTALL) |
| pattern4_re = re.compile(pattern4, re.DOTALL) |
| pattern5_re = re.compile(pattern5, re.DOTALL) |
|
|
| answer_dict = {} |
| for output in output_list: |
| matches1 = pattern1_re.findall(output) |
| matches2 = pattern2_re.findall(output) |
| matches3 = pattern3_re.findall(output) |
| matches4 = pattern4_re.findall(output) |
| matches5 = pattern5_re.findall(output) |
|
|
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| elif len(matches3) >= 1: |
| extracted_answer = matches3[-1] |
| elif len(matches4) >= 1: |
| extracted_answer = matches4[-1] |
| elif len(matches5) >= 1: |
| extracted_answer = matches5[-1] |
| else: |
| extracted_answer = None |
|
|
| if extracted_answer is None: |
| continue |
| |
| extracted_answer = math_answer_cleaning(extracted_answer) |
|
|
| has_found = False |
| for prev_ans in answer_dict: |
| if extracted_answer == prev_ans: |
| answer_dict[prev_ans]['count'] += 1 |
| has_found = True |
| break |
| elif math_equal(extracted_answer, prev_ans): |
| answer_dict[prev_ans]['count'] += 1 |
| has_found = True |
| break |
| elif check_after_fraction_mapping(extracted_answer, prev_ans): |
| answer_dict[prev_ans]['count'] += 1 |
| has_found = True |
| break |
| elif round_number(extracted_answer) == round_number(prev_ans): |
| answer_dict[prev_ans]['count'] += 1 |
| has_found = True |
| break |
| elif is_equal_after_calculation(extracted_answer, prev_ans): |
| answer_dict[prev_ans]['count'] += 1 |
| has_found = True |
| break |
|
|
| if not has_found: |
| answer_dict[extracted_answer] = {"count": 1, "original_output": output} |
|
|
| |
| sorted_answers = sorted(answer_dict, key=lambda x: answer_dict[x]["count"], reverse=True) |
|
|
| return answer_dict[sorted_answers[0]] |
|
|
|
|
| def evaluate_gpqa(input_datapath, test_datapath): |
| """Evaluate GPQA (Graduate-Level Google-Proof Q&A) benchmark. |
| |
| Args: |
| input_datapath: Path to model output JSONL file |
| test_datapath: Path to GPQA test JSON file |
| |
| Returns: |
| float: Accuracy score |
| """ |
| class _TimeoutException(Exception): |
| pass |
|
|
| def _timeout_handler(signum, frame): |
| raise _TimeoutException |
|
|
| output_list = read_text_data(input_datapath) |
| gold_list = read_json_data(test_datapath) |
|
|
| num_samples = len(gold_list) |
| assert len(output_list) == len(gold_list) == num_samples |
|
|
| pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| pattern1_re = re.compile(pattern1, re.DOTALL) |
|
|
| |
| pattern2_re = re.compile(r'\b(?:Answer|Final Answer|ANSWER)\b[:\s\*]*\(?([A-D])\)?') |
|
|
| count_none = 0 |
| count_timeout = 0 |
| correct = 0 |
| |
| for output, gold in zip(output_list, gold_list): |
| choices = [gold['choice_A'], gold['choice_B'], gold['choice_C'], gold['choice_D']] |
| correct_answer = gold["correct_answer"] |
| correct_index = choices.index(correct_answer) |
| correct_choice = "ABCD"[correct_index] |
| |
| matches1 = pattern1_re.findall(output) |
| matches2 = pattern2_re.findall(output) |
| if len(matches1) >= 1: |
| extracted_answer = matches1[-1] |
| elif len(matches2) >= 1: |
| extracted_answer = matches2[-1] |
| else: |
| extracted_answer = None |
|
|
| if extracted_answer is None: |
| count_none += 1 |
| continue |
| |
| correct_answer = math_answer_cleaning(correct_answer) |
| extracted_answer = math_answer_cleaning(extracted_answer) |
| |
| try: |
| |
| signal.signal(signal.SIGALRM, _timeout_handler) |
| signal.alarm(5) |
|
|
| if extracted_answer.lower() == correct_choice.lower(): |
| correct += 1 |
| elif math_equal(extracted_answer, correct_answer): |
| correct += 1 |
| elif "("+correct_choice+")" in extracted_answer: |
| |
| correct += 1 |
|
|
| |
| signal.alarm(0) |
|
|
| except: |
| |
| signal.alarm(0) |
| count_timeout += 1 |
|
|
| acc = correct / num_samples |
| print("num_samples:", num_samples) |
| print("count_none:", count_none) |
| print("accuracy:", acc) |
|
|
| return acc |
|
|
| def get_args(): |
| """Parse command-line arguments for evaluation script. |
| |
| Returns: |
| argparse.Namespace: Parsed arguments |
| """ |
| parser = argparse.ArgumentParser(description="Math Benchmark Evaluation") |
| parser.add_argument("--modelfolder", type=str, required=True, help="Path to model output folder") |
| parser.add_argument("--testfolder", type=str, required=True, help="Path to test data folder") |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def check_finish(input_datapath): |
| """Check the finish rate (non-empty outputs) of model outputs. |
| |
| Args: |
| input_datapath: Path to model output JSONL file |
| |
| Returns: |
| float: Finish rate (proportion of non-empty outputs) |
| """ |
| finish_rates = [] |
| with open(input_datapath, "r") as f: |
| for line in f: |
| item = json.loads(line) |
| if not item['reason']: |
| finish_rates.append(0) |
| output = item['output'] |
| finish_rates.append(1 if output else 0) |
|
|
| return np.mean(finish_rates) |
|
|
|
|
| def main(): |
| """Main evaluation function for AIME benchmarks with W&B logging.""" |
| args = get_args() |
|
|
| model_folder = args.modelfolder |
| test_datafolder = args.testfolder |
| import glob |
|
|
| avg_acc = [] |
| avg_common_acc = [] |
| |
| aime24_accs = [] |
| aime25_accs = [] |
| aime24_finish = [] |
| aime25_finish = [] |
|
|
| input_datapaths = glob.glob(model_folder+"/outputs_*/aime24.jsonl") |
| acc_tmp = 0 |
| for input_datapath in input_datapaths: |
| test_datapath = os.path.join(test_datafolder, "qwen2_math/aime24/test.jsonl") |
| acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath) |
| aime24_accs.append(acc) |
| finish = check_finish(input_datapath) |
| aime24_finish.append(finish) |
| acc_tmp += acc |
| |
| aime24_acc = acc_tmp / len(input_datapaths) |
| aime24_std = np.std(aime24_accs) if len(aime24_accs) > 1 else 0 |
| aime24_finish = np.mean(aime24_finish) |
| print("-"*80) |
| print("avg acc for aime24:", aime24_acc, "std:", aime24_std) |
| print("avg finish for aime24:", aime24_finish) |
| avg_acc.append(aime24_acc) |
| avg_common_acc.append(aime24_acc) |
|
|
| input_datapaths = glob.glob(model_folder+"/outputs_*/aime25.jsonl") |
| acc_tmp = 0 |
| for input_datapath in input_datapaths: |
| test_datapath = os.path.join(test_datafolder, "aime25/test.jsonl") |
| acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath) |
| aime25_accs.append(acc) |
| finish = check_finish(input_datapath) |
| aime25_finish.append(finish) |
| acc_tmp += acc |
| |
| aime25_acc = acc_tmp / len(input_datapaths) |
| aime25_std = np.std(aime25_accs) if len(aime25_accs) > 1 else 0 |
| aime25_finish = np.mean(aime25_finish) |
| print("-"*80) |
| print("avg acc for aime25:", aime25_acc, "std:", aime25_std) |
| print("avg finish for aime25:", aime25_finish) |
| avg_acc.append(aime25_acc) |
| avg_common_acc.append(aime25_acc) |
|
|
| print("="*80) |
| print("average acc across AIME24:", aime24_acc, "±", aime24_std, ", AIME25:", aime25_acc, "±", aime25_std) |
| print("average finish across AIME24:", aime24_finish, ", AIME25:", aime25_finish) |
|
|
| if __name__ == "__main__": |
| main() |
|
|