| | """Math benchmark evaluation utilities. |
| | |
| | This module provides evaluation functions for math benchmarks |
| | |
| | It includes answer extraction, cleaning, and grading logic for math problems. |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import re |
| | import signal |
| | import sys |
| |
|
| | import numpy as np |
| | from sympy import simplify |
| | from sympy.parsing.latex import parse_latex |
| | from tqdm import tqdm |
| |
|
| | from tools.grader import math_equal |
| |
|
| |
|
| | def read_text_data(datapath): |
| | """Read model outputs from JSONL file. |
| | |
| | Args: |
| | datapath: Path to JSONL file containing model outputs |
| | |
| | Returns: |
| | list: List of output strings |
| | """ |
| | print("reading from %s" % datapath) |
| | data_list = [] |
| | with open(datapath, "r") as f: |
| | for line in f: |
| | data_list.append(json.loads(line.strip())['output']) |
| | |
| | return data_list |
| |
|
| |
|
| | def read_jsonl_data(datapath): |
| | """Read model outputs from JSONL file (alternative method). |
| | |
| | Args: |
| | datapath: Path to JSONL file |
| | |
| | Returns: |
| | list: List of output strings |
| | """ |
| | print("reading from %s" % datapath) |
| | data_list = [] |
| | with open(datapath, "r") as f: |
| | for line in f: |
| | data_item = json.loads(line.strip()) |
| | data_list.append(data_item['output']) |
| |
|
| | return data_list |
| |
|
| |
|
| | def read_json_data(datapath): |
| | """Read JSON data file. |
| | |
| | Args: |
| | datapath: Path to JSON file |
| | |
| | Returns: |
| | Data structure from JSON file |
| | """ |
| | print("reading from %s" % datapath) |
| | with open(datapath, "r") as f: |
| | data_list = json.load(f) |
| |
|
| | return data_list |
| |
|
| |
|
| | def evaluate_gsm8k_zeroshot(input_datapath, test_datapath): |
| | """Evaluate GSM8K zero-shot performance. |
| | |
| | Args: |
| | input_datapath: Path to model output JSONL file |
| | test_datapath: Path to GSM8K test JSON file |
| | |
| | Returns: |
| | float: Accuracy score |
| | """ |
| | output_list = read_text_data(input_datapath) |
| | gold_list = read_json_data(test_datapath) |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | num_samples = len(gold_list) |
| | assert len(output_list) == len(gold_list) == num_samples |
| |
|
| | count_none = 0 |
| | correct = 0 |
| | for output, gold in zip(output_list, gold_list): |
| | gold = gold['answer'].split("#### ")[-1] |
| |
|
| | matches1 = pattern1_re.findall(output) |
| | matches2 = pattern2_re.findall(output) |
| | matches3 = pattern3_re.findall(output) |
| | matches4 = pattern4_re.findall(output) |
| | matches5 = pattern5_re.findall(output) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| |
|
| | if extracted_answer is None: |
| | count_none += 1 |
| | continue |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | acc = correct / len(gold_list) |
| | print("count_none:", count_none) |
| | print("accuracy:", acc) |
| | return acc |
| |
|
| |
|
| | def is_completely_wrapped_by_text(input_string): |
| | """Check if input string is completely wrapped by LaTeX \\text{}. |
| | |
| | Args: |
| | input_string: LaTeX string to check |
| | |
| | Returns: |
| | str or None: Extracted content if wrapped, None otherwise |
| | """ |
| | pattern = r'^\\text{(.*)}$' |
| | match = re.match(pattern, input_string) |
| | if match: |
| | extracted_content = match.group(1) |
| | extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "") |
| | return extracted_content |
| | else: |
| | return None |
| |
|
| |
|
| | def math_answer_cleaning(answer): |
| | """Clean and normalize math answer for comparison. |
| | |
| | Performs various cleaning operations: |
| | - Remove LaTeX formatting (\\text, \\quad, etc.) |
| | - Normalize fractions and scientific notation |
| | - Remove units and special characters |
| | - Convert to lowercase |
| | |
| | Args: |
| | answer: Raw answer string |
| | |
| | Returns: |
| | str: Cleaned answer string |
| | """ |
| | extracted_content = is_completely_wrapped_by_text(answer) |
| | answer = extracted_content if extracted_content else answer |
| |
|
| | answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "") |
| | answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{") |
| | answer = answer.replace("^\circ", "").replace("^{\circ}", "") |
| | answer = answer.replace("\quad", "") |
| | answer = re.sub(r'\\,\\text\{.*?\}', '', answer) |
| | answer = re.sub(r'\\text\{.*?\}', '', answer) |
| | answer = re.sub(r'(\s\^\{-\d+\})', '', answer) |
| | answer = answer.replace(" ", "").replace("\n", "").replace("\\n", "") |
| | answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer) |
| | answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer) |
| | answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer) |
| | answer = re.sub(r"10\^\{(-?\d+)\}", r"1e\1", answer) |
| | answer = answer.replace(",", "").lower() |
| |
|
| | if answer.endswith("\\"): |
| | answer = answer[:-1] |
| | |
| | func_pattern = r'^[a-zA-Z_]\w*\([a-zA-Z_]\w*\)$' |
| | if "=" in answer and (re.match(func_pattern, answer.split("=")[0]) or len(answer.split("=")[0])<=3): |
| | answer = answer.split("=", 1)[1] |
| |
|
| | return answer |
| |
|
| |
|
| | def round_number(answer): |
| | """Round very small numbers to 2 significant figures. |
| | |
| | Args: |
| | answer: Answer string |
| | |
| | Returns: |
| | str: Rounded answer if applicable, otherwise original answer |
| | """ |
| | def _is_float(string): |
| | try: |
| | float(string) |
| | return True |
| | except: |
| | return False |
| |
|
| | if _is_float(answer) and float(answer) < 1: |
| | return f"{float(answer):.2g}" |
| | |
| | return answer |
| |
|
| |
|
| | def evaluate_math500_zeroshot(input_datapath, test_datapath): |
| | """Evaluate MATH-500 zero-shot performance with timeout protection. |
| | |
| | Args: |
| | input_datapath: Path to model output JSONL file |
| | test_datapath: Path to MATH-500 test JSONL file |
| | |
| | Returns: |
| | float: Accuracy score |
| | """ |
| | class _TimeoutException(Exception): |
| | pass |
| |
|
| | def _timeout_handler(signum, frame): |
| | raise _TimeoutException |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | gold_list = [] |
| | print("reading from %s" % test_datapath) |
| |
|
| | |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | answer = item['answer'] |
| | gold_list.append(answer) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | count_timeout = 0 |
| | correct = 0 |
| | print("reading from %s" % input_datapath) |
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| |
|
| | |
| | matches1 = pattern1_re.findall(line) |
| | matches2 = pattern2_re.findall(line) |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| |
|
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| | elif check_after_fraction_mapping(extracted_answer, gold): |
| | correct += 1 |
| | |
| | |
| | signal.alarm(0) |
| | |
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | |
| |
|
| | acc = correct / len(gold_list) |
| | print("len(gold_list):", len(gold_list)) |
| | print("count_output_none:", count_output_none) |
| | print("count_timeout:", count_timeout) |
| | print("count_answer_none:", count_answer_none) |
| | print("accuracy:", acc) |
| |
|
| | return acc |
| |
|
| |
|
| | def evaluate_minerva_math_zeroshot(input_datapath, test_datapath): |
| | """Evaluate Minerva Math zero-shot performance with timeout protection. |
| | |
| | Args: |
| | input_datapath: Path to model output JSONL file |
| | test_datapath: Path to Minerva Math test JSONL file |
| | |
| | Returns: |
| | float: Accuracy score |
| | """ |
| | class _TimeoutException(Exception): |
| | pass |
| |
|
| | def _timeout_handler(signum, frame): |
| | raise _TimeoutException |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | gold_list = [] |
| | solution_list = [] |
| | print("reading from %s" % test_datapath) |
| |
|
| | |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | problem = item['problem'] |
| | solution = item['solution'] |
| | matches1 = pattern1_re.findall(solution) |
| | if len(matches1) == 0: |
| | extracted_answer = None |
| | else: |
| | extracted_answer = matches1[-1] |
| | gold_list.append(extracted_answer) |
| | solution_list.append(solution) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | count_timeout = 0 |
| | correct = 0 |
| | print("reading from %s" % input_datapath) |
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| |
|
| | |
| | matches1 = pattern1_re.findall(line) |
| | |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | |
| | |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| |
|
| | |
| | unit_list = ["\\hbar^{4}"] |
| | for unit in unit_list: |
| | if extracted_answer.endswith(unit): |
| | extracted_answer = extracted_answer[:-len(unit)] |
| | |
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif round_number(extracted_answer) == round_number(gold): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| | elif check_after_fraction_mapping(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | |
| | signal.alarm(0) |
| | |
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | |
| |
|
| | acc = correct / len(gold_list) |
| | print("len(gold_list):", len(gold_list)) |
| | print("count_output_none:", count_output_none) |
| | print("count_timeout:", count_timeout) |
| | print("count_answer_none:", count_answer_none) |
| | print("accuracy:", acc) |
| |
|
| | return acc |
| |
|
| |
|
| | def calculate_numbers(input_string): |
| | """Safely evaluate mathematical expression string. |
| | |
| | Args: |
| | input_string: Mathematical expression as string |
| | |
| | Returns: |
| | Result of evaluation, or None if error |
| | """ |
| | try: |
| | result = eval(input_string) |
| | return result |
| | except: |
| | return None |
| |
|
| |
|
| | def is_equal_after_calculation(extracted_answer, gold): |
| | """Check if answers are equal after converting fractions and evaluating. |
| | |
| | Args: |
| | extracted_answer: Extracted answer string |
| | gold: Gold standard answer string |
| | |
| | Returns: |
| | bool: True if answers are mathematically equal |
| | """ |
| | gold = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', gold) |
| | extracted_answer = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', extracted_answer) |
| | gold_result = calculate_numbers(gold) |
| | extracted_answer_result = calculate_numbers(extracted_answer) |
| |
|
| | if gold_result and gold_result == extracted_answer_result: |
| | return True |
| | else: |
| | return False |
| |
|
| |
|
| | def is_math_formula_equal(extracted_answer, gold): |
| | """Check if two LaTeX formulas are mathematically equivalent using SymPy. |
| | |
| | Args: |
| | extracted_answer: Extracted answer string (LaTeX) |
| | gold: Gold standard answer string (LaTeX) |
| | |
| | Returns: |
| | bool: True if formulas are mathematically equivalent |
| | """ |
| | try: |
| | extracted_answer_expr = parse_latex(extracted_answer) |
| | gold_expr = parse_latex(gold) |
| | return simplify(extracted_answer_expr - gold_expr) == 0 |
| | except Exception as e: |
| | print("error:", e) |
| | return False |
| |
|
| |
|
| | def check_after_fraction_mapping(extracted_answer, gold): |
| | """Check if answers match after converting LaTeX fractions to division. |
| | |
| | Args: |
| | extracted_answer: Extracted answer string |
| | gold: Gold standard answer string |
| | |
| | Returns: |
| | bool: True if answers match after fraction conversion |
| | """ |
| | return re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', extracted_answer) == re.sub(r'\\frac{(.*?)}{(.*?)}', r'\1/\2', gold) |
| |
|
| |
|
| | def evaluate_gaokao2023en_zeroshot(input_datapath, test_datapath): |
| |
|
| | class _TimeoutException(Exception): |
| | pass |
| |
|
| | def _timeout_handler(signum, frame): |
| | |
| | raise _TimeoutException |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes'] |
| | |
| | gold_list = [] |
| | question_list = [] |
| | print("reading from %s" % test_datapath) |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | question = item['question'].strip() |
| | question_list.append(question) |
| | answer = item['answer'] |
| | answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
| | gold_list.append(answer) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | count_timeout = 0 |
| | correct = 0 |
| | print("reading from %s" % input_datapath) |
| |
|
| | |
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| |
|
| | matches1 = pattern1_re.findall(line) |
| | matches2 = pattern2_re.findall(line) |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | question = question_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | if len(gold) == 1 and gold.isalpha(): |
| | |
| | |
| | option = "(" + gold + ")" |
| | assert option in question |
| | start_option_idx = question.index(option) |
| | next_option = "(" + chr(ord(gold)+1) + ")" |
| | next_option2 = "(" + chr(ord(gold)+2) + ")" |
| | if next_option in question: |
| | end_option_idx = question.index(next_option) |
| | gold2 = question[start_option_idx: end_option_idx] |
| | elif next_option2 in question: |
| | end_option_idx = question.index(next_option2) |
| | gold2 = question[start_option_idx: end_option_idx] |
| | else: |
| | gold2 = question[start_option_idx:] |
| | |
| | gold2 = gold2.replace(option, "").strip().replace("\\n", "") |
| | else: |
| | gold2 = None |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| | if gold2: |
| | gold2 = math_answer_cleaning(gold2) |
| |
|
| | for unit in unit_strings: |
| | extracted_answer = extracted_answer.replace(unit, "") |
| | gold = gold.replace(unit, "") |
| | if gold2: |
| | gold2 = gold2.replace(unit, "") |
| |
|
| | if "=" in extracted_answer and not gold2: |
| | |
| | extracted_answer = extracted_answer.split("=", 1)[1] |
| |
|
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif gold2 and math_equal(extracted_answer, gold2): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| | elif check_after_fraction_mapping(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | |
| | signal.alarm(0) |
| |
|
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | |
| |
|
| | acc = correct / len(gold_list) |
| | print("benchmark size:", len(gold_list)) |
| | print("count_output_none:", count_output_none) |
| | print("count_answer_none:", count_answer_none) |
| | print("accuracy:", acc) |
| |
|
| | return acc |
| |
|
| |
|
| | def evaluate_olympiadbench_zeroshot(input_datapath, test_datapath): |
| | |
| | class _TimeoutException(Exception): |
| | pass |
| |
|
| | def _timeout_handler(signum, frame): |
| | |
| | raise _TimeoutException |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | gold_list = [] |
| | question_list = [] |
| |
|
| | print("reading from %s" % test_datapath) |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | assert len(item['final_answer']) == 1 |
| | answer = item['final_answer'][0] |
| | answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
| | gold_list.append(answer) |
| | question_list.append(item['question']) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | count_timeout = 0 |
| | correct = 0 |
| | |
| | print("reading from %s" % input_datapath) |
| |
|
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| |
|
| | matches1 = pattern1_re.findall(line) |
| | matches2 = pattern2_re.findall(line) |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| | |
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| | elif check_after_fraction_mapping(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | |
| | signal.alarm(0) |
| |
|
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | acc = correct / len(gold_list) |
| | print("count_output_none:", count_output_none) |
| | print("count_answer_none:", count_answer_none) |
| | print("count_timeout:", count_timeout) |
| | print("accuracy:", acc) |
| | |
| | return acc |
| |
|
| |
|
| |
|
| |
|
| | def evaluate_collegemath_zeroshot(input_datapath, test_datapath): |
| |
|
| | class _TimeoutException(Exception): |
| | pass |
| |
|
| | def _timeout_handler(signum, frame): |
| | |
| | raise _TimeoutException |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| | pattern6 = r"\\\[\s*(.*?)\s*\\\]" |
| |
|
| | def _extra_content_from_gold(input_string): |
| | pattern_extra = r'\$(.*?)\$' |
| | matches = re.findall(pattern_extra, input_string) |
| | return matches |
| | |
| | def _extra_cleaning(input_string): |
| | |
| | input_string = re.sub(r'\\mathrm\{.*?\}', '', input_string) |
| | input_string = re.sub(r'\$\([^)]*\)', '', input_string) |
| | return input_string |
| | |
| | def _check_after_equal(extracted_answer_ori, gold_ori, matches): |
| |
|
| | if extracted_answer_ori.replace(",", "").replace("$", "") == gold_ori.replace(",", "").replace("$", ""): |
| | return True |
| | |
| | if not extracted_answer_ori.split("=")[-1] == gold_ori.split("=")[-1].replace("$", ""): |
| | return False |
| |
|
| | if "," in gold_ori or "or" in gold_ori or "and" in gold_ori: |
| | |
| | if len(matches) <= 1: |
| | return False |
| |
|
| | answer1 = extracted_answer_ori.split("=")[-1] |
| | answer2 = math_answer_cleaning(matches[-2].split("=")[-1]) |
| | gold_ori = gold_ori.replace("$", "") |
| | if "or" in gold_ori: |
| | gold_list = gold_ori.split("or", 1) |
| | elif "and" in gold_ori: |
| | gold_list = gold_ori.split("and", 1) |
| | else: |
| | gold_list = gold_ori.split(",", 1) |
| | |
| | gold_ori1 = gold_list[-1].split("=")[-1] |
| | gold_ori2 = gold_list[-2].split("=")[-1] |
| | |
| | if math_equal(answer1, gold_ori1) and math_equal(answer2, gold_ori2): |
| | return True |
| | elif math_equal(answer2, gold_ori1) and math_equal(answer1, gold_ori2): |
| | return True |
| |
|
| | return False |
| | |
| | else: |
| | return True |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| | pattern6_re = re.compile(pattern6, re.DOTALL) |
| |
|
| | gold_list = [] |
| | question_list = [] |
| | print("reading from %s" % test_datapath) |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | answer = item['answer'] |
| | answer = re.sub(r'^\$(.*)\$$', r'\1', answer) |
| | gold_list.append(answer) |
| | question_list.append(item['question']) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | count_timeout = 0 |
| | correct = 0 |
| | print("reading from %s" % input_datapath) |
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| |
|
| | matches1 = pattern1_re.findall(line) |
| | |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| | matches6 = pattern6_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | |
| | |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | elif len(matches6) >= 1: |
| | extracted_answer = matches6[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| | extracted_answer = _extra_cleaning(extracted_answer) |
| | gold = _extra_cleaning(gold) |
| | if gold.endswith("-"): |
| | gold = gold[:-1] |
| | if gold.endswith("."): |
| | gold = gold[:-1] |
| | if gold.endswith("hours"): |
| | gold = gold[:-len("hours")] |
| | if extracted_answer.endswith("."): |
| | extracted_answer = extracted_answer[:-1] |
| | |
| | extracted_answer_ori = extracted_answer |
| | gold_ori = gold |
| |
|
| | if "=" in gold: |
| | gold = gold.split("=", 1)[1] |
| | if ":" in gold: |
| | gold = gold.split(":", 1)[1] |
| |
|
| | if "=" in extracted_answer: |
| | extracted_answer = extracted_answer.split("=", 1)[1] |
| | if ":" in extracted_answer: |
| | extracted_answer = extracted_answer.split(":", 1)[1] |
| |
|
| | |
| | extracted_answer = extracted_answer.replace("\\emptyset", "\\oslash") |
| |
|
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif _check_after_equal(extracted_answer_ori, gold_ori, matches1): |
| | correct += 1 |
| | elif _check_after_equal(extracted_answer, gold, matches1): |
| | correct += 1 |
| | elif check_after_fraction_mapping(extracted_answer, gold): |
| | correct += 1 |
| | elif round_number(extracted_answer) == round_number(gold): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| | else: |
| | |
| | gold_matches = _extra_content_from_gold(gold) |
| |
|
| | correctflag = False |
| | if len(gold_matches) >= 1: |
| | gold2 = "".join(gold_matches) |
| | if "\\approx" in gold2: |
| | gold2 = gold2.split("\\approx", 1)[1] |
| | |
| | gold2 = gold2.replace(",", "") |
| | extracted_answer2 = extracted_answer.replace(",", "") |
| | if gold2 != "" and extracted_answer2 == gold2: |
| | correct += 1 |
| | correctflag = True |
| | |
| | |
| | signal.alarm(0) |
| |
|
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | acc = correct / len(gold_list) |
| | print("count_output_none:", count_output_none) |
| | print("count_answer_none:", count_answer_none) |
| | print("count_timeout:", count_timeout) |
| | print("accuracy:", acc) |
| |
|
| | return acc |
| |
|
| |
|
| | def evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath): |
| | """Evaluate AMC23 or AIME24/25 zero-shot performance. |
| | |
| | Args: |
| | input_datapath: Path to model output JSONL file |
| | test_datapath: Path to AMC23/AIME24/AIME25 test JSONL file |
| | |
| | Returns: |
| | float: Accuracy score |
| | """ |
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | gold_list = [] |
| | question_list = [] |
| | print("reading from %s" % test_datapath) |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | answer = str(item['answer']) |
| | gold_list.append(answer) |
| | question_list.append(item['problem']) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | correct = 0 |
| | print("reading from %s" % input_datapath) |
| |
|
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| |
|
| | matches1 = pattern1_re.findall(line) |
| | matches2 = pattern2_re.findall(line) |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif round_number(extracted_answer) == round_number(gold): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | acc = correct / len(gold_list) |
| | print("count_output_none:", count_output_none) |
| | print("count_answer_none:", count_answer_none) |
| | print("accuracy:", acc) |
| | |
| | return acc |
| |
|
| |
|
| | def evaluate_omnimath_zeroshot(input_datapath, test_datapath): |
| | class _TimeoutException(Exception): |
| | pass |
| | def _timeout_handler(signum, frame): |
| | |
| | raise _TimeoutException |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | gold_list = [] |
| | question_list = [] |
| | print("reading from %s" % test_datapath) |
| | with open(test_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | answer = str(item['answer']) |
| | gold_list.append(answer) |
| | question_list.append(item['problem']) |
| |
|
| | count_output_none = 0 |
| | count_answer_none = 0 |
| | correct = 0 |
| | print("reading from %s" % input_datapath) |
| | |
| | with open(input_datapath, "r") as f: |
| | for i, line in enumerate(f): |
| | line = line.strip() |
| | line = json.loads(line)['output'] |
| | |
| | matches1 = pattern1_re.findall(line) |
| | matches2 = pattern2_re.findall(line) |
| | matches3 = pattern3_re.findall(line) |
| | matches4 = pattern4_re.findall(line) |
| | matches5 = pattern5_re.findall(line) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| | gold = gold_list[i] |
| | |
| | if extracted_answer is None: |
| | count_output_none += 1 |
| | continue |
| | if gold is None: |
| | count_answer_none += 1 |
| | continue |
| |
|
| | gold_ori = gold |
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | gold = math_answer_cleaning(gold) |
| |
|
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| |
|
| | if math_equal(extracted_answer, gold): |
| | correct += 1 |
| | elif check_after_fraction_mapping(extracted_answer, gold): |
| | correct += 1 |
| | elif round_number(extracted_answer) == round_number(gold): |
| | correct += 1 |
| | elif is_equal_after_calculation(extracted_answer, gold): |
| | correct += 1 |
| |
|
| | |
| | signal.alarm(0) |
| | |
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | acc = correct / len(gold_list) |
| | print("count_output_none:", count_output_none) |
| | print("count_answer_none:", count_answer_none) |
| | print("accuracy:", acc) |
| | |
| | return acc |
| |
|
| |
|
| | def get_answer_by_marjority_voting(output_list): |
| | """Get the most common answer from multiple model outputs via majority voting. |
| | |
| | Args: |
| | output_list: List of model output strings |
| | |
| | Returns: |
| | dict: Dictionary with 'count' and 'original_output' for the majority answer |
| | """ |
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern2 = r"\*\*(.*?)\*\*" |
| | pattern3 = r"\\\[\n(.*?)\n\\\]" |
| | pattern4 = r'is \\\((.*?)\\\)' |
| | pattern5 = r"\\\[\\n(.*?)\\n\\\]" |
| |
|
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| | pattern2_re = re.compile(pattern2, re.DOTALL) |
| | pattern3_re = re.compile(pattern3, re.DOTALL) |
| | pattern4_re = re.compile(pattern4, re.DOTALL) |
| | pattern5_re = re.compile(pattern5, re.DOTALL) |
| |
|
| | answer_dict = {} |
| | for output in output_list: |
| | matches1 = pattern1_re.findall(output) |
| | matches2 = pattern2_re.findall(output) |
| | matches3 = pattern3_re.findall(output) |
| | matches4 = pattern4_re.findall(output) |
| | matches5 = pattern5_re.findall(output) |
| |
|
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | elif len(matches3) >= 1: |
| | extracted_answer = matches3[-1] |
| | elif len(matches4) >= 1: |
| | extracted_answer = matches4[-1] |
| | elif len(matches5) >= 1: |
| | extracted_answer = matches5[-1] |
| | else: |
| | extracted_answer = None |
| |
|
| | if extracted_answer is None: |
| | continue |
| | |
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| |
|
| | has_found = False |
| | for prev_ans in answer_dict: |
| | if extracted_answer == prev_ans: |
| | answer_dict[prev_ans]['count'] += 1 |
| | has_found = True |
| | break |
| | elif math_equal(extracted_answer, prev_ans): |
| | answer_dict[prev_ans]['count'] += 1 |
| | has_found = True |
| | break |
| | elif check_after_fraction_mapping(extracted_answer, prev_ans): |
| | answer_dict[prev_ans]['count'] += 1 |
| | has_found = True |
| | break |
| | elif round_number(extracted_answer) == round_number(prev_ans): |
| | answer_dict[prev_ans]['count'] += 1 |
| | has_found = True |
| | break |
| | elif is_equal_after_calculation(extracted_answer, prev_ans): |
| | answer_dict[prev_ans]['count'] += 1 |
| | has_found = True |
| | break |
| |
|
| | if not has_found: |
| | answer_dict[extracted_answer] = {"count": 1, "original_output": output} |
| |
|
| | |
| | sorted_answers = sorted(answer_dict, key=lambda x: answer_dict[x]["count"], reverse=True) |
| |
|
| | return answer_dict[sorted_answers[0]] |
| |
|
| |
|
| | def evaluate_gpqa(input_datapath, test_datapath): |
| | """Evaluate GPQA (Graduate-Level Google-Proof Q&A) benchmark. |
| | |
| | Args: |
| | input_datapath: Path to model output JSONL file |
| | test_datapath: Path to GPQA test JSON file |
| | |
| | Returns: |
| | float: Accuracy score |
| | """ |
| | class _TimeoutException(Exception): |
| | pass |
| |
|
| | def _timeout_handler(signum, frame): |
| | raise _TimeoutException |
| |
|
| | output_list = read_text_data(input_datapath) |
| | gold_list = read_json_data(test_datapath) |
| |
|
| | num_samples = len(gold_list) |
| | assert len(output_list) == len(gold_list) == num_samples |
| |
|
| | pattern1 = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" |
| | pattern1_re = re.compile(pattern1, re.DOTALL) |
| |
|
| | |
| | pattern2_re = re.compile(r'\b(?:Answer|Final Answer|ANSWER)\b[:\s\*]*\(?([A-D])\)?') |
| |
|
| | count_none = 0 |
| | count_timeout = 0 |
| | correct = 0 |
| | |
| | for output, gold in zip(output_list, gold_list): |
| | choices = [gold['choice_A'], gold['choice_B'], gold['choice_C'], gold['choice_D']] |
| | correct_answer = gold["correct_answer"] |
| | correct_index = choices.index(correct_answer) |
| | correct_choice = "ABCD"[correct_index] |
| | |
| | matches1 = pattern1_re.findall(output) |
| | matches2 = pattern2_re.findall(output) |
| | if len(matches1) >= 1: |
| | extracted_answer = matches1[-1] |
| | elif len(matches2) >= 1: |
| | extracted_answer = matches2[-1] |
| | else: |
| | extracted_answer = None |
| |
|
| | if extracted_answer is None: |
| | count_none += 1 |
| | continue |
| | |
| | correct_answer = math_answer_cleaning(correct_answer) |
| | extracted_answer = math_answer_cleaning(extracted_answer) |
| | |
| | try: |
| | |
| | signal.signal(signal.SIGALRM, _timeout_handler) |
| | signal.alarm(5) |
| |
|
| | if extracted_answer.lower() == correct_choice.lower(): |
| | correct += 1 |
| | elif math_equal(extracted_answer, correct_answer): |
| | correct += 1 |
| | elif "("+correct_choice+")" in extracted_answer: |
| | |
| | correct += 1 |
| |
|
| | |
| | signal.alarm(0) |
| |
|
| | except: |
| | |
| | signal.alarm(0) |
| | count_timeout += 1 |
| |
|
| | acc = correct / num_samples |
| | print("num_samples:", num_samples) |
| | print("count_none:", count_none) |
| | print("accuracy:", acc) |
| |
|
| | return acc |
| |
|
| | def get_args(): |
| | """Parse command-line arguments for evaluation script. |
| | |
| | Returns: |
| | argparse.Namespace: Parsed arguments |
| | """ |
| | parser = argparse.ArgumentParser(description="Math Benchmark Evaluation") |
| | parser.add_argument("--modelfolder", type=str, required=True, help="Path to model output folder") |
| | parser.add_argument("--testfolder", type=str, required=True, help="Path to test data folder") |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def check_finish(input_datapath): |
| | """Check the finish rate (non-empty outputs) of model outputs. |
| | |
| | Args: |
| | input_datapath: Path to model output JSONL file |
| | |
| | Returns: |
| | float: Finish rate (proportion of non-empty outputs) |
| | """ |
| | finish_rates = [] |
| | with open(input_datapath, "r") as f: |
| | for line in f: |
| | item = json.loads(line) |
| | if not item['reason']: |
| | finish_rates.append(0) |
| | output = item['output'] |
| | finish_rates.append(1 if output else 0) |
| |
|
| | return np.mean(finish_rates) |
| |
|
| |
|
| | def main(): |
| | """Main evaluation function for AIME benchmarks with W&B logging.""" |
| | args = get_args() |
| |
|
| | model_folder = args.modelfolder |
| | test_datafolder = args.testfolder |
| | import glob |
| |
|
| | avg_acc = [] |
| | avg_common_acc = [] |
| | |
| | aime24_accs = [] |
| | aime25_accs = [] |
| | aime24_finish = [] |
| | aime25_finish = [] |
| |
|
| | input_datapaths = glob.glob(model_folder+"/outputs_*/aime24.jsonl") |
| | acc_tmp = 0 |
| | for input_datapath in input_datapaths: |
| | test_datapath = os.path.join(test_datafolder, "qwen2_math/aime24/test.jsonl") |
| | acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath) |
| | aime24_accs.append(acc) |
| | finish = check_finish(input_datapath) |
| | aime24_finish.append(finish) |
| | acc_tmp += acc |
| | |
| | aime24_acc = acc_tmp / len(input_datapaths) |
| | aime24_std = np.std(aime24_accs) if len(aime24_accs) > 1 else 0 |
| | aime24_finish = np.mean(aime24_finish) |
| | print("-"*80) |
| | print("avg acc for aime24:", aime24_acc, "std:", aime24_std) |
| | print("avg finish for aime24:", aime24_finish) |
| | avg_acc.append(aime24_acc) |
| | avg_common_acc.append(aime24_acc) |
| |
|
| | input_datapaths = glob.glob(model_folder+"/outputs_*/aime25.jsonl") |
| | acc_tmp = 0 |
| | for input_datapath in input_datapaths: |
| | test_datapath = os.path.join(test_datafolder, "aime25/test.jsonl") |
| | acc = evaluate_amc23_or_aime24_zeroshot(input_datapath, test_datapath) |
| | aime25_accs.append(acc) |
| | finish = check_finish(input_datapath) |
| | aime25_finish.append(finish) |
| | acc_tmp += acc |
| | |
| | aime25_acc = acc_tmp / len(input_datapaths) |
| | aime25_std = np.std(aime25_accs) if len(aime25_accs) > 1 else 0 |
| | aime25_finish = np.mean(aime25_finish) |
| | print("-"*80) |
| | print("avg acc for aime25:", aime25_acc, "std:", aime25_std) |
| | print("avg finish for aime25:", aime25_finish) |
| | avg_acc.append(aime25_acc) |
| | avg_common_acc.append(aime25_acc) |
| |
|
| | print("="*80) |
| | print("average acc across AIME24:", aime24_acc, "±", aime24_std, ", AIME25:", aime25_acc, "±", aime25_std) |
| | print("average finish across AIME24:", aime24_finish, ", AIME25:", aime25_finish) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|