|
|
import re |
|
|
import os |
|
|
import json |
|
|
from typing import Any |
|
|
import re |
|
|
import regex |
|
|
from latex2sympy2 import latex2sympy |
|
|
from word2number import w2n |
|
|
|
|
|
|
|
|
def extract_math_answer(text: str) -> str: |
|
|
"""Extract answer from math problem (boxed format)""" |
|
|
|
|
|
boxed_pattern = r'\\boxed\{([^}]*)\}' |
|
|
matches = re.findall(boxed_pattern, text) |
|
|
if matches: |
|
|
return matches[-1].strip() |
|
|
|
|
|
|
|
|
answer_patterns = [ |
|
|
r'answer is[:\s]+([^\n.]+)', |
|
|
r'final answer[:\s]+([^\n.]+)', |
|
|
r'therefore[,:\s]+([^\n.]+)' |
|
|
] |
|
|
|
|
|
for pattern in answer_patterns: |
|
|
matches = re.findall(pattern, text.lower()) |
|
|
if matches: |
|
|
return matches[-1].strip() |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def extract_multiple_choice_answer(text: str) -> str: |
|
|
"""Extract answer from multiple choice (A, B, C, D format)""" |
|
|
|
|
|
pattern = r'\b([A-D])\b' |
|
|
matches = re.findall(pattern, text.upper()) |
|
|
if matches: |
|
|
return matches[-1] |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def normalize_answer(answer: str) -> str: |
|
|
"""Normalize answer for comparison""" |
|
|
answer = answer.strip().lower() |
|
|
|
|
|
answer = answer.replace('$', '').replace('\\', '') |
|
|
|
|
|
answer = ' '.join(answer.split()) |
|
|
return answer |
|
|
|
|
|
|
|
|
def check_answer_match(pred: str, ground_truth: str, task_type: str = "math") -> bool: |
|
|
"""Check if predicted answer matches ground truth""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = str(pred) |
|
|
ground_truth = str(ground_truth) |
|
|
return pred == ground_truth or pred in ground_truth or ground_truth in pred |
|
|
|
|
|
|
|
|
def ensure_dir(directory: str): |
|
|
"""Create directory if it doesn't exist""" |
|
|
if not os.path.exists(directory): |
|
|
os.makedirs(directory) |
|
|
|
|
|
|
|
|
def save_json(data: Any, path: str): |
|
|
"""Save data to JSON file""" |
|
|
ensure_dir(os.path.dirname(path)) |
|
|
with open(path, 'w') as f: |
|
|
json.dump(data, f, indent=2) |
|
|
|
|
|
|
|
|
def load_json(path: str) -> Any: |
|
|
"""Load data from JSON file""" |
|
|
with open(path, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
def _fix_fracs(string): |
|
|
substrs = string.split("\\frac") |
|
|
new_str = substrs[0] |
|
|
if len(substrs) > 1: |
|
|
substrs = substrs[1:] |
|
|
for substr in substrs: |
|
|
new_str += "\\frac" |
|
|
if len(substr) > 0 and substr[0] == "{": |
|
|
new_str += substr |
|
|
else: |
|
|
try: |
|
|
assert len(substr) >= 2 |
|
|
except: |
|
|
return string |
|
|
a = substr[0] |
|
|
b = substr[1] |
|
|
if b != "{": |
|
|
if len(substr) > 2: |
|
|
post_substr = substr[2:] |
|
|
new_str += "{" + a + "}{" + b + "}" + post_substr |
|
|
else: |
|
|
new_str += "{" + a + "}{" + b + "}" |
|
|
else: |
|
|
if len(substr) > 2: |
|
|
post_substr = substr[2:] |
|
|
new_str += "{" + a + "}" + b + post_substr |
|
|
else: |
|
|
new_str += "{" + a + "}" + b |
|
|
string = new_str |
|
|
return string |
|
|
|
|
|
|
|
|
def _fix_a_slash_b(string): |
|
|
if len(string.split("/")) != 2: |
|
|
return string |
|
|
a = string.split("/")[0] |
|
|
b = string.split("/")[1] |
|
|
try: |
|
|
if "sqrt" not in a: |
|
|
a = int(a) |
|
|
if "sqrt" not in b: |
|
|
b = int(b) |
|
|
assert string == "{}/{}".format(a, b) |
|
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
|
|
return new_string |
|
|
except: |
|
|
return string |
|
|
|
|
|
|
|
|
def _fix_sqrt(string): |
|
|
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) |
|
|
return _string |
|
|
|
|
|
|
|
|
def convert_word_number(text: str) -> str: |
|
|
try: |
|
|
text = str(w2n.word_to_num(text)) |
|
|
except: |
|
|
pass |
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
unit_texts = [ |
|
|
"east", |
|
|
"degree", |
|
|
"mph", |
|
|
"kmph", |
|
|
"ft", |
|
|
"m sqaure", |
|
|
" m east", |
|
|
"sq m", |
|
|
"deg", |
|
|
"mile", |
|
|
"q .", |
|
|
"monkey", |
|
|
"prime", |
|
|
"ratio", |
|
|
"profit of rs", |
|
|
"rd", |
|
|
"o", |
|
|
"gm", |
|
|
"p . m", |
|
|
"lb", |
|
|
"tile", |
|
|
"per", |
|
|
"dm", |
|
|
"lt", |
|
|
"gain", |
|
|
"ab", |
|
|
"way", |
|
|
"west", |
|
|
"a .", |
|
|
"b .", |
|
|
"c .", |
|
|
"d .", |
|
|
"e .", |
|
|
"f .", |
|
|
"g .", |
|
|
"h .", |
|
|
"t", |
|
|
"a", |
|
|
"h", |
|
|
"no change", |
|
|
"men", |
|
|
"soldier", |
|
|
"pie", |
|
|
"bc", |
|
|
"excess", |
|
|
"st", |
|
|
"inches", |
|
|
"noon", |
|
|
"percent", |
|
|
"by", |
|
|
"gal", |
|
|
"kmh", |
|
|
"c", |
|
|
"acre", |
|
|
"rise", |
|
|
"a . m", |
|
|
"th", |
|
|
"π r 2", |
|
|
"sq", |
|
|
"mark", |
|
|
"l", |
|
|
"toy", |
|
|
"coin", |
|
|
"sq . m", |
|
|
"gallon", |
|
|
"° f", |
|
|
"profit", |
|
|
"minw", |
|
|
"yr", |
|
|
"women", |
|
|
"feet", |
|
|
"am", |
|
|
"pm", |
|
|
"hr", |
|
|
"cu cm", |
|
|
"square", |
|
|
"v â € ™", |
|
|
"are", |
|
|
"rupee", |
|
|
"rounds", |
|
|
"cubic", |
|
|
"cc", |
|
|
"mtr", |
|
|
"s", |
|
|
"ohm", |
|
|
"number", |
|
|
"kmph", |
|
|
"day", |
|
|
"hour", |
|
|
"minute", |
|
|
"min", |
|
|
"second", |
|
|
"man", |
|
|
"woman", |
|
|
"sec", |
|
|
"cube", |
|
|
"mt", |
|
|
"sq inch", |
|
|
"mp", |
|
|
"∏ cm ³", |
|
|
"hectare", |
|
|
"more", |
|
|
"sec", |
|
|
"unit", |
|
|
"cu . m", |
|
|
"cm 2", |
|
|
"rs .", |
|
|
"rs", |
|
|
"kg", |
|
|
"g", |
|
|
"month", |
|
|
"km", |
|
|
"m", |
|
|
"cm", |
|
|
"mm", |
|
|
"apple", |
|
|
"liter", |
|
|
"loss", |
|
|
"yard", |
|
|
"pure", |
|
|
"year", |
|
|
"increase", |
|
|
"decrease", |
|
|
"d", |
|
|
"less", |
|
|
"Surface", |
|
|
"litre", |
|
|
"pi sq m", |
|
|
"s .", |
|
|
"metre", |
|
|
"meter", |
|
|
"inch", |
|
|
] |
|
|
|
|
|
unit_texts.extend([t + "s" for t in unit_texts]) |
|
|
|
|
|
|
|
|
def strip_string(string, skip_unit=False): |
|
|
string = str(string).strip() |
|
|
|
|
|
string = string.replace("\n", "") |
|
|
|
|
|
|
|
|
string = string.rstrip(".") |
|
|
|
|
|
|
|
|
|
|
|
string = string.replace("\\!", "") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) |
|
|
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) |
|
|
string = string.replace("bmatrix", "pmatrix") |
|
|
|
|
|
|
|
|
string = string.replace("tfrac", "frac") |
|
|
string = string.replace("dfrac", "frac") |
|
|
string = (string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge")) |
|
|
|
|
|
|
|
|
string = string.replace("\\left", "") |
|
|
string = string.replace("\\right", "") |
|
|
string = string.replace("\\{", "{") |
|
|
string = string.replace("\\}", "}") |
|
|
|
|
|
|
|
|
_string = re.sub(r"\\text{.*?}$", "", string).strip() |
|
|
if _string != "" and _string != string: |
|
|
|
|
|
string = _string |
|
|
|
|
|
if not skip_unit: |
|
|
|
|
|
for _ in range(2): |
|
|
for unit_text in unit_texts: |
|
|
|
|
|
|
|
|
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) |
|
|
if _string != "": |
|
|
string = _string |
|
|
|
|
|
|
|
|
string = string.replace("^{\\circ}", "") |
|
|
string = string.replace("^\\circ", "") |
|
|
|
|
|
|
|
|
string = string.replace("\\$", "") |
|
|
string = string.replace("$", "") |
|
|
string = string.replace("\\(", "").replace("\\)", "") |
|
|
|
|
|
|
|
|
string = convert_word_number(string) |
|
|
|
|
|
|
|
|
string = re.sub(r"\\text\{(.*?)\}", r"\1", string) |
|
|
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: |
|
|
string = string.replace(key, "") |
|
|
string = string.replace("\\emptyset", r"{}") |
|
|
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") |
|
|
|
|
|
|
|
|
string = string.replace("\\%", "") |
|
|
string = string.replace("\%", "") |
|
|
string = string.replace("%", "") |
|
|
|
|
|
|
|
|
string = string.replace(" .", " 0.") |
|
|
string = string.replace("{.", "{0.") |
|
|
|
|
|
|
|
|
|
|
|
if (string.startswith("{") and string.endswith("}") and string.isalnum() or |
|
|
string.startswith("(") and string.endswith(")") and string.isalnum() or |
|
|
string.startswith("[") and string.endswith("]") and string.isalnum()): |
|
|
string = string[1:-1] |
|
|
|
|
|
|
|
|
string = string.replace("infinity", "\\infty") |
|
|
if "\\infty" not in string: |
|
|
string = string.replace("inf", "\\infty") |
|
|
string = string.replace("+\\inity", "\\infty") |
|
|
|
|
|
|
|
|
string = string.replace("and", "") |
|
|
string = string.replace("\\mathbf", "") |
|
|
|
|
|
|
|
|
string = re.sub(r"\\mbox{.*?}", "", string) |
|
|
|
|
|
|
|
|
string.replace("'", "") |
|
|
string.replace('"', "") |
|
|
|
|
|
|
|
|
if "j" in string and "i" not in string: |
|
|
string = string.replace("j", "i") |
|
|
|
|
|
|
|
|
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) |
|
|
string = re.sub(r"(\d+)\.0*$", r"\1", string) |
|
|
|
|
|
|
|
|
if len(string) == 0: |
|
|
return string |
|
|
if string[0] == ".": |
|
|
string = "0" + string |
|
|
|
|
|
|
|
|
if len(string.split("=")) == 2: |
|
|
if len(string.split("=")[0]) <= 2: |
|
|
string = string.split("=")[1] |
|
|
|
|
|
string = _fix_sqrt(string) |
|
|
string = string.replace(" ", "") |
|
|
|
|
|
|
|
|
string = _fix_fracs(string) |
|
|
|
|
|
|
|
|
string = _fix_a_slash_b(string) |
|
|
|
|
|
return string |
|
|
|
|
|
|
|
|
direct_answer_trigger_for_fewshot = ("choice is", "answer is") |
|
|
|
|
|
|
|
|
def choice_answer_clean(pred: str): |
|
|
pred = pred.strip("\n") |
|
|
|
|
|
|
|
|
ICL = False |
|
|
for trigger in direct_answer_trigger_for_fewshot: |
|
|
if pred.count(trigger) > 1: |
|
|
ICL = True |
|
|
if ICL: |
|
|
pred = pred.split("\n\n")[0] |
|
|
|
|
|
|
|
|
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) |
|
|
if len(preds) > 1: |
|
|
answer_flag = True |
|
|
pred = preds[-1] |
|
|
else: |
|
|
answer_flag = False |
|
|
|
|
|
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") |
|
|
|
|
|
|
|
|
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) |
|
|
if tmp: |
|
|
pred = tmp |
|
|
else: |
|
|
pred = [pred.strip().strip(".")] |
|
|
|
|
|
if len(pred) == 0: |
|
|
pred = "" |
|
|
else: |
|
|
if answer_flag: |
|
|
|
|
|
pred = pred[0] |
|
|
else: |
|
|
|
|
|
pred = pred[-1] |
|
|
|
|
|
|
|
|
pred = pred.rstrip(".").rstrip("/") |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
def find_box(pred_str: str): |
|
|
ans = pred_str.split("boxed")[-1] |
|
|
if not ans: |
|
|
return "" |
|
|
if ans[0] == "{": |
|
|
stack = 1 |
|
|
a = "" |
|
|
for c in ans[1:]: |
|
|
if c == "{": |
|
|
stack += 1 |
|
|
a += c |
|
|
elif c == "}": |
|
|
stack -= 1 |
|
|
if stack == 0: |
|
|
break |
|
|
a += c |
|
|
else: |
|
|
a += c |
|
|
else: |
|
|
a = ans.split("$")[0].strip() |
|
|
return a |
|
|
|
|
|
|
|
|
def clean_units(pred_str: str): |
|
|
"""Clean the units in the number.""" |
|
|
|
|
|
def convert_pi_to_number(code_string): |
|
|
code_string = code_string.replace("\\pi", "π") |
|
|
|
|
|
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string) |
|
|
|
|
|
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) |
|
|
|
|
|
|
|
|
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) |
|
|
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) |
|
|
return code_string |
|
|
|
|
|
pred_str = convert_pi_to_number(pred_str) |
|
|
pred_str = pred_str.replace("%", "/100") |
|
|
pred_str = pred_str.replace("$", "") |
|
|
pred_str = pred_str.replace("¥", "") |
|
|
pred_str = pred_str.replace("°C", "") |
|
|
pred_str = pred_str.replace(" C", "") |
|
|
pred_str = pred_str.replace("°", "") |
|
|
return pred_str |
|
|
|
|
|
|
|
|
def extract_theoremqa_answer(pred: str, answer_flag: bool = True): |
|
|
if any([option in pred.lower() for option in ["yes", "true"]]): |
|
|
pred = "True" |
|
|
elif any([option in pred.lower() for option in ["no", "false"]]): |
|
|
pred = "False" |
|
|
elif any([option in pred.lower() for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]]): |
|
|
pass |
|
|
else: |
|
|
|
|
|
if "boxed" in pred: |
|
|
pred = find_box(pred) |
|
|
|
|
|
if answer_flag: |
|
|
|
|
|
pred = pred.split("=")[-1].strip() |
|
|
pred = clean_units(pred) |
|
|
try: |
|
|
tmp = str(latex2sympy(pred)) |
|
|
pred = str(eval(tmp)) |
|
|
except Exception: |
|
|
if re.match(r"-?[\d\.]+\s\D+$", pred): |
|
|
pred = pred.split(" ")[0] |
|
|
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): |
|
|
pred = pred.split(" ")[0] |
|
|
else: |
|
|
|
|
|
preds = re.findall(r"-?\d*\.?\d+", pred) |
|
|
if len(preds) >= 1: |
|
|
pred = preds[-1] |
|
|
else: |
|
|
pred = "" |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
def extract_answer(pred_str, data_name, use_last_number=True): |
|
|
if data_name.lower() == "humaneval": |
|
|
pattern = r"### Function Body:\s*\n```python\n(.*?)\n```" |
|
|
matches = re.findall(pattern, pred_str, re.DOTALL) |
|
|
try: |
|
|
return matches[0] |
|
|
except IndexError: |
|
|
return "" |
|
|
elif data_name.lower() == "mmlu": |
|
|
if len(pred_str) >= 3 and pred_str[0] == '(' and pred_str[2] == ')': |
|
|
return pred_str[1] |
|
|
pred_str = pred_str.replace("\u043a\u0438", "") |
|
|
|
|
|
if "final answer is $" in pred_str and "$. I hope" in pred_str: |
|
|
|
|
|
tmp = pred_str.split("final answer is $", 1)[1] |
|
|
pred = tmp.split("$. I hope", 1)[0].strip() |
|
|
elif "boxed" in pred_str: |
|
|
ans = pred_str.split("boxed")[-1] |
|
|
if len(ans) == 0: |
|
|
return "" |
|
|
elif ans[0] == "{": |
|
|
stack = 1 |
|
|
a = "" |
|
|
for c in ans[1:]: |
|
|
if c == "{": |
|
|
stack += 1 |
|
|
a += c |
|
|
elif c == "}": |
|
|
stack -= 1 |
|
|
if stack == 0: |
|
|
break |
|
|
a += c |
|
|
else: |
|
|
a += c |
|
|
else: |
|
|
a = ans.split("$")[0].strip() |
|
|
pred = a |
|
|
elif "he answer is" in pred_str: |
|
|
pred = pred_str.split("he answer is")[-1].strip() |
|
|
elif "final answer is" in pred_str: |
|
|
pred = pred_str.split("final answer is")[-1].strip() |
|
|
elif "答案是" in pred_str: |
|
|
|
|
|
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() |
|
|
else: |
|
|
if use_last_number: |
|
|
pattern = "-?\d*\.?\d+" |
|
|
pred = re.findall(pattern, pred_str.replace(",", "")) |
|
|
if len(pred) >= 1: |
|
|
pred = pred[-1] |
|
|
else: |
|
|
pred = "" |
|
|
else: |
|
|
pred = "" |
|
|
|
|
|
|
|
|
|
|
|
pred = re.sub(r"\n\s*", "", pred) |
|
|
if pred != "" and pred[0] == ":": |
|
|
pred = pred[1:] |
|
|
if pred != "" and pred[-1] == ".": |
|
|
pred = pred[:-1] |
|
|
if pred != "" and pred[-1] == "/": |
|
|
pred = pred[:-1] |
|
|
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva"]) |
|
|
|
|
|
if data_name == 'GPQA' or data_name == 'MMLU': |
|
|
if len(pred) >= 3 and pred[0] == '(' and pred[2] == ')': |
|
|
pred = pred[1] |
|
|
return pred |
|
|
|
|
|
|
|
|
STRIP_EXCEPTIONS = ["carp_en", "minerva"] |
|
|
|
|
|
|
|
|
def parse_ground_truth(groudtruth_solution: str, data_name): |
|
|
gt_ans = extract_answer(groudtruth_solution, data_name) |
|
|
return gt_ans |
|
|
|
|
|
|
|
|
def parse_question(example, data_name): |
|
|
question = "" |
|
|
if data_name == "asdiv": |
|
|
question = f"{example['body'].strip()} {example['question'].strip()}" |
|
|
elif data_name == "svamp": |
|
|
body = example["Body"].strip() |
|
|
if not body.endswith("."): |
|
|
body = body + "." |
|
|
question = f'{body} {example["Question"].strip()}' |
|
|
elif data_name == "tabmwp": |
|
|
title_str = (f'regarding "{example["table_title"]}" ' if example["table_title"] else "") |
|
|
question = f"Read the following table {title_str}and answer a question:\n" |
|
|
question += f'{example["table"]}\n{example["question"]}' |
|
|
if example["choices"]: |
|
|
question += (f' Please select from the following options: {example["choices"]}') |
|
|
elif data_name == "carp_en": |
|
|
question = example["content"] |
|
|
elif data_name == "mmlu_stem": |
|
|
options = example["choices"] |
|
|
assert len(options) == 4 |
|
|
for i, (label, option) in enumerate(zip("ABCD", options)): |
|
|
options[i] = f"({label}) {str(option).strip()}" |
|
|
options = " ".join(options) |
|
|
|
|
|
question = f"{example['question'].strip()}\nAnswer Choices: {options}" |
|
|
elif data_name == "sat_math": |
|
|
options = example["options"].strip() |
|
|
assert "A" == options[0] |
|
|
options = "(" + options |
|
|
for ch in "BCD": |
|
|
if f" {ch}) " in options: |
|
|
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options) |
|
|
|
|
|
question = f"{example['question'].strip()}\nAnswer Choices: {options}" |
|
|
elif "aqua" in data_name: |
|
|
options = example["options"] |
|
|
choice = "(" + "(".join(options) |
|
|
choice = choice.replace("(", " (").replace(")", ") ").strip() |
|
|
choice = "\nAnswer Choices: " + choice |
|
|
question = example["question"].strip() + choice |
|
|
elif data_name == "gaokao_math_qa": |
|
|
options_dict = example["options"] |
|
|
options = [] |
|
|
for key in options_dict: |
|
|
options.append(f"({key}) {options_dict[key]}") |
|
|
options = " ".join(options) |
|
|
question = f"{example['question'].strip()}\n选项: {options}" |
|
|
else: |
|
|
for key in ["question", "problem", "Question", "input"]: |
|
|
if key in example: |
|
|
question = example[key] |
|
|
break |
|
|
|
|
|
|
|
|
_, gt_ans = parse_ground_truth(example, data_name) |
|
|
if isinstance(gt_ans, str): |
|
|
gt_lower = gt_ans.lower() |
|
|
if gt_lower in ["true", "false"]: |
|
|
question += " (True or False)" |
|
|
if gt_lower in ["yes", "no"]: |
|
|
question += " (Yes or No)" |
|
|
return question.strip() |